use na::{RealField, SimdPartialOrd};
use num::{One, Zero};
use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
pub trait IntervalFunction<T> {
fn eval(&self, t: T) -> T;
fn eval_interval(&self, t: Interval<T>) -> Interval<T>;
fn eval_interval_gradient(&self, t: Interval<T>) -> Interval<T>;
}
pub fn find_root_intervals_to<T: RealField + Copy>(
function: &impl IntervalFunction<T>,
init: Interval<T>,
min_interval_width: T,
min_image_width: T,
max_recursions: usize,
results: &mut Vec<Interval<T>>,
candidates: &mut Vec<(Interval<T>, usize)>,
) {
candidates.clear();
let push_candidate = |candidate,
recursion,
results: &mut Vec<Interval<T>>,
candidates: &mut Vec<(Interval<T>, usize)>| {
let candidate_image = function.eval_interval(candidate);
let is_small_range =
candidate.width() < min_interval_width || candidate_image.width() < min_image_width;
if candidate_image.contains(T::zero()) {
if recursion == max_recursions || is_small_range {
results.push(candidate);
} else {
candidates.push((candidate, recursion + 1));
}
} else if is_small_range
&& function.eval(candidate.midpoint()).abs() < T::default_epsilon().sqrt()
{
results.push(candidate);
}
};
push_candidate(init, 0, results, candidates);
while let Some((candidate, recursion)) = candidates.pop() {
let mid = candidate.midpoint();
let f_mid = function.eval(mid);
let gradient = function.eval_interval_gradient(candidate);
let (shift1, shift2) = Interval(f_mid, f_mid) / gradient;
let new_candidates = [
(Interval(mid, mid) - shift1).intersect(candidate),
shift2.and_then(|shift2| (Interval(mid, mid) - shift2).intersect(candidate)),
];
let prev_width = candidate.width();
for new_candidate in new_candidates.iter().flatten() {
if new_candidate.width() > prev_width * na::convert(0.75) {
let [a, b] = new_candidate.split();
push_candidate(a, recursion, results, candidates);
push_candidate(b, recursion, results, candidates);
} else {
push_candidate(*new_candidate, recursion, results, candidates);
}
}
}
}
pub fn find_root_intervals<T: RealField + Copy>(
function: &impl IntervalFunction<T>,
init: Interval<T>,
min_interval_width: T,
min_image_width: T,
max_recursions: usize,
) -> Vec<Interval<T>> {
let mut results = vec![];
let mut candidates = vec![];
find_root_intervals_to(
function,
init,
min_interval_width,
min_image_width,
max_recursions,
&mut results,
&mut candidates,
);
results
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct Interval<T>(pub T, pub T);
impl<T> Interval<T> {
#[must_use]
pub fn sort(a: T, b: T) -> Self
where
T: PartialOrd,
{
if a < b {
Self(a, b)
} else {
Self(b, a)
}
}
#[must_use]
pub fn splat(e: T) -> Self
where
T: Clone,
{
Self(e.clone(), e)
}
#[must_use]
pub fn contains(&self, t: T) -> bool
where
T: PartialOrd<T>,
{
self.0 <= t && self.1 >= t
}
#[must_use]
pub fn width(self) -> T::Output
where
T: Sub<T>,
{
self.1 - self.0
}
#[must_use]
pub fn midpoint(self) -> T
where
T: RealField + Copy,
{
let two: T = na::convert(2.0);
(self.0 + self.1) / two
}
#[must_use]
pub fn split(self) -> [Self; 2]
where
T: RealField + Copy,
{
let mid = self.midpoint();
[Interval(self.0, mid), Interval(mid, self.1)]
}
#[must_use]
pub fn enclose(self, t: T) -> Self
where
T: PartialOrd,
{
if t < self.0 {
Interval(t, self.1)
} else if t > self.1 {
Interval(self.0, t)
} else {
self
}
}
#[must_use]
pub fn intersect(self, rhs: Self) -> Option<Self>
where
T: PartialOrd + SimdPartialOrd, {
let result = Interval(self.0.simd_max(rhs.0), self.1.simd_min(rhs.1));
if result.0 > result.1 {
None
} else {
Some(result)
}
}
#[must_use]
pub fn sin_cos(self) -> (Self, Self)
where
T: RealField + Copy,
{
(self.sin(), self.cos())
}
#[must_use]
pub fn sin(self) -> Self
where
T: RealField + Copy,
{
if self.width() >= T::two_pi() {
Interval(-T::one(), T::one())
} else {
let sin0 = self.0.sin();
let sin1 = self.1.sin();
let mut result = Interval::sort(sin0, sin1);
let orig = (self.0 / T::two_pi()).floor() * T::two_pi();
let crit = [orig + T::frac_pi_2(), orig + T::pi() + T::frac_pi_2()];
let crit_vals = [T::one(), -T::one()];
for i in 0..2 {
if self.contains(crit[i]) || self.contains(crit[i] + T::two_pi()) {
result = result.enclose(crit_vals[i])
}
}
result
}
}
#[must_use]
pub fn cos(self) -> Self
where
T: RealField + Copy,
{
if self.width() >= T::two_pi() {
Interval(-T::one(), T::one())
} else {
let cos0 = self.0.cos();
let cos1 = self.1.cos();
let mut result = Interval::sort(cos0, cos1);
let orig = (self.0 / T::two_pi()).floor() * T::two_pi();
let crit = [orig, orig + T::pi()];
let crit_vals = [T::one(), -T::one()];
for i in 0..2 {
if self.contains(crit[i]) || self.contains(crit[i] + T::two_pi()) {
result = result.enclose(crit_vals[i])
}
}
result
}
}
}
impl<T: Add<T> + Copy> Add<T> for Interval<T> {
type Output = Interval<<T as Add<T>>::Output>;
fn add(self, rhs: T) -> Self::Output {
Interval(self.0 + rhs, self.1 + rhs)
}
}
impl<T: Add<T>> Add<Interval<T>> for Interval<T> {
type Output = Interval<<T as Add<T>>::Output>;
fn add(self, rhs: Self) -> Self::Output {
Interval(self.0 + rhs.0, self.1 + rhs.1)
}
}
impl<T: Sub<T> + Copy> Sub<T> for Interval<T> {
type Output = Interval<<T as Sub<T>>::Output>;
fn sub(self, rhs: T) -> Self::Output {
Interval(self.0 - rhs, self.1 - rhs)
}
}
impl<T: Sub<T> + Copy> Sub<Interval<T>> for Interval<T> {
type Output = Interval<<T as Sub<T>>::Output>;
fn sub(self, rhs: Self) -> Self::Output {
Interval(self.0 - rhs.1, self.1 - rhs.0)
}
}
impl<T: Neg> Neg for Interval<T> {
type Output = Interval<T::Output>;
fn neg(self) -> Self::Output {
Interval(-self.1, -self.0)
}
}
impl<T: Mul<T>> Mul<T> for Interval<T>
where
T: Copy + PartialOrd + Zero,
{
type Output = Interval<<T as Mul<T>>::Output>;
fn mul(self, rhs: T) -> Self::Output {
if rhs < T::zero() {
Interval(self.1 * rhs, self.0 * rhs)
} else {
Interval(self.0 * rhs, self.1 * rhs)
}
}
}
impl<T: Mul<T>> Mul<Interval<T>> for Interval<T>
where
T: Copy + PartialOrd + Zero,
<T as Mul<T>>::Output: SimdPartialOrd,
{
type Output = Interval<<T as Mul<T>>::Output>;
fn mul(self, rhs: Self) -> Self::Output {
let Interval(a1, a2) = self;
let Interval(b1, b2) = rhs;
if a2 <= T::zero() {
if b2 <= T::zero() {
Interval(a2 * b2, a1 * b1)
} else if b1 < T::zero() {
Interval(a1 * b2, a1 * b1)
} else {
Interval(a1 * b2, a2 * b1)
}
} else if a1 < T::zero() {
if b2 <= T::zero() {
Interval(a2 * b1, a1 * b1)
} else if b1 < T::zero() {
Interval((a1 * b2).simd_min(b2 * b1), (a1 * b1).simd_max(a2 * b2))
} else {
Interval(a1 * b2, a2 * b2)
}
} else if b2 <= T::zero() {
Interval(a2 * b1, a1 * b2)
} else if b1 < T::zero() {
Interval(a2 * b1, a2 * b2)
} else {
Interval(a1 * b1, a2 * b2)
}
}
}
impl<T: Div<T>> Div<Interval<T>> for Interval<T>
where
T: RealField + Copy,
<T as Div<T>>::Output: SimdPartialOrd,
{
type Output = (
Interval<<T as Div<T>>::Output>,
Option<Interval<<T as Div<T>>::Output>>,
);
fn div(self, rhs: Self) -> Self::Output {
let infinity = T::one() / T::zero();
let Interval(a1, a2) = self;
let Interval(b1, b2) = rhs;
if b1 <= T::zero() && b2 >= T::zero() {
if a2 < T::zero() {
if b2 == T::zero() {
(Interval(a2 / b1, infinity), None)
} else if b1 != T::zero() {
(
Interval(-infinity, a2 / b2),
Some(Interval(a2 / b1, infinity)),
)
} else {
(Interval(-infinity, a2 / b2), None)
}
} else if a1 <= T::zero() {
(Interval(-infinity, infinity), None)
} else if b2 == T::zero() {
(Interval(-infinity, a1 / b1), None)
} else if b1 != T::zero() {
(
Interval(-infinity, a1 / b1),
Some(Interval(a1 / b2, infinity)),
)
} else {
(Interval(a1 / b2, infinity), None)
}
} else if a2 <= T::zero() {
if b2 < T::zero() {
(Interval(a2 / b1, a1 / b2), None)
} else {
(Interval(a1 / b1, a2 / b2), None)
}
} else if a1 < T::zero() {
if b2 < T::zero() {
(Interval(a2 / b2, a1 / b2), None)
} else {
(Interval(a1 / b1, a2 / b1), None)
}
} else if b2 < T::zero() {
(Interval(a2 / b2, a1 / b1), None)
} else {
(Interval(a1 / b2, a2 / b1), None)
}
}
}
impl<T: Copy + Add<T, Output = T>> AddAssign<Interval<T>> for Interval<T> {
fn add_assign(&mut self, rhs: Interval<T>) {
*self = *self + rhs;
}
}
impl<T: Copy + Sub<T, Output = T>> SubAssign<Interval<T>> for Interval<T> {
fn sub_assign(&mut self, rhs: Interval<T>) {
*self = *self - rhs;
}
}
impl<T: Mul<T, Output = T>> MulAssign<Interval<T>> for Interval<T>
where
T: Copy + PartialOrd + Zero,
<T as Mul<T>>::Output: SimdPartialOrd,
{
fn mul_assign(&mut self, rhs: Interval<T>) {
*self = *self * rhs;
}
}
impl<T: Zero + Add<T>> Zero for Interval<T> {
fn zero() -> Self {
Self(T::zero(), T::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero() && self.1.is_zero()
}
}
impl<T: One + Mul<T>> One for Interval<T>
where
Interval<T>: Mul<Interval<T>, Output = Interval<T>>,
{
fn one() -> Self {
Self(T::one(), T::one())
}
}
#[cfg(test)]
mod test {
use super::{Interval, IntervalFunction};
use na::RealField;
#[test]
fn roots_sin() {
struct Sin;
impl IntervalFunction<f32> for Sin {
fn eval(&self, t: f32) -> f32 {
t.sin()
}
fn eval_interval(&self, t: Interval<f32>) -> Interval<f32> {
t.sin()
}
fn eval_interval_gradient(&self, t: Interval<f32>) -> Interval<f32> {
t.cos()
}
}
let function = Sin;
let roots = super::find_root_intervals(
&function,
Interval(0.0, f32::two_pi()),
1.0e-5,
1.0e-5,
100,
);
assert_eq!(roots.len(), 3);
}
#[test]
fn interval_sin_cos() {
let a = f32::pi() / 6.0;
let b = f32::pi() / 2.0 + f32::pi() / 6.0;
let c = f32::pi() + f32::pi() / 6.0;
let d = f32::pi() + f32::pi() / 2.0 + f32::pi() / 6.0;
let shifts = [0.0, f32::two_pi() * 100.0, -f32::two_pi() * 100.0];
for shift in shifts.iter() {
assert_eq!(
Interval(a + *shift, b + *shift).sin(),
Interval((a + *shift).sin(), 1.0)
);
assert_eq!(
Interval(a + *shift, c + *shift).sin(),
Interval((c + *shift).sin(), 1.0)
);
assert_eq!(Interval(a + *shift, d + *shift).sin(), Interval(-1.0, 1.0));
assert_eq!(
Interval(a + *shift, b + *shift).cos(),
Interval((b + *shift).cos(), (a + *shift).cos())
);
assert_eq!(
Interval(a + *shift, c + *shift).cos(),
Interval(-1.0, (a + *shift).cos())
);
assert_eq!(
Interval(a + *shift, d + *shift).cos(),
Interval(-1.0, (a + *shift).cos())
);
}
}
}