1use alloc::{vec, vec::Vec};
4use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
5use num::{One, Zero};
6use num_traits::AsPrimitive;
7use simba::scalar::RealField;
8
9pub trait IntervalFunction<T> {
11 fn eval(&self, t: T) -> T;
13 fn eval_interval(&self, t: Interval<T>) -> Interval<T>;
15 fn eval_interval_gradient(&self, t: Interval<T>) -> Interval<T>;
17}
18
19pub fn find_root_intervals_to<T: Copy + PartialOrd + RealField>(
24 function: &impl IntervalFunction<T>,
25 init: Interval<T>,
26 min_interval_width: T,
27 min_image_width: T,
28 max_recursions: usize,
29 results: &mut Vec<Interval<T>>,
30 candidates: &mut Vec<(Interval<T>, usize)>,
31) where
32 f64: AsPrimitive<T>,
33{
34 candidates.clear();
35
36 let push_candidate = |candidate,
37 recursion,
38 results: &mut Vec<Interval<T>>,
39 candidates: &mut Vec<(Interval<T>, usize)>| {
40 let candidate_image = function.eval_interval(candidate);
41 let is_small_range =
42 candidate.width() < min_interval_width || candidate_image.width() < min_image_width;
43
44 if candidate_image.contains(T::zero()) {
45 if recursion == max_recursions || is_small_range {
46 results.push(candidate);
47 } else {
48 candidates.push((candidate, recursion + 1));
49 }
50 } else if is_small_range
51 && function.eval(candidate.midpoint()).abs() < T::default_epsilon().sqrt()
52 {
53 results.push(candidate);
56 }
57 };
58
59 push_candidate(init, 0, results, candidates);
60
61 while let Some((candidate, recursion)) = candidates.pop() {
62 let mid = candidate.midpoint();
73 let f_mid = function.eval(mid);
74 let gradient = function.eval_interval_gradient(candidate);
75 let (shift1, shift2) = Interval(f_mid, f_mid) / gradient;
76
77 let new_candidates = [
78 (Interval(mid, mid) - shift1).intersect(candidate),
79 shift2.and_then(|shift2| (Interval(mid, mid) - shift2).intersect(candidate)),
80 ];
81
82 let prev_width = candidate.width();
83
84 for new_candidate in new_candidates.iter().flatten() {
85 if new_candidate.width() > prev_width * 0.75.as_() {
86 let [a, b] = new_candidate.split();
89 push_candidate(a, recursion, results, candidates);
90 push_candidate(b, recursion, results, candidates);
91 } else {
92 push_candidate(*new_candidate, recursion, results, candidates);
93 }
94 }
95 }
96}
97
98pub fn find_root_intervals<T: Copy + PartialOrd + RealField>(
100 function: &impl IntervalFunction<T>,
101 init: Interval<T>,
102 min_interval_width: T,
103 min_image_width: T,
104 max_recursions: usize,
105) -> Vec<Interval<T>>
106where
107 f64: AsPrimitive<T>,
108{
109 let mut results = vec![];
110 let mut candidates = vec![];
111 find_root_intervals_to(
112 function,
113 init,
114 min_interval_width,
115 min_image_width,
116 max_recursions,
117 &mut results,
118 &mut candidates,
119 );
120 results
121}
122
123#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
125pub struct Interval<T>(pub T, pub T);
126
127impl<T> Interval<T> {
128 #[must_use]
130 pub fn sort(a: T, b: T) -> Self
131 where
132 T: PartialOrd,
133 {
134 if a < b {
135 Self(a, b)
136 } else {
137 Self(b, a)
138 }
139 }
140
141 #[must_use]
143 pub fn splat(e: T) -> Self
144 where
145 T: Clone,
146 {
147 Self(e.clone(), e)
148 }
149
150 #[must_use]
152 pub fn contains(&self, t: T) -> bool
153 where
154 T: PartialOrd<T>,
155 {
156 self.0 <= t && self.1 >= t
157 }
158
159 #[must_use]
161 pub fn width(self) -> T::Output
162 where
163 T: Sub<T>,
164 {
165 self.1 - self.0
166 }
167
168 #[must_use]
170 pub fn midpoint(self) -> T
171 where
172 T: Copy + Add<T, Output = T> + Div<T, Output = T> + One,
173 {
174 let two = T::one() + T::one();
175 (self.0 + self.1) / two
176 }
177
178 #[must_use]
180 pub fn split(self) -> [Self; 2]
181 where
182 T: Copy + Add<T, Output = T> + Div<T, Output = T> + One,
183 {
184 let mid = self.midpoint();
185 [Interval(self.0, mid), Interval(mid, self.1)]
186 }
187
188 #[must_use]
190 pub fn enclose(self, t: T) -> Self
191 where
192 T: PartialOrd,
193 {
194 if t < self.0 {
195 Interval(t, self.1)
196 } else if t > self.1 {
197 Interval(self.0, t)
198 } else {
199 self
200 }
201 }
202
203 #[must_use]
207 pub fn intersect(self, rhs: Self) -> Option<Self>
208 where
209 T: PartialOrd + RealField,
210 {
211 let result = Interval(self.0.max(rhs.0), self.1.min(rhs.1));
212
213 if result.0 > result.1 {
214 None
216 } else {
217 Some(result)
218 }
219 }
220
221 #[must_use]
223 pub fn sin_cos(self) -> (Self, Self)
224 where
225 T: Copy + RealField,
226 f64: AsPrimitive<T>,
227 {
228 (self.sin(), self.cos())
229 }
230
231 #[must_use]
233 pub fn sin(self) -> Self
234 where
235 T: Copy + RealField,
236 f64: AsPrimitive<T>,
237 {
238 let two_pi = T::two_pi();
239 let one = T::one();
240 let pi = T::pi();
241 let frac_pi_2 = T::frac_pi_2();
242
243 if self.width() >= two_pi {
244 Interval(one * (-1.0).as_(), one)
245 } else {
246 let sin0 = self.0.sin();
247 let sin1 = self.1.sin();
248 let mut result = Interval::sort(sin0, sin1);
249
250 let orig = (self.0 / two_pi).floor() * two_pi;
251 let crit = [orig + frac_pi_2, orig + pi + frac_pi_2];
252 let crit_vals = [one, one * (-1.0).as_()];
253
254 for i in 0..2 {
255 if self.contains(crit[i]) || self.contains(crit[i] + two_pi) {
256 result = result.enclose(crit_vals[i])
257 }
258 }
259
260 result
261 }
262 }
263
264 #[must_use]
266 pub fn cos(self) -> Self
267 where
268 T: Copy + RealField,
269 f64: AsPrimitive<T>,
270 {
271 let two_pi = T::two_pi();
272 let one = T::one();
273 let pi = T::pi();
274
275 if self.width() >= two_pi {
276 Interval(one * (-1.0).as_(), one)
277 } else {
278 let cos0 = self.0.cos();
279 let cos1 = self.1.cos();
280 let mut result = Interval::sort(cos0, cos1);
281
282 let orig = (self.0 / two_pi).floor() * two_pi;
283 let crit = [orig, orig + pi];
284 let crit_vals = [one, one * (-1.0).as_()];
285
286 for i in 0..2 {
287 if self.contains(crit[i]) || self.contains(crit[i] + two_pi) {
288 result = result.enclose(crit_vals[i])
289 }
290 }
291
292 result
293 }
294 }
295}
296
297impl<T: Add<T> + Copy> Add<T> for Interval<T> {
298 type Output = Interval<<T as Add<T>>::Output>;
299
300 fn add(self, rhs: T) -> Self::Output {
301 Interval(self.0 + rhs, self.1 + rhs)
302 }
303}
304
305impl<T: Add<T>> Add<Interval<T>> for Interval<T> {
306 type Output = Interval<<T as Add<T>>::Output>;
307
308 fn add(self, rhs: Self) -> Self::Output {
309 Interval(self.0 + rhs.0, self.1 + rhs.1)
310 }
311}
312
313impl<T: Sub<T> + Copy> Sub<T> for Interval<T> {
314 type Output = Interval<<T as Sub<T>>::Output>;
315
316 fn sub(self, rhs: T) -> Self::Output {
317 Interval(self.0 - rhs, self.1 - rhs)
318 }
319}
320
321impl<T: Sub<T> + Copy> Sub<Interval<T>> for Interval<T> {
322 type Output = Interval<<T as Sub<T>>::Output>;
323
324 fn sub(self, rhs: Self) -> Self::Output {
325 Interval(self.0 - rhs.1, self.1 - rhs.0)
326 }
327}
328
329impl<T: Neg> Neg for Interval<T> {
330 type Output = Interval<T::Output>;
331
332 fn neg(self) -> Self::Output {
333 Interval(-self.1, -self.0)
334 }
335}
336
337impl<T: Mul<T>> Mul<T> for Interval<T>
338where
339 T: Copy + PartialOrd + Zero,
340{
341 type Output = Interval<<T as Mul<T>>::Output>;
342
343 fn mul(self, rhs: T) -> Self::Output {
344 if rhs < T::zero() {
345 Interval(self.1 * rhs, self.0 * rhs)
346 } else {
347 Interval(self.0 * rhs, self.1 * rhs)
348 }
349 }
350}
351
352impl<T: Mul<T>> Mul<Interval<T>> for Interval<T>
353where
354 T: Copy + PartialOrd + Zero + RealField,
355 <T as Mul<T>>::Output: PartialOrd + RealField,
356{
357 type Output = Interval<<T as Mul<T>>::Output>;
358
359 fn mul(self, rhs: Self) -> Self::Output {
360 let Interval(a1, a2) = self;
361 let Interval(b1, b2) = rhs;
362
363 if a2 <= T::zero() {
364 if b2 <= T::zero() {
365 Interval(a2 * b2, a1 * b1)
366 } else if b1 < T::zero() {
367 Interval(a1 * b2, a1 * b1)
368 } else {
369 Interval(a1 * b2, a2 * b1)
370 }
371 } else if a1 < T::zero() {
372 if b2 <= T::zero() {
373 Interval(a2 * b1, a1 * b1)
374 } else if b1 < T::zero() {
375 Interval((a1 * b2).min(b2 * b1), (a1 * b1).max(a2 * b2))
376 } else {
377 Interval(a1 * b2, a2 * b2)
378 }
379 } else if b2 <= T::zero() {
380 Interval(a2 * b1, a1 * b2)
381 } else if b1 < T::zero() {
382 Interval(a2 * b1, a2 * b2)
383 } else {
384 Interval(a1 * b1, a2 * b2)
385 }
386 }
387}
388
389impl<T: Div<T>> Div<Interval<T>> for Interval<T>
390where
391 T: Copy + One + Zero + PartialOrd + 'static,
392 <T as Div<T>>::Output: PartialOrd + RealField,
393 f64: AsPrimitive<T>,
394{
395 type Output = (
396 Interval<<T as Div<T>>::Output>,
397 Option<Interval<<T as Div<T>>::Output>>,
398 );
399
400 fn div(self, rhs: Self) -> Self::Output {
401 let infinity = T::one() / T::zero();
402 let neg_one: T = (-1.0).as_();
403 let neg_infinity = neg_one / T::zero();
404
405 let Interval(a1, a2) = self;
406 let Interval(b1, b2) = rhs;
407
408 if b1 <= T::zero() && b2 >= T::zero() {
409 if a2 < T::zero() {
412 if b2 == T::zero() {
413 (Interval(a2 / b1, infinity), None)
414 } else if b1 != T::zero() {
415 (
416 Interval(neg_infinity, a2 / b2),
417 Some(Interval(a2 / b1, infinity)),
418 )
419 } else {
420 (Interval(neg_infinity, a2 / b2), None)
421 }
422 } else if a1 <= T::zero() {
423 (Interval(neg_infinity, infinity), None)
424 } else if b2 == T::zero() {
425 (Interval(neg_infinity, a1 / b1), None)
426 } else if b1 != T::zero() {
427 (
428 Interval(neg_infinity, a1 / b1),
429 Some(Interval(a1 / b2, infinity)),
430 )
431 } else {
432 (Interval(a1 / b2, infinity), None)
433 }
434 } else if a2 <= T::zero() {
435 if b2 < T::zero() {
436 (Interval(a2 / b1, a1 / b2), None)
437 } else {
438 (Interval(a1 / b1, a2 / b2), None)
439 }
440 } else if a1 < T::zero() {
441 if b2 < T::zero() {
442 (Interval(a2 / b2, a1 / b2), None)
443 } else {
444 (Interval(a1 / b1, a2 / b1), None)
445 }
446 } else if b2 < T::zero() {
447 (Interval(a2 / b2, a1 / b1), None)
448 } else {
449 (Interval(a1 / b2, a2 / b1), None)
450 }
451 }
452}
453
454impl<T: Copy + Add<T, Output = T>> AddAssign<Interval<T>> for Interval<T> {
455 fn add_assign(&mut self, rhs: Interval<T>) {
456 *self = *self + rhs;
457 }
458}
459
460impl<T: Copy + Sub<T, Output = T>> SubAssign<Interval<T>> for Interval<T> {
461 fn sub_assign(&mut self, rhs: Interval<T>) {
462 *self = *self - rhs;
463 }
464}
465
466impl<T: Mul<T, Output = T>> MulAssign<Interval<T>> for Interval<T>
467where
468 T: Copy + PartialOrd + Zero + RealField,
469 <T as Mul<T>>::Output: PartialOrd + RealField,
470{
471 fn mul_assign(&mut self, rhs: Interval<T>) {
472 *self = *self * rhs;
473 }
474}
475
476impl<T: Zero + Add<T>> Zero for Interval<T> {
477 fn zero() -> Self {
478 Self(T::zero(), T::zero())
479 }
480
481 fn is_zero(&self) -> bool {
482 self.0.is_zero() && self.1.is_zero()
483 }
484}
485
486impl<T: One + Mul<T>> One for Interval<T>
487where
488 Interval<T>: Mul<Interval<T>, Output = Interval<T>>,
489{
490 fn one() -> Self {
491 Self(T::one(), T::one())
492 }
493}
494
495#[cfg(test)]
496mod test {
497 use super::{Interval, IntervalFunction};
498
499 #[test]
500 fn roots_sin() {
501 struct Sin;
502
503 impl IntervalFunction<f32> for Sin {
504 fn eval(&self, t: f32) -> f32 {
505 t.sin()
506 }
507
508 fn eval_interval(&self, t: Interval<f32>) -> Interval<f32> {
509 t.sin()
510 }
511
512 fn eval_interval_gradient(&self, t: Interval<f32>) -> Interval<f32> {
513 t.cos()
514 }
515 }
516
517 let function = Sin;
518 let two_pi = core::f32::consts::TAU;
519 let roots =
520 super::find_root_intervals(&function, Interval(0.0, two_pi), 1.0e-5, 1.0e-5, 100);
521 assert_eq!(roots.len(), 3);
522 }
523
524 #[test]
525 fn interval_sin_cos() {
526 let a = core::f32::consts::PI / 6.0;
527 let b = core::f32::consts::FRAC_PI_2 + core::f32::consts::PI / 6.0;
528 let c = core::f32::consts::PI + core::f32::consts::PI / 6.0;
529 let d = core::f32::consts::PI + core::f32::consts::FRAC_PI_2 + core::f32::consts::PI / 6.0;
530 let two_pi = core::f32::consts::TAU;
531 let shifts = [0.0, two_pi * 100.0, -two_pi * 100.0];
532
533 for shift in shifts.iter() {
534 assert_eq!(
536 Interval(a + *shift, b + *shift).sin(),
537 Interval((a + *shift).sin(), 1.0)
538 );
539 assert_eq!(
540 Interval(a + *shift, c + *shift).sin(),
541 Interval((c + *shift).sin(), 1.0)
542 );
543 assert_eq!(Interval(a + *shift, d + *shift).sin(), Interval(-1.0, 1.0));
544
545 assert_eq!(
547 Interval(a + *shift, b + *shift).cos(),
548 Interval((b + *shift).cos(), (a + *shift).cos())
549 );
550 assert_eq!(
551 Interval(a + *shift, c + *shift).cos(),
552 Interval(-1.0, (a + *shift).cos())
553 );
554 assert_eq!(
555 Interval(a + *shift, d + *shift).cos(),
556 Interval(-1.0, (a + *shift).cos())
557 );
558 }
559 }
560}