parry2d/utils/
interval.rs

1// "Complete Interval Arithmetic and its Implementation on the Computer"
2// Ulrich W. Kulisch
3use alloc::{vec, vec::Vec};
4use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
5use num::{One, Zero};
6use num_traits::AsPrimitive;
7use simba::scalar::RealField;
8
9/// A derivable valued function which can be bounded on intervals.
10pub trait IntervalFunction<T> {
11    /// Evaluate the function at `t`.
12    fn eval(&self, t: T) -> T;
13    /// Bounds all the values of this function on the interval `t`.
14    fn eval_interval(&self, t: Interval<T>) -> Interval<T>;
15    /// Bounds all the values of the gradient of this function on the interval `t`.
16    fn eval_interval_gradient(&self, t: Interval<T>) -> Interval<T>;
17}
18
19/// Execute the Interval Newton Method to isolate all the roots of the given nonlinear function.
20///
21/// The results are stored in `results`. The `candidate` buffer is just a workspace buffer used
22/// to avoid allocations.
23pub fn find_root_intervals_to<T: Copy + PartialOrd + RealField>(
24    function: &impl IntervalFunction<T>,
25    init: Interval<T>,
26    min_interval_width: T,
27    min_image_width: T,
28    max_recursions: usize,
29    results: &mut Vec<Interval<T>>,
30    candidates: &mut Vec<(Interval<T>, usize)>,
31) where
32    f64: AsPrimitive<T>,
33{
34    candidates.clear();
35
36    let push_candidate = |candidate,
37                          recursion,
38                          results: &mut Vec<Interval<T>>,
39                          candidates: &mut Vec<(Interval<T>, usize)>| {
40        let candidate_image = function.eval_interval(candidate);
41        let is_small_range =
42            candidate.width() < min_interval_width || candidate_image.width() < min_image_width;
43
44        if candidate_image.contains(T::zero()) {
45            if recursion == max_recursions || is_small_range {
46                results.push(candidate);
47            } else {
48                candidates.push((candidate, recursion + 1));
49            }
50        } else if is_small_range
51            && function.eval(candidate.midpoint()).abs() < T::default_epsilon().sqrt()
52        {
53            // If we have a small range, and we are close to zero,
54            // consider that we reached zero.
55            results.push(candidate);
56        }
57    };
58
59    push_candidate(init, 0, results, candidates);
60
61    while let Some((candidate, recursion)) = candidates.pop() {
62        // println!(
63        //     "Candidate: {:?}, recursion: {}, image: {:?}",
64        //     candidate,
65        //     recursion,
66        //     function.eval_interval(candidate)
67        // );
68
69        // NOTE: we don't check the max_recursions at the beginning of the
70        //       loop here because that would make us loose the candidate
71        //       we just popped.
72        let mid = candidate.midpoint();
73        let f_mid = function.eval(mid);
74        let gradient = function.eval_interval_gradient(candidate);
75        let (shift1, shift2) = Interval(f_mid, f_mid) / gradient;
76
77        let new_candidates = [
78            (Interval(mid, mid) - shift1).intersect(candidate),
79            shift2.and_then(|shift2| (Interval(mid, mid) - shift2).intersect(candidate)),
80        ];
81
82        let prev_width = candidate.width();
83
84        for new_candidate in new_candidates.iter().flatten() {
85            if new_candidate.width() > prev_width * 0.75.as_() {
86                // If the new candidate range is still quite big compared to
87                // new candidate, split it to accelerate the search.
88                let [a, b] = new_candidate.split();
89                push_candidate(a, recursion, results, candidates);
90                push_candidate(b, recursion, results, candidates);
91            } else {
92                push_candidate(*new_candidate, recursion, results, candidates);
93            }
94        }
95    }
96}
97
98/// Execute the Interval Newton Method to isolate all the roots of the given nonlinear function.
99pub fn find_root_intervals<T: Copy + PartialOrd + RealField>(
100    function: &impl IntervalFunction<T>,
101    init: Interval<T>,
102    min_interval_width: T,
103    min_image_width: T,
104    max_recursions: usize,
105) -> Vec<Interval<T>>
106where
107    f64: AsPrimitive<T>,
108{
109    let mut results = vec![];
110    let mut candidates = vec![];
111    find_root_intervals_to(
112        function,
113        init,
114        min_interval_width,
115        min_image_width,
116        max_recursions,
117        &mut results,
118        &mut candidates,
119    );
120    results
121}
122
123/// An interval implementing interval arithmetic.
124#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
125pub struct Interval<T>(pub T, pub T);
126
127impl<T> Interval<T> {
128    /// Create the interval `[min(a, b), max(a, b)]`.
129    #[must_use]
130    pub fn sort(a: T, b: T) -> Self
131    where
132        T: PartialOrd,
133    {
134        if a < b {
135            Self(a, b)
136        } else {
137            Self(b, a)
138        }
139    }
140
141    /// Create the interval `[e, e]` (single value).
142    #[must_use]
143    pub fn splat(e: T) -> Self
144    where
145        T: Clone,
146    {
147        Self(e.clone(), e)
148    }
149
150    /// Does this interval contain the given value?
151    #[must_use]
152    pub fn contains(&self, t: T) -> bool
153    where
154        T: PartialOrd<T>,
155    {
156        self.0 <= t && self.1 >= t
157    }
158
159    /// The width of this interval.
160    #[must_use]
161    pub fn width(self) -> T::Output
162    where
163        T: Sub<T>,
164    {
165        self.1 - self.0
166    }
167
168    /// The average of the two interval endpoints.
169    #[must_use]
170    pub fn midpoint(self) -> T
171    where
172        T: Copy + Add<T, Output = T> + Div<T, Output = T> + One,
173    {
174        let two = T::one() + T::one();
175        (self.0 + self.1) / two
176    }
177
178    /// Splits this interval at its midpoint.
179    #[must_use]
180    pub fn split(self) -> [Self; 2]
181    where
182        T: Copy + Add<T, Output = T> + Div<T, Output = T> + One,
183    {
184        let mid = self.midpoint();
185        [Interval(self.0, mid), Interval(mid, self.1)]
186    }
187
188    /// Computes a new interval that contains both `self` and `t`.
189    #[must_use]
190    pub fn enclose(self, t: T) -> Self
191    where
192        T: PartialOrd,
193    {
194        if t < self.0 {
195            Interval(t, self.1)
196        } else if t > self.1 {
197            Interval(self.0, t)
198        } else {
199            self
200        }
201    }
202
203    /// Computes the intersection between two intervals.
204    ///
205    /// Returns `None` if the intervals are disjoint.
206    #[must_use]
207    pub fn intersect(self, rhs: Self) -> Option<Self>
208    where
209        T: PartialOrd + RealField,
210    {
211        let result = Interval(self.0.max(rhs.0), self.1.min(rhs.1));
212
213        if result.0 > result.1 {
214            // The range is invalid if there is no intersection.
215            None
216        } else {
217            Some(result)
218        }
219    }
220
221    /// Bounds the image of the`sin` and `cos` functions on this interval.
222    #[must_use]
223    pub fn sin_cos(self) -> (Self, Self)
224    where
225        T: Copy + RealField,
226        f64: AsPrimitive<T>,
227    {
228        (self.sin(), self.cos())
229    }
230
231    /// Bounds the image of the sinus function on this interval.
232    #[must_use]
233    pub fn sin(self) -> Self
234    where
235        T: Copy + RealField,
236        f64: AsPrimitive<T>,
237    {
238        let two_pi = T::two_pi();
239        let one = T::one();
240        let pi = T::pi();
241        let frac_pi_2 = T::frac_pi_2();
242
243        if self.width() >= two_pi {
244            Interval(one * (-1.0).as_(), one)
245        } else {
246            let sin0 = self.0.sin();
247            let sin1 = self.1.sin();
248            let mut result = Interval::sort(sin0, sin1);
249
250            let orig = (self.0 / two_pi).floor() * two_pi;
251            let crit = [orig + frac_pi_2, orig + pi + frac_pi_2];
252            let crit_vals = [one, one * (-1.0).as_()];
253
254            for i in 0..2 {
255                if self.contains(crit[i]) || self.contains(crit[i] + two_pi) {
256                    result = result.enclose(crit_vals[i])
257                }
258            }
259
260            result
261        }
262    }
263
264    /// Bounds the image of the cosinus function on this interval.
265    #[must_use]
266    pub fn cos(self) -> Self
267    where
268        T: Copy + RealField,
269        f64: AsPrimitive<T>,
270    {
271        let two_pi = T::two_pi();
272        let one = T::one();
273        let pi = T::pi();
274
275        if self.width() >= two_pi {
276            Interval(one * (-1.0).as_(), one)
277        } else {
278            let cos0 = self.0.cos();
279            let cos1 = self.1.cos();
280            let mut result = Interval::sort(cos0, cos1);
281
282            let orig = (self.0 / two_pi).floor() * two_pi;
283            let crit = [orig, orig + pi];
284            let crit_vals = [one, one * (-1.0).as_()];
285
286            for i in 0..2 {
287                if self.contains(crit[i]) || self.contains(crit[i] + two_pi) {
288                    result = result.enclose(crit_vals[i])
289                }
290            }
291
292            result
293        }
294    }
295}
296
297impl<T: Add<T> + Copy> Add<T> for Interval<T> {
298    type Output = Interval<<T as Add<T>>::Output>;
299
300    fn add(self, rhs: T) -> Self::Output {
301        Interval(self.0 + rhs, self.1 + rhs)
302    }
303}
304
305impl<T: Add<T>> Add<Interval<T>> for Interval<T> {
306    type Output = Interval<<T as Add<T>>::Output>;
307
308    fn add(self, rhs: Self) -> Self::Output {
309        Interval(self.0 + rhs.0, self.1 + rhs.1)
310    }
311}
312
313impl<T: Sub<T> + Copy> Sub<T> for Interval<T> {
314    type Output = Interval<<T as Sub<T>>::Output>;
315
316    fn sub(self, rhs: T) -> Self::Output {
317        Interval(self.0 - rhs, self.1 - rhs)
318    }
319}
320
321impl<T: Sub<T> + Copy> Sub<Interval<T>> for Interval<T> {
322    type Output = Interval<<T as Sub<T>>::Output>;
323
324    fn sub(self, rhs: Self) -> Self::Output {
325        Interval(self.0 - rhs.1, self.1 - rhs.0)
326    }
327}
328
329impl<T: Neg> Neg for Interval<T> {
330    type Output = Interval<T::Output>;
331
332    fn neg(self) -> Self::Output {
333        Interval(-self.1, -self.0)
334    }
335}
336
337impl<T: Mul<T>> Mul<T> for Interval<T>
338where
339    T: Copy + PartialOrd + Zero,
340{
341    type Output = Interval<<T as Mul<T>>::Output>;
342
343    fn mul(self, rhs: T) -> Self::Output {
344        if rhs < T::zero() {
345            Interval(self.1 * rhs, self.0 * rhs)
346        } else {
347            Interval(self.0 * rhs, self.1 * rhs)
348        }
349    }
350}
351
352impl<T: Mul<T>> Mul<Interval<T>> for Interval<T>
353where
354    T: Copy + PartialOrd + Zero + RealField,
355    <T as Mul<T>>::Output: PartialOrd + RealField,
356{
357    type Output = Interval<<T as Mul<T>>::Output>;
358
359    fn mul(self, rhs: Self) -> Self::Output {
360        let Interval(a1, a2) = self;
361        let Interval(b1, b2) = rhs;
362
363        if a2 <= T::zero() {
364            if b2 <= T::zero() {
365                Interval(a2 * b2, a1 * b1)
366            } else if b1 < T::zero() {
367                Interval(a1 * b2, a1 * b1)
368            } else {
369                Interval(a1 * b2, a2 * b1)
370            }
371        } else if a1 < T::zero() {
372            if b2 <= T::zero() {
373                Interval(a2 * b1, a1 * b1)
374            } else if b1 < T::zero() {
375                Interval((a1 * b2).min(b2 * b1), (a1 * b1).max(a2 * b2))
376            } else {
377                Interval(a1 * b2, a2 * b2)
378            }
379        } else if b2 <= T::zero() {
380            Interval(a2 * b1, a1 * b2)
381        } else if b1 < T::zero() {
382            Interval(a2 * b1, a2 * b2)
383        } else {
384            Interval(a1 * b1, a2 * b2)
385        }
386    }
387}
388
389impl<T: Div<T>> Div<Interval<T>> for Interval<T>
390where
391    T: Copy + One + Zero + PartialOrd + 'static,
392    <T as Div<T>>::Output: PartialOrd + RealField,
393    f64: AsPrimitive<T>,
394{
395    type Output = (
396        Interval<<T as Div<T>>::Output>,
397        Option<Interval<<T as Div<T>>::Output>>,
398    );
399
400    fn div(self, rhs: Self) -> Self::Output {
401        let infinity = T::one() / T::zero();
402        let neg_one: T = (-1.0).as_();
403        let neg_infinity = neg_one / T::zero();
404
405        let Interval(a1, a2) = self;
406        let Interval(b1, b2) = rhs;
407
408        if b1 <= T::zero() && b2 >= T::zero() {
409            // rhs contains T::zero() so we my have to return
410            // two intervals.
411            if a2 < T::zero() {
412                if b2 == T::zero() {
413                    (Interval(a2 / b1, infinity), None)
414                } else if b1 != T::zero() {
415                    (
416                        Interval(neg_infinity, a2 / b2),
417                        Some(Interval(a2 / b1, infinity)),
418                    )
419                } else {
420                    (Interval(neg_infinity, a2 / b2), None)
421                }
422            } else if a1 <= T::zero() {
423                (Interval(neg_infinity, infinity), None)
424            } else if b2 == T::zero() {
425                (Interval(neg_infinity, a1 / b1), None)
426            } else if b1 != T::zero() {
427                (
428                    Interval(neg_infinity, a1 / b1),
429                    Some(Interval(a1 / b2, infinity)),
430                )
431            } else {
432                (Interval(a1 / b2, infinity), None)
433            }
434        } else if a2 <= T::zero() {
435            if b2 < T::zero() {
436                (Interval(a2 / b1, a1 / b2), None)
437            } else {
438                (Interval(a1 / b1, a2 / b2), None)
439            }
440        } else if a1 < T::zero() {
441            if b2 < T::zero() {
442                (Interval(a2 / b2, a1 / b2), None)
443            } else {
444                (Interval(a1 / b1, a2 / b1), None)
445            }
446        } else if b2 < T::zero() {
447            (Interval(a2 / b2, a1 / b1), None)
448        } else {
449            (Interval(a1 / b2, a2 / b1), None)
450        }
451    }
452}
453
454impl<T: Copy + Add<T, Output = T>> AddAssign<Interval<T>> for Interval<T> {
455    fn add_assign(&mut self, rhs: Interval<T>) {
456        *self = *self + rhs;
457    }
458}
459
460impl<T: Copy + Sub<T, Output = T>> SubAssign<Interval<T>> for Interval<T> {
461    fn sub_assign(&mut self, rhs: Interval<T>) {
462        *self = *self - rhs;
463    }
464}
465
466impl<T: Mul<T, Output = T>> MulAssign<Interval<T>> for Interval<T>
467where
468    T: Copy + PartialOrd + Zero + RealField,
469    <T as Mul<T>>::Output: PartialOrd + RealField,
470{
471    fn mul_assign(&mut self, rhs: Interval<T>) {
472        *self = *self * rhs;
473    }
474}
475
476impl<T: Zero + Add<T>> Zero for Interval<T> {
477    fn zero() -> Self {
478        Self(T::zero(), T::zero())
479    }
480
481    fn is_zero(&self) -> bool {
482        self.0.is_zero() && self.1.is_zero()
483    }
484}
485
486impl<T: One + Mul<T>> One for Interval<T>
487where
488    Interval<T>: Mul<Interval<T>, Output = Interval<T>>,
489{
490    fn one() -> Self {
491        Self(T::one(), T::one())
492    }
493}
494
495#[cfg(test)]
496mod test {
497    use super::{Interval, IntervalFunction};
498
499    #[test]
500    fn roots_sin() {
501        struct Sin;
502
503        impl IntervalFunction<f32> for Sin {
504            fn eval(&self, t: f32) -> f32 {
505                t.sin()
506            }
507
508            fn eval_interval(&self, t: Interval<f32>) -> Interval<f32> {
509                t.sin()
510            }
511
512            fn eval_interval_gradient(&self, t: Interval<f32>) -> Interval<f32> {
513                t.cos()
514            }
515        }
516
517        let function = Sin;
518        let two_pi = core::f32::consts::TAU;
519        let roots =
520            super::find_root_intervals(&function, Interval(0.0, two_pi), 1.0e-5, 1.0e-5, 100);
521        assert_eq!(roots.len(), 3);
522    }
523
524    #[test]
525    fn interval_sin_cos() {
526        let a = core::f32::consts::PI / 6.0;
527        let b = core::f32::consts::FRAC_PI_2 + core::f32::consts::PI / 6.0;
528        let c = core::f32::consts::PI + core::f32::consts::PI / 6.0;
529        let d = core::f32::consts::PI + core::f32::consts::FRAC_PI_2 + core::f32::consts::PI / 6.0;
530        let two_pi = core::f32::consts::TAU;
531        let shifts = [0.0, two_pi * 100.0, -two_pi * 100.0];
532
533        for shift in shifts.iter() {
534            // Test sinus.
535            assert_eq!(
536                Interval(a + *shift, b + *shift).sin(),
537                Interval((a + *shift).sin(), 1.0)
538            );
539            assert_eq!(
540                Interval(a + *shift, c + *shift).sin(),
541                Interval((c + *shift).sin(), 1.0)
542            );
543            assert_eq!(Interval(a + *shift, d + *shift).sin(), Interval(-1.0, 1.0));
544
545            // Test cosinus.
546            assert_eq!(
547                Interval(a + *shift, b + *shift).cos(),
548                Interval((b + *shift).cos(), (a + *shift).cos())
549            );
550            assert_eq!(
551                Interval(a + *shift, c + *shift).cos(),
552                Interval(-1.0, (a + *shift).cos())
553            );
554            assert_eq!(
555                Interval(a + *shift, d + *shift).cos(),
556                Interval(-1.0, (a + *shift).cos())
557            );
558        }
559    }
560}