parry3d/utils/
interval.rs

1// "Complete Interval Arithmetic and its Implementation on the Computer"
2// Ulrich W. Kulisch
3use alloc::{vec, vec::Vec};
4use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
5use na::{RealField, SimdPartialOrd};
6use num::{One, Zero};
7
8/// A derivable valued function which can be bounded on intervals.
9pub trait IntervalFunction<T> {
10    /// Evaluate the function at `t`.
11    fn eval(&self, t: T) -> T;
12    /// Bounds all the values of this function on the interval `t`.
13    fn eval_interval(&self, t: Interval<T>) -> Interval<T>;
14    /// Bounds all the values of the gradient of this function on the interval `t`.
15    fn eval_interval_gradient(&self, t: Interval<T>) -> Interval<T>;
16}
17
18/// Execute the Interval Newton Method to isolate all the roots of the given nonlinear function.
19///
20/// The results are stored in `results`. The `candidate` buffer is just a workspace buffer used
21/// to avoid allocations.
22pub fn find_root_intervals_to<T: RealField + Copy>(
23    function: &impl IntervalFunction<T>,
24    init: Interval<T>,
25    min_interval_width: T,
26    min_image_width: T,
27    max_recursions: usize,
28    results: &mut Vec<Interval<T>>,
29    candidates: &mut Vec<(Interval<T>, usize)>,
30) {
31    candidates.clear();
32
33    let push_candidate = |candidate,
34                          recursion,
35                          results: &mut Vec<Interval<T>>,
36                          candidates: &mut Vec<(Interval<T>, usize)>| {
37        let candidate_image = function.eval_interval(candidate);
38        let is_small_range =
39            candidate.width() < min_interval_width || candidate_image.width() < min_image_width;
40
41        if candidate_image.contains(T::zero()) {
42            if recursion == max_recursions || is_small_range {
43                results.push(candidate);
44            } else {
45                candidates.push((candidate, recursion + 1));
46            }
47        } else if is_small_range
48            && function.eval(candidate.midpoint()).abs() < T::default_epsilon().sqrt()
49        {
50            // If we have a small range, and we are close to zero,
51            // consider that we reached zero.
52            results.push(candidate);
53        }
54    };
55
56    push_candidate(init, 0, results, candidates);
57
58    while let Some((candidate, recursion)) = candidates.pop() {
59        // println!(
60        //     "Candidate: {:?}, recursion: {}, image: {:?}",
61        //     candidate,
62        //     recursion,
63        //     function.eval_interval(candidate)
64        // );
65
66        // NOTE: we don't check the max_recursions at the beginning of the
67        //       loop here because that would make us loose the candidate
68        //       we just popped.
69        let mid = candidate.midpoint();
70        let f_mid = function.eval(mid);
71        let gradient = function.eval_interval_gradient(candidate);
72        let (shift1, shift2) = Interval(f_mid, f_mid) / gradient;
73
74        let new_candidates = [
75            (Interval(mid, mid) - shift1).intersect(candidate),
76            shift2.and_then(|shift2| (Interval(mid, mid) - shift2).intersect(candidate)),
77        ];
78
79        let prev_width = candidate.width();
80
81        for new_candidate in new_candidates.iter().flatten() {
82            if new_candidate.width() > prev_width * na::convert(0.75) {
83                // If the new candidate range is still quite big compared to
84                // new candidate, split it to accelerate the search.
85                let [a, b] = new_candidate.split();
86                push_candidate(a, recursion, results, candidates);
87                push_candidate(b, recursion, results, candidates);
88            } else {
89                push_candidate(*new_candidate, recursion, results, candidates);
90            }
91        }
92    }
93}
94
95/// Execute the Interval Newton Method to isolate all the roots of the given nonlinear function.
96pub fn find_root_intervals<T: RealField + Copy>(
97    function: &impl IntervalFunction<T>,
98    init: Interval<T>,
99    min_interval_width: T,
100    min_image_width: T,
101    max_recursions: usize,
102) -> Vec<Interval<T>> {
103    let mut results = vec![];
104    let mut candidates = vec![];
105    find_root_intervals_to(
106        function,
107        init,
108        min_interval_width,
109        min_image_width,
110        max_recursions,
111        &mut results,
112        &mut candidates,
113    );
114    results
115}
116
117/// An interval implementing interval arithmetic.
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
119pub struct Interval<T>(pub T, pub T);
120
121impl<T> Interval<T> {
122    /// Create the interval `[min(a, b), max(a, b)]`.
123    #[must_use]
124    pub fn sort(a: T, b: T) -> Self
125    where
126        T: PartialOrd,
127    {
128        if a < b {
129            Self(a, b)
130        } else {
131            Self(b, a)
132        }
133    }
134
135    /// Create the interval `[e, e]` (single value).
136    #[must_use]
137    pub fn splat(e: T) -> Self
138    where
139        T: Clone,
140    {
141        Self(e.clone(), e)
142    }
143
144    /// Does this interval contain the given value?
145    #[must_use]
146    pub fn contains(&self, t: T) -> bool
147    where
148        T: PartialOrd<T>,
149    {
150        self.0 <= t && self.1 >= t
151    }
152
153    /// The width of this interval.
154    #[must_use]
155    pub fn width(self) -> T::Output
156    where
157        T: Sub<T>,
158    {
159        self.1 - self.0
160    }
161
162    /// The average of the two interval endpoints.
163    #[must_use]
164    pub fn midpoint(self) -> T
165    where
166        T: RealField + Copy,
167    {
168        let two: T = na::convert(2.0);
169        (self.0 + self.1) / two
170    }
171
172    /// Splits this interval at its midpoint.
173    #[must_use]
174    pub fn split(self) -> [Self; 2]
175    where
176        T: RealField + Copy,
177    {
178        let mid = self.midpoint();
179        [Interval(self.0, mid), Interval(mid, self.1)]
180    }
181
182    /// Computes a new interval that contains both `self` and `t`.
183    #[must_use]
184    pub fn enclose(self, t: T) -> Self
185    where
186        T: PartialOrd,
187    {
188        if t < self.0 {
189            Interval(t, self.1)
190        } else if t > self.1 {
191            Interval(self.0, t)
192        } else {
193            self
194        }
195    }
196
197    /// Computes the intersection between two intervals.
198    ///
199    /// Returns `None` if the intervals are disjoint.
200    #[must_use]
201    pub fn intersect(self, rhs: Self) -> Option<Self>
202    where
203        T: PartialOrd + SimdPartialOrd, // TODO: it is weird to have both.
204    {
205        let result = Interval(self.0.simd_max(rhs.0), self.1.simd_min(rhs.1));
206
207        if result.0 > result.1 {
208            // The range is invalid if there is no intersection.
209            None
210        } else {
211            Some(result)
212        }
213    }
214
215    /// Bounds the image of the`sin` and `cos` functions on this interval.
216    #[must_use]
217    pub fn sin_cos(self) -> (Self, Self)
218    where
219        T: RealField + Copy,
220    {
221        (self.sin(), self.cos())
222    }
223
224    /// Bounds the image of the sinus function on this interval.
225    #[must_use]
226    pub fn sin(self) -> Self
227    where
228        T: RealField + Copy,
229    {
230        if self.width() >= T::two_pi() {
231            Interval(-T::one(), T::one())
232        } else {
233            let sin0 = self.0.sin();
234            let sin1 = self.1.sin();
235            let mut result = Interval::sort(sin0, sin1);
236
237            let orig = (self.0 / T::two_pi()).floor() * T::two_pi();
238            let crit = [orig + T::frac_pi_2(), orig + T::pi() + T::frac_pi_2()];
239            let crit_vals = [T::one(), -T::one()];
240
241            for i in 0..2 {
242                if self.contains(crit[i]) || self.contains(crit[i] + T::two_pi()) {
243                    result = result.enclose(crit_vals[i])
244                }
245            }
246
247            result
248        }
249    }
250
251    /// Bounds the image of the cosinus function on this interval.
252    #[must_use]
253    pub fn cos(self) -> Self
254    where
255        T: RealField + Copy,
256    {
257        if self.width() >= T::two_pi() {
258            Interval(-T::one(), T::one())
259        } else {
260            let cos0 = self.0.cos();
261            let cos1 = self.1.cos();
262            let mut result = Interval::sort(cos0, cos1);
263
264            let orig = (self.0 / T::two_pi()).floor() * T::two_pi();
265            let crit = [orig, orig + T::pi()];
266            let crit_vals = [T::one(), -T::one()];
267
268            for i in 0..2 {
269                if self.contains(crit[i]) || self.contains(crit[i] + T::two_pi()) {
270                    result = result.enclose(crit_vals[i])
271                }
272            }
273
274            result
275        }
276    }
277}
278
279impl<T: Add<T> + Copy> Add<T> for Interval<T> {
280    type Output = Interval<<T as Add<T>>::Output>;
281
282    fn add(self, rhs: T) -> Self::Output {
283        Interval(self.0 + rhs, self.1 + rhs)
284    }
285}
286
287impl<T: Add<T>> Add<Interval<T>> for Interval<T> {
288    type Output = Interval<<T as Add<T>>::Output>;
289
290    fn add(self, rhs: Self) -> Self::Output {
291        Interval(self.0 + rhs.0, self.1 + rhs.1)
292    }
293}
294
295impl<T: Sub<T> + Copy> Sub<T> for Interval<T> {
296    type Output = Interval<<T as Sub<T>>::Output>;
297
298    fn sub(self, rhs: T) -> Self::Output {
299        Interval(self.0 - rhs, self.1 - rhs)
300    }
301}
302
303impl<T: Sub<T> + Copy> Sub<Interval<T>> for Interval<T> {
304    type Output = Interval<<T as Sub<T>>::Output>;
305
306    fn sub(self, rhs: Self) -> Self::Output {
307        Interval(self.0 - rhs.1, self.1 - rhs.0)
308    }
309}
310
311impl<T: Neg> Neg for Interval<T> {
312    type Output = Interval<T::Output>;
313
314    fn neg(self) -> Self::Output {
315        Interval(-self.1, -self.0)
316    }
317}
318
319impl<T: Mul<T>> Mul<T> for Interval<T>
320where
321    T: Copy + PartialOrd + Zero,
322{
323    type Output = Interval<<T as Mul<T>>::Output>;
324
325    fn mul(self, rhs: T) -> Self::Output {
326        if rhs < T::zero() {
327            Interval(self.1 * rhs, self.0 * rhs)
328        } else {
329            Interval(self.0 * rhs, self.1 * rhs)
330        }
331    }
332}
333
334impl<T: Mul<T>> Mul<Interval<T>> for Interval<T>
335where
336    T: Copy + PartialOrd + Zero,
337    <T as Mul<T>>::Output: SimdPartialOrd,
338{
339    type Output = Interval<<T as Mul<T>>::Output>;
340
341    fn mul(self, rhs: Self) -> Self::Output {
342        let Interval(a1, a2) = self;
343        let Interval(b1, b2) = rhs;
344
345        if a2 <= T::zero() {
346            if b2 <= T::zero() {
347                Interval(a2 * b2, a1 * b1)
348            } else if b1 < T::zero() {
349                Interval(a1 * b2, a1 * b1)
350            } else {
351                Interval(a1 * b2, a2 * b1)
352            }
353        } else if a1 < T::zero() {
354            if b2 <= T::zero() {
355                Interval(a2 * b1, a1 * b1)
356            } else if b1 < T::zero() {
357                Interval((a1 * b2).simd_min(b2 * b1), (a1 * b1).simd_max(a2 * b2))
358            } else {
359                Interval(a1 * b2, a2 * b2)
360            }
361        } else if b2 <= T::zero() {
362            Interval(a2 * b1, a1 * b2)
363        } else if b1 < T::zero() {
364            Interval(a2 * b1, a2 * b2)
365        } else {
366            Interval(a1 * b1, a2 * b2)
367        }
368    }
369}
370
371impl<T: Div<T>> Div<Interval<T>> for Interval<T>
372where
373    T: RealField + Copy,
374    <T as Div<T>>::Output: SimdPartialOrd,
375{
376    type Output = (
377        Interval<<T as Div<T>>::Output>,
378        Option<Interval<<T as Div<T>>::Output>>,
379    );
380
381    fn div(self, rhs: Self) -> Self::Output {
382        let infinity = T::one() / T::zero();
383
384        let Interval(a1, a2) = self;
385        let Interval(b1, b2) = rhs;
386
387        if b1 <= T::zero() && b2 >= T::zero() {
388            // rhs contains T::zero() so we my have to return
389            // two intervals.
390            if a2 < T::zero() {
391                if b2 == T::zero() {
392                    (Interval(a2 / b1, infinity), None)
393                } else if b1 != T::zero() {
394                    (
395                        Interval(-infinity, a2 / b2),
396                        Some(Interval(a2 / b1, infinity)),
397                    )
398                } else {
399                    (Interval(-infinity, a2 / b2), None)
400                }
401            } else if a1 <= T::zero() {
402                (Interval(-infinity, infinity), None)
403            } else if b2 == T::zero() {
404                (Interval(-infinity, a1 / b1), None)
405            } else if b1 != T::zero() {
406                (
407                    Interval(-infinity, a1 / b1),
408                    Some(Interval(a1 / b2, infinity)),
409                )
410            } else {
411                (Interval(a1 / b2, infinity), None)
412            }
413        } else if a2 <= T::zero() {
414            if b2 < T::zero() {
415                (Interval(a2 / b1, a1 / b2), None)
416            } else {
417                (Interval(a1 / b1, a2 / b2), None)
418            }
419        } else if a1 < T::zero() {
420            if b2 < T::zero() {
421                (Interval(a2 / b2, a1 / b2), None)
422            } else {
423                (Interval(a1 / b1, a2 / b1), None)
424            }
425        } else if b2 < T::zero() {
426            (Interval(a2 / b2, a1 / b1), None)
427        } else {
428            (Interval(a1 / b2, a2 / b1), None)
429        }
430    }
431}
432
433impl<T: Copy + Add<T, Output = T>> AddAssign<Interval<T>> for Interval<T> {
434    fn add_assign(&mut self, rhs: Interval<T>) {
435        *self = *self + rhs;
436    }
437}
438
439impl<T: Copy + Sub<T, Output = T>> SubAssign<Interval<T>> for Interval<T> {
440    fn sub_assign(&mut self, rhs: Interval<T>) {
441        *self = *self - rhs;
442    }
443}
444
445impl<T: Mul<T, Output = T>> MulAssign<Interval<T>> for Interval<T>
446where
447    T: Copy + PartialOrd + Zero,
448    <T as Mul<T>>::Output: SimdPartialOrd,
449{
450    fn mul_assign(&mut self, rhs: Interval<T>) {
451        *self = *self * rhs;
452    }
453}
454
455impl<T: Zero + Add<T>> Zero for Interval<T> {
456    fn zero() -> Self {
457        Self(T::zero(), T::zero())
458    }
459
460    fn is_zero(&self) -> bool {
461        self.0.is_zero() && self.1.is_zero()
462    }
463}
464
465impl<T: One + Mul<T>> One for Interval<T>
466where
467    Interval<T>: Mul<Interval<T>, Output = Interval<T>>,
468{
469    fn one() -> Self {
470        Self(T::one(), T::one())
471    }
472}
473
474#[cfg(test)]
475mod test {
476    use super::{Interval, IntervalFunction};
477    use na::RealField;
478
479    #[test]
480    fn roots_sin() {
481        struct Sin;
482
483        impl IntervalFunction<f32> for Sin {
484            fn eval(&self, t: f32) -> f32 {
485                t.sin()
486            }
487
488            fn eval_interval(&self, t: Interval<f32>) -> Interval<f32> {
489                t.sin()
490            }
491
492            fn eval_interval_gradient(&self, t: Interval<f32>) -> Interval<f32> {
493                t.cos()
494            }
495        }
496
497        let function = Sin;
498        let roots = super::find_root_intervals(
499            &function,
500            Interval(0.0, f32::two_pi()),
501            1.0e-5,
502            1.0e-5,
503            100,
504        );
505        assert_eq!(roots.len(), 3);
506    }
507
508    #[test]
509    fn interval_sin_cos() {
510        let a = f32::pi() / 6.0;
511        let b = f32::pi() / 2.0 + f32::pi() / 6.0;
512        let c = f32::pi() + f32::pi() / 6.0;
513        let d = f32::pi() + f32::pi() / 2.0 + f32::pi() / 6.0;
514        let shifts = [0.0, f32::two_pi() * 100.0, -f32::two_pi() * 100.0];
515
516        for shift in shifts.iter() {
517            // Test sinus.
518            assert_eq!(
519                Interval(a + *shift, b + *shift).sin(),
520                Interval((a + *shift).sin(), 1.0)
521            );
522            assert_eq!(
523                Interval(a + *shift, c + *shift).sin(),
524                Interval((c + *shift).sin(), 1.0)
525            );
526            assert_eq!(Interval(a + *shift, d + *shift).sin(), Interval(-1.0, 1.0));
527
528            // Test cosinus.
529            assert_eq!(
530                Interval(a + *shift, b + *shift).cos(),
531                Interval((b + *shift).cos(), (a + *shift).cos())
532            );
533            assert_eq!(
534                Interval(a + *shift, c + *shift).cos(),
535                Interval(-1.0, (a + *shift).cos())
536            );
537            assert_eq!(
538                Interval(a + *shift, d + *shift).cos(),
539                Interval(-1.0, (a + *shift).cos())
540            );
541        }
542    }
543}