1use alloc::{vec, vec::Vec};
4use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
5use na::{RealField, SimdPartialOrd};
6use num::{One, Zero};
7
8pub trait IntervalFunction<T> {
10 fn eval(&self, t: T) -> T;
12 fn eval_interval(&self, t: Interval<T>) -> Interval<T>;
14 fn eval_interval_gradient(&self, t: Interval<T>) -> Interval<T>;
16}
17
18pub fn find_root_intervals_to<T: RealField + Copy>(
23 function: &impl IntervalFunction<T>,
24 init: Interval<T>,
25 min_interval_width: T,
26 min_image_width: T,
27 max_recursions: usize,
28 results: &mut Vec<Interval<T>>,
29 candidates: &mut Vec<(Interval<T>, usize)>,
30) {
31 candidates.clear();
32
33 let push_candidate = |candidate,
34 recursion,
35 results: &mut Vec<Interval<T>>,
36 candidates: &mut Vec<(Interval<T>, usize)>| {
37 let candidate_image = function.eval_interval(candidate);
38 let is_small_range =
39 candidate.width() < min_interval_width || candidate_image.width() < min_image_width;
40
41 if candidate_image.contains(T::zero()) {
42 if recursion == max_recursions || is_small_range {
43 results.push(candidate);
44 } else {
45 candidates.push((candidate, recursion + 1));
46 }
47 } else if is_small_range
48 && function.eval(candidate.midpoint()).abs() < T::default_epsilon().sqrt()
49 {
50 results.push(candidate);
53 }
54 };
55
56 push_candidate(init, 0, results, candidates);
57
58 while let Some((candidate, recursion)) = candidates.pop() {
59 let mid = candidate.midpoint();
70 let f_mid = function.eval(mid);
71 let gradient = function.eval_interval_gradient(candidate);
72 let (shift1, shift2) = Interval(f_mid, f_mid) / gradient;
73
74 let new_candidates = [
75 (Interval(mid, mid) - shift1).intersect(candidate),
76 shift2.and_then(|shift2| (Interval(mid, mid) - shift2).intersect(candidate)),
77 ];
78
79 let prev_width = candidate.width();
80
81 for new_candidate in new_candidates.iter().flatten() {
82 if new_candidate.width() > prev_width * na::convert(0.75) {
83 let [a, b] = new_candidate.split();
86 push_candidate(a, recursion, results, candidates);
87 push_candidate(b, recursion, results, candidates);
88 } else {
89 push_candidate(*new_candidate, recursion, results, candidates);
90 }
91 }
92 }
93}
94
95pub fn find_root_intervals<T: RealField + Copy>(
97 function: &impl IntervalFunction<T>,
98 init: Interval<T>,
99 min_interval_width: T,
100 min_image_width: T,
101 max_recursions: usize,
102) -> Vec<Interval<T>> {
103 let mut results = vec![];
104 let mut candidates = vec![];
105 find_root_intervals_to(
106 function,
107 init,
108 min_interval_width,
109 min_image_width,
110 max_recursions,
111 &mut results,
112 &mut candidates,
113 );
114 results
115}
116
117#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
119pub struct Interval<T>(pub T, pub T);
120
121impl<T> Interval<T> {
122 #[must_use]
124 pub fn sort(a: T, b: T) -> Self
125 where
126 T: PartialOrd,
127 {
128 if a < b {
129 Self(a, b)
130 } else {
131 Self(b, a)
132 }
133 }
134
135 #[must_use]
137 pub fn splat(e: T) -> Self
138 where
139 T: Clone,
140 {
141 Self(e.clone(), e)
142 }
143
144 #[must_use]
146 pub fn contains(&self, t: T) -> bool
147 where
148 T: PartialOrd<T>,
149 {
150 self.0 <= t && self.1 >= t
151 }
152
153 #[must_use]
155 pub fn width(self) -> T::Output
156 where
157 T: Sub<T>,
158 {
159 self.1 - self.0
160 }
161
162 #[must_use]
164 pub fn midpoint(self) -> T
165 where
166 T: RealField + Copy,
167 {
168 let two: T = na::convert(2.0);
169 (self.0 + self.1) / two
170 }
171
172 #[must_use]
174 pub fn split(self) -> [Self; 2]
175 where
176 T: RealField + Copy,
177 {
178 let mid = self.midpoint();
179 [Interval(self.0, mid), Interval(mid, self.1)]
180 }
181
182 #[must_use]
184 pub fn enclose(self, t: T) -> Self
185 where
186 T: PartialOrd,
187 {
188 if t < self.0 {
189 Interval(t, self.1)
190 } else if t > self.1 {
191 Interval(self.0, t)
192 } else {
193 self
194 }
195 }
196
197 #[must_use]
201 pub fn intersect(self, rhs: Self) -> Option<Self>
202 where
203 T: PartialOrd + SimdPartialOrd, {
205 let result = Interval(self.0.simd_max(rhs.0), self.1.simd_min(rhs.1));
206
207 if result.0 > result.1 {
208 None
210 } else {
211 Some(result)
212 }
213 }
214
215 #[must_use]
217 pub fn sin_cos(self) -> (Self, Self)
218 where
219 T: RealField + Copy,
220 {
221 (self.sin(), self.cos())
222 }
223
224 #[must_use]
226 pub fn sin(self) -> Self
227 where
228 T: RealField + Copy,
229 {
230 if self.width() >= T::two_pi() {
231 Interval(-T::one(), T::one())
232 } else {
233 let sin0 = self.0.sin();
234 let sin1 = self.1.sin();
235 let mut result = Interval::sort(sin0, sin1);
236
237 let orig = (self.0 / T::two_pi()).floor() * T::two_pi();
238 let crit = [orig + T::frac_pi_2(), orig + T::pi() + T::frac_pi_2()];
239 let crit_vals = [T::one(), -T::one()];
240
241 for i in 0..2 {
242 if self.contains(crit[i]) || self.contains(crit[i] + T::two_pi()) {
243 result = result.enclose(crit_vals[i])
244 }
245 }
246
247 result
248 }
249 }
250
251 #[must_use]
253 pub fn cos(self) -> Self
254 where
255 T: RealField + Copy,
256 {
257 if self.width() >= T::two_pi() {
258 Interval(-T::one(), T::one())
259 } else {
260 let cos0 = self.0.cos();
261 let cos1 = self.1.cos();
262 let mut result = Interval::sort(cos0, cos1);
263
264 let orig = (self.0 / T::two_pi()).floor() * T::two_pi();
265 let crit = [orig, orig + T::pi()];
266 let crit_vals = [T::one(), -T::one()];
267
268 for i in 0..2 {
269 if self.contains(crit[i]) || self.contains(crit[i] + T::two_pi()) {
270 result = result.enclose(crit_vals[i])
271 }
272 }
273
274 result
275 }
276 }
277}
278
279impl<T: Add<T> + Copy> Add<T> for Interval<T> {
280 type Output = Interval<<T as Add<T>>::Output>;
281
282 fn add(self, rhs: T) -> Self::Output {
283 Interval(self.0 + rhs, self.1 + rhs)
284 }
285}
286
287impl<T: Add<T>> Add<Interval<T>> for Interval<T> {
288 type Output = Interval<<T as Add<T>>::Output>;
289
290 fn add(self, rhs: Self) -> Self::Output {
291 Interval(self.0 + rhs.0, self.1 + rhs.1)
292 }
293}
294
295impl<T: Sub<T> + Copy> Sub<T> for Interval<T> {
296 type Output = Interval<<T as Sub<T>>::Output>;
297
298 fn sub(self, rhs: T) -> Self::Output {
299 Interval(self.0 - rhs, self.1 - rhs)
300 }
301}
302
303impl<T: Sub<T> + Copy> Sub<Interval<T>> for Interval<T> {
304 type Output = Interval<<T as Sub<T>>::Output>;
305
306 fn sub(self, rhs: Self) -> Self::Output {
307 Interval(self.0 - rhs.1, self.1 - rhs.0)
308 }
309}
310
311impl<T: Neg> Neg for Interval<T> {
312 type Output = Interval<T::Output>;
313
314 fn neg(self) -> Self::Output {
315 Interval(-self.1, -self.0)
316 }
317}
318
319impl<T: Mul<T>> Mul<T> for Interval<T>
320where
321 T: Copy + PartialOrd + Zero,
322{
323 type Output = Interval<<T as Mul<T>>::Output>;
324
325 fn mul(self, rhs: T) -> Self::Output {
326 if rhs < T::zero() {
327 Interval(self.1 * rhs, self.0 * rhs)
328 } else {
329 Interval(self.0 * rhs, self.1 * rhs)
330 }
331 }
332}
333
334impl<T: Mul<T>> Mul<Interval<T>> for Interval<T>
335where
336 T: Copy + PartialOrd + Zero,
337 <T as Mul<T>>::Output: SimdPartialOrd,
338{
339 type Output = Interval<<T as Mul<T>>::Output>;
340
341 fn mul(self, rhs: Self) -> Self::Output {
342 let Interval(a1, a2) = self;
343 let Interval(b1, b2) = rhs;
344
345 if a2 <= T::zero() {
346 if b2 <= T::zero() {
347 Interval(a2 * b2, a1 * b1)
348 } else if b1 < T::zero() {
349 Interval(a1 * b2, a1 * b1)
350 } else {
351 Interval(a1 * b2, a2 * b1)
352 }
353 } else if a1 < T::zero() {
354 if b2 <= T::zero() {
355 Interval(a2 * b1, a1 * b1)
356 } else if b1 < T::zero() {
357 Interval((a1 * b2).simd_min(b2 * b1), (a1 * b1).simd_max(a2 * b2))
358 } else {
359 Interval(a1 * b2, a2 * b2)
360 }
361 } else if b2 <= T::zero() {
362 Interval(a2 * b1, a1 * b2)
363 } else if b1 < T::zero() {
364 Interval(a2 * b1, a2 * b2)
365 } else {
366 Interval(a1 * b1, a2 * b2)
367 }
368 }
369}
370
371impl<T: Div<T>> Div<Interval<T>> for Interval<T>
372where
373 T: RealField + Copy,
374 <T as Div<T>>::Output: SimdPartialOrd,
375{
376 type Output = (
377 Interval<<T as Div<T>>::Output>,
378 Option<Interval<<T as Div<T>>::Output>>,
379 );
380
381 fn div(self, rhs: Self) -> Self::Output {
382 let infinity = T::one() / T::zero();
383
384 let Interval(a1, a2) = self;
385 let Interval(b1, b2) = rhs;
386
387 if b1 <= T::zero() && b2 >= T::zero() {
388 if a2 < T::zero() {
391 if b2 == T::zero() {
392 (Interval(a2 / b1, infinity), None)
393 } else if b1 != T::zero() {
394 (
395 Interval(-infinity, a2 / b2),
396 Some(Interval(a2 / b1, infinity)),
397 )
398 } else {
399 (Interval(-infinity, a2 / b2), None)
400 }
401 } else if a1 <= T::zero() {
402 (Interval(-infinity, infinity), None)
403 } else if b2 == T::zero() {
404 (Interval(-infinity, a1 / b1), None)
405 } else if b1 != T::zero() {
406 (
407 Interval(-infinity, a1 / b1),
408 Some(Interval(a1 / b2, infinity)),
409 )
410 } else {
411 (Interval(a1 / b2, infinity), None)
412 }
413 } else if a2 <= T::zero() {
414 if b2 < T::zero() {
415 (Interval(a2 / b1, a1 / b2), None)
416 } else {
417 (Interval(a1 / b1, a2 / b2), None)
418 }
419 } else if a1 < T::zero() {
420 if b2 < T::zero() {
421 (Interval(a2 / b2, a1 / b2), None)
422 } else {
423 (Interval(a1 / b1, a2 / b1), None)
424 }
425 } else if b2 < T::zero() {
426 (Interval(a2 / b2, a1 / b1), None)
427 } else {
428 (Interval(a1 / b2, a2 / b1), None)
429 }
430 }
431}
432
433impl<T: Copy + Add<T, Output = T>> AddAssign<Interval<T>> for Interval<T> {
434 fn add_assign(&mut self, rhs: Interval<T>) {
435 *self = *self + rhs;
436 }
437}
438
439impl<T: Copy + Sub<T, Output = T>> SubAssign<Interval<T>> for Interval<T> {
440 fn sub_assign(&mut self, rhs: Interval<T>) {
441 *self = *self - rhs;
442 }
443}
444
445impl<T: Mul<T, Output = T>> MulAssign<Interval<T>> for Interval<T>
446where
447 T: Copy + PartialOrd + Zero,
448 <T as Mul<T>>::Output: SimdPartialOrd,
449{
450 fn mul_assign(&mut self, rhs: Interval<T>) {
451 *self = *self * rhs;
452 }
453}
454
455impl<T: Zero + Add<T>> Zero for Interval<T> {
456 fn zero() -> Self {
457 Self(T::zero(), T::zero())
458 }
459
460 fn is_zero(&self) -> bool {
461 self.0.is_zero() && self.1.is_zero()
462 }
463}
464
465impl<T: One + Mul<T>> One for Interval<T>
466where
467 Interval<T>: Mul<Interval<T>, Output = Interval<T>>,
468{
469 fn one() -> Self {
470 Self(T::one(), T::one())
471 }
472}
473
474#[cfg(test)]
475mod test {
476 use super::{Interval, IntervalFunction};
477 use na::RealField;
478
479 #[test]
480 fn roots_sin() {
481 struct Sin;
482
483 impl IntervalFunction<f32> for Sin {
484 fn eval(&self, t: f32) -> f32 {
485 t.sin()
486 }
487
488 fn eval_interval(&self, t: Interval<f32>) -> Interval<f32> {
489 t.sin()
490 }
491
492 fn eval_interval_gradient(&self, t: Interval<f32>) -> Interval<f32> {
493 t.cos()
494 }
495 }
496
497 let function = Sin;
498 let roots = super::find_root_intervals(
499 &function,
500 Interval(0.0, f32::two_pi()),
501 1.0e-5,
502 1.0e-5,
503 100,
504 );
505 assert_eq!(roots.len(), 3);
506 }
507
508 #[test]
509 fn interval_sin_cos() {
510 let a = f32::pi() / 6.0;
511 let b = f32::pi() / 2.0 + f32::pi() / 6.0;
512 let c = f32::pi() + f32::pi() / 6.0;
513 let d = f32::pi() + f32::pi() / 2.0 + f32::pi() / 6.0;
514 let shifts = [0.0, f32::two_pi() * 100.0, -f32::two_pi() * 100.0];
515
516 for shift in shifts.iter() {
517 assert_eq!(
519 Interval(a + *shift, b + *shift).sin(),
520 Interval((a + *shift).sin(), 1.0)
521 );
522 assert_eq!(
523 Interval(a + *shift, c + *shift).sin(),
524 Interval((c + *shift).sin(), 1.0)
525 );
526 assert_eq!(Interval(a + *shift, d + *shift).sin(), Interval(-1.0, 1.0));
527
528 assert_eq!(
530 Interval(a + *shift, b + *shift).cos(),
531 Interval((b + *shift).cos(), (a + *shift).cos())
532 );
533 assert_eq!(
534 Interval(a + *shift, c + *shift).cos(),
535 Interval(-1.0, (a + *shift).cos())
536 );
537 assert_eq!(
538 Interval(a + *shift, d + *shift).cos(),
539 Interval(-1.0, (a + *shift).cos())
540 );
541 }
542 }
543}