avian3d/spatial_query/ray_caster.rs
1use crate::prelude::*;
2use bevy::{
3    ecs::{
4        entity::{EntityMapper, MapEntities},
5        lifecycle::HookContext,
6        world::DeferredWorld,
7    },
8    prelude::*,
9};
10#[cfg(all(
11    feature = "default-collider",
12    any(feature = "parry-f32", feature = "parry-f64")
13))]
14use parry::{partitioning::BvhNode, query::RayCast};
15
16/// A component used for [raycasting](spatial_query#raycasting).
17///
18/// **Raycasting** is a type of [spatial query](spatial_query) that finds one or more hits
19/// between a ray and a set of colliders.
20///
21/// Each ray is defined by a local `origin` and a `direction`. The [`RayCaster`] will find each hit
22/// and add them to the [`RayHits`] component. Each hit has a `distance` property which refers to
23/// how far the ray travelled, along with a `normal` for the point of intersection.
24///
25/// The [`RayCaster`] is the easiest way to handle simple raycasts. If you want more control and don't want to
26/// perform raycasts every frame, consider using the [`SpatialQuery`] system parameter.
27///
28/// # Hit Count and Order
29///
30/// The results of a raycast are in an arbitrary order by default. You can iterate over them in the order of
31/// distance with the [`RayHits::iter_sorted`] method.
32///
33/// You can configure the maximum amount of hits for a ray using `max_hits`. By default this is unbounded,
34/// so you will get all hits. When the number or complexity of colliders is large, this can be very
35/// expensive computationally. Set the value to whatever works best for your case.
36///
37/// Note that when there are more hits than `max_hits`, **some hits will be missed**.
38/// To guarantee that the closest hit is included, you should set `max_hits` to one or a value that
39/// is enough to contain all hits.
40///
41/// # Example
42///
43/// ```
44/// # #[cfg(feature = "2d")]
45/// # use avian2d::prelude::*;
46/// # #[cfg(feature = "3d")]
47/// use avian3d::prelude::*;
48/// use bevy::prelude::*;
49///
50/// # #[cfg(all(feature = "3d", feature = "f32"))]
51/// fn setup(mut commands: Commands) {
52///     // Spawn a ray at the center going right
53///     commands.spawn(RayCaster::new(Vec3::ZERO, Dir3::X));
54///     // ...spawn colliders and other things
55/// }
56///
57/// # #[cfg(all(feature = "3d", feature = "f32"))]
58/// fn print_hits(query: Query<(&RayCaster, &RayHits)>) {
59///     for (ray, hits) in &query {
60///         // For the faster iterator that isn't sorted, use `.iter()`
61///         for hit in hits.iter_sorted() {
62///             println!(
63///                 "Hit entity {} at {} with normal {}",
64///                 hit.entity,
65///                 ray.origin + *ray.direction * hit.distance,
66///                 hit.normal,
67///             );
68///         }
69///     }
70/// }
71/// ```
72#[derive(Component, Clone, Debug, PartialEq, Reflect)]
73#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
74#[cfg_attr(feature = "serialize", reflect(Serialize, Deserialize))]
75#[reflect(Debug, Component, PartialEq)]
76#[component(on_add = on_add_ray_caster)]
77#[require(RayHits)]
78pub struct RayCaster {
79    /// Controls if the ray caster is enabled.
80    pub enabled: bool,
81
82    /// The local origin of the ray relative to the [`Position`] and [`Rotation`] of the ray entity or its parent.
83    ///
84    /// To get the global origin, use the `global_origin` method.
85    pub origin: Vector,
86
87    /// The global origin of the ray.
88    global_origin: Vector,
89
90    /// The local direction of the ray relative to the [`Rotation`] of the ray entity or its parent.
91    ///
92    /// To get the global direction, use the `global_direction` method.
93    pub direction: Dir,
94
95    /// The global direction of the ray.
96    global_direction: Dir,
97
98    /// The maximum number of hits allowed.
99    ///
100    /// When there are more hits than `max_hits`, **some hits will be missed**.
101    /// To guarantee that the closest hit is included, you should set `max_hits` to one or a value that
102    /// is enough to contain all hits.
103    pub max_hits: u32,
104
105    /// The maximum distance the ray can travel.
106    ///
107    /// By default this is infinite, so the ray will travel until all hits up to `max_hits` have been checked.
108    #[doc(alias = "max_time_of_impact")]
109    pub max_distance: Scalar,
110
111    /// Controls how the ray behaves when the ray origin is inside of a [collider](Collider).
112    ///
113    /// If `true`, shapes will be treated as solid, and the ray cast will return with a distance of `0.0`
114    /// if the ray origin is inside of the shape. Otherwise, shapes will be treated as hollow, and the ray
115    /// will always return a hit at the shape's boundary.
116    pub solid: bool,
117
118    /// If true, the ray caster ignores hits against its own [`Collider`]. This is the default.
119    pub ignore_self: bool,
120
121    /// Rules that determine which colliders are taken into account in the ray cast.
122    pub query_filter: SpatialQueryFilter,
123}
124
125impl Default for RayCaster {
126    fn default() -> Self {
127        Self {
128            enabled: true,
129            origin: Vector::ZERO,
130            global_origin: Vector::ZERO,
131            direction: Dir::X,
132            global_direction: Dir::X,
133            max_distance: Scalar::MAX,
134            max_hits: u32::MAX,
135            solid: true,
136            ignore_self: true,
137            query_filter: SpatialQueryFilter::default(),
138        }
139    }
140}
141
142impl From<Ray> for RayCaster {
143    fn from(ray: Ray) -> Self {
144        RayCaster::from_ray(ray)
145    }
146}
147
148impl RayCaster {
149    /// Creates a new [`RayCaster`] with a given origin and direction.
150    pub fn new(origin: Vector, direction: Dir) -> Self {
151        Self {
152            origin,
153            direction,
154            ..default()
155        }
156    }
157
158    /// Creates a new [`RayCaster`] from a ray.
159    pub fn from_ray(ray: Ray) -> Self {
160        Self {
161            origin: ray.origin.adjust_precision(),
162            direction: ray.direction,
163            ..default()
164        }
165    }
166
167    /// Sets the ray origin.
168    pub fn with_origin(mut self, origin: Vector) -> Self {
169        self.origin = origin;
170        self
171    }
172
173    /// Sets the ray direction.
174    pub fn with_direction(mut self, direction: Dir) -> Self {
175        self.direction = direction;
176        self
177    }
178
179    /// Controls how the ray behaves when the ray origin is inside of a [collider](Collider).
180    ///
181    /// If `true`, shapes will be treated as solid, and the ray cast will return with a distance of `0.0`
182    /// if the ray origin is inside of the shape. Otherwise, shapes will be treated as hollow, and the ray
183    /// will always return a hit at the shape's boundary.
184    pub fn with_solidness(mut self, solid: bool) -> Self {
185        self.solid = solid;
186        self
187    }
188
189    /// Sets if the ray caster should ignore hits against its own [`Collider`].
190    ///
191    /// The default is `true`.
192    pub fn with_ignore_self(mut self, ignore: bool) -> Self {
193        self.ignore_self = ignore;
194        self
195    }
196
197    /// Sets the maximum distance the ray can travel.
198    pub fn with_max_distance(mut self, max_distance: Scalar) -> Self {
199        self.max_distance = max_distance;
200        self
201    }
202
203    /// Sets the maximum number of allowed hits.
204    pub fn with_max_hits(mut self, max_hits: u32) -> Self {
205        self.max_hits = max_hits;
206        self
207    }
208
209    /// Sets the ray caster's [query filter](SpatialQueryFilter) that controls which colliders
210    /// should be included or excluded by raycasts.
211    pub fn with_query_filter(mut self, query_filter: SpatialQueryFilter) -> Self {
212        self.query_filter = query_filter;
213        self
214    }
215
216    /// Enables the [`RayCaster`].
217    pub fn enable(&mut self) {
218        self.enabled = true;
219    }
220
221    /// Disables the [`RayCaster`].
222    pub fn disable(&mut self) {
223        self.enabled = false;
224    }
225
226    /// Returns the global origin of the ray.
227    pub fn global_origin(&self) -> Vector {
228        self.global_origin
229    }
230
231    /// Returns the global direction of the ray.
232    pub fn global_direction(&self) -> Dir {
233        self.global_direction
234    }
235
236    /// Sets the global origin of the ray.
237    pub(crate) fn set_global_origin(&mut self, global_origin: Vector) {
238        self.global_origin = global_origin;
239    }
240
241    /// Sets the global direction of the ray.
242    pub(crate) fn set_global_direction(&mut self, global_direction: Dir) {
243        self.global_direction = global_direction;
244    }
245
246    #[cfg(all(
247        feature = "default-collider",
248        any(feature = "parry-f32", feature = "parry-f64")
249    ))]
250    pub(crate) fn cast(
251        &mut self,
252        caster_entity: Entity,
253        hits: &mut RayHits,
254        query_pipeline: &SpatialQueryPipeline,
255    ) {
256        if self.ignore_self {
257            self.query_filter.excluded_entities.insert(caster_entity);
258        } else {
259            self.query_filter.excluded_entities.remove(&caster_entity);
260        }
261
262        hits.clear();
263
264        if self.max_hits == 1 {
265            let first_hit = query_pipeline.cast_ray(
266                self.global_origin(),
267                self.global_direction(),
268                self.max_distance,
269                self.solid,
270                &self.query_filter,
271            );
272
273            if let Some(hit) = first_hit {
274                hits.push(hit);
275            }
276        } else {
277            let ray = parry::query::Ray::new(
278                self.global_origin().into(),
279                self.global_direction().adjust_precision().into(),
280            );
281
282            let found_hits = query_pipeline
283                .bvh
284                .leaves(|node: &BvhNode| node.aabb().intersects_local_ray(&ray, self.max_distance))
285                .filter_map(|leaf| {
286                    let proxy = query_pipeline.proxies.get(leaf as usize)?;
287
288                    if !self.query_filter.test(proxy.entity, proxy.layers) {
289                        return None;
290                    }
291
292                    let hit = proxy.collider.shape_scaled().cast_ray_and_get_normal(
293                        &proxy.isometry,
294                        &ray,
295                        self.max_distance,
296                        self.solid,
297                    )?;
298
299                    Some(RayHitData {
300                        entity: proxy.entity,
301                        distance: hit.time_of_impact,
302                        normal: hit.normal.into(),
303                    })
304                })
305                .take(self.max_hits as usize);
306
307            hits.extend(found_hits);
308        }
309    }
310}
311
312fn on_add_ray_caster(mut world: DeferredWorld, ctx: HookContext) {
313    let ray_caster = world.get::<RayCaster>(ctx.entity).unwrap();
314    let max_hits = if ray_caster.max_hits == u32::MAX {
315        10
316    } else {
317        ray_caster.max_hits as usize
318    };
319
320    // Initialize capacity for hits
321    world.get_mut::<RayHits>(ctx.entity).unwrap().0 = Vec::with_capacity(max_hits);
322}
323
324/// Contains the hits of a ray cast by a [`RayCaster`].
325///
326/// The maximum number of hits depends on the value of `max_hits` in [`RayCaster`].
327///
328/// # Order
329///
330/// By default, the order of the hits is not guaranteed.
331///
332/// You can iterate the hits in the order of distance with `iter_sorted`.
333/// Note that this will create and sort a new vector instead of iterating over the existing one.
334///
335/// **Note**: When there are more hits than `max_hits`, **some hits will be missed**.
336/// If you want to guarantee that the closest hit is included, set `max_hits` to one.
337///
338/// # Example
339///
340/// ```
341#[cfg_attr(feature = "2d", doc = "use avian2d::prelude::*;")]
342#[cfg_attr(feature = "3d", doc = "use avian3d::prelude::*;")]
343/// use bevy::prelude::*;
344///
345/// fn print_hits(query: Query<&RayHits, With<RayCaster>>) {
346///     for hits in &query {
347///         // For the faster iterator that isn't sorted, use `.iter()`.
348///         for hit in hits.iter_sorted() {
349///             println!("Hit entity {} with distance {}", hit.entity, hit.distance);
350///         }
351///     }
352/// }
353/// ```
354#[derive(Component, Clone, Debug, Default, Deref, DerefMut, PartialEq, Reflect)]
355#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
356#[cfg_attr(feature = "serialize", reflect(Serialize, Deserialize))]
357#[reflect(Component, Debug, Default, PartialEq)]
358pub struct RayHits(pub Vec<RayHitData>);
359
360impl RayHits {
361    /// Returns an iterator over the hits, sorted in ascending order according to the distance.
362    ///
363    /// Note that this allocates a new vector. If you don't need the hits in order, use `iter`.
364    pub fn iter_sorted(&self) -> alloc::vec::IntoIter<RayHitData> {
365        let mut vector = self.as_slice().to_vec();
366        vector.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
367        vector.into_iter()
368    }
369}
370
371impl IntoIterator for RayHits {
372    type Item = RayHitData;
373    type IntoIter = alloc::vec::IntoIter<RayHitData>;
374
375    fn into_iter(self) -> Self::IntoIter {
376        self.0.into_iter()
377    }
378}
379
380impl<'a> IntoIterator for &'a RayHits {
381    type Item = &'a RayHitData;
382    type IntoIter = core::slice::Iter<'a, RayHitData>;
383
384    fn into_iter(self) -> Self::IntoIter {
385        self.0.iter()
386    }
387}
388
389impl<'a> IntoIterator for &'a mut RayHits {
390    type Item = &'a mut RayHitData;
391    type IntoIter = core::slice::IterMut<'a, RayHitData>;
392
393    fn into_iter(self) -> Self::IntoIter {
394        self.0.iter_mut()
395    }
396}
397
398impl MapEntities for RayHits {
399    fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
400        for hit in self {
401            hit.map_entities(entity_mapper);
402        }
403    }
404}
405
406/// Data related to a hit during a [raycast](spatial_query#raycasting).
407#[derive(Clone, Copy, Debug, PartialEq, Reflect)]
408#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
409#[cfg_attr(feature = "serialize", reflect(Serialize, Deserialize))]
410#[reflect(Debug, PartialEq)]
411pub struct RayHitData {
412    /// The entity of the collider that was hit by the ray.
413    pub entity: Entity,
414
415    /// How far the ray travelled. This is the distance between the ray origin and the point of intersection.
416    pub distance: Scalar,
417
418    /// The normal at the point of intersection, expressed in world space.
419    pub normal: Vector,
420}
421
422impl MapEntities for RayHitData {
423    fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
424        self.entity = entity_mapper.get_mapped(self.entity);
425    }
426}