bevy_ecs/message/
iterators.rs

1#[cfg(feature = "multi_threaded")]
2use crate::batching::BatchingStrategy;
3use crate::message::{Message, MessageCursor, MessageId, MessageInstance, Messages};
4use core::{iter::Chain, slice::Iter};
5
6/// An iterator that yields any unread messages from a [`MessageReader`](super::MessageReader) or [`MessageCursor`].
7#[derive(Debug)]
8pub struct MessageIterator<'a, M: Message> {
9    iter: MessageIteratorWithId<'a, M>,
10}
11
12impl<'a, M: Message> Iterator for MessageIterator<'a, M> {
13    type Item = &'a M;
14    fn next(&mut self) -> Option<Self::Item> {
15        self.iter.next().map(|(message, _)| message)
16    }
17
18    fn size_hint(&self) -> (usize, Option<usize>) {
19        self.iter.size_hint()
20    }
21
22    fn count(self) -> usize {
23        self.iter.count()
24    }
25
26    fn last(self) -> Option<Self::Item>
27    where
28        Self: Sized,
29    {
30        self.iter.last().map(|(message, _)| message)
31    }
32
33    fn nth(&mut self, n: usize) -> Option<Self::Item> {
34        self.iter.nth(n).map(|(message, _)| message)
35    }
36}
37
38impl<'a, M: Message> ExactSizeIterator for MessageIterator<'a, M> {
39    fn len(&self) -> usize {
40        self.iter.len()
41    }
42}
43
44/// An iterator that yields any unread messages (and their IDs) from a [`MessageReader`](super::MessageReader) or [`MessageCursor`].
45#[derive(Debug)]
46pub struct MessageIteratorWithId<'a, M: Message> {
47    reader: &'a mut MessageCursor<M>,
48    chain: Chain<Iter<'a, MessageInstance<M>>, Iter<'a, MessageInstance<M>>>,
49    unread: usize,
50}
51
52impl<'a, M: Message> MessageIteratorWithId<'a, M> {
53    /// Creates a new iterator that yields any `messages` that have not yet been seen by `reader`.
54    pub fn new(reader: &'a mut MessageCursor<M>, messages: &'a Messages<M>) -> Self {
55        let a_index = reader
56            .last_message_count
57            .saturating_sub(messages.messages_a.start_message_count);
58        let b_index = reader
59            .last_message_count
60            .saturating_sub(messages.messages_b.start_message_count);
61        let a = messages.messages_a.get(a_index..).unwrap_or_default();
62        let b = messages.messages_b.get(b_index..).unwrap_or_default();
63
64        let unread_count = a.len() + b.len();
65        // Ensure `len` is implemented correctly
66        debug_assert_eq!(unread_count, reader.len(messages));
67        reader.last_message_count = messages.message_count - unread_count;
68        // Iterate the oldest first, then the newer messages
69        let chain = a.iter().chain(b.iter());
70
71        Self {
72            reader,
73            chain,
74            unread: unread_count,
75        }
76    }
77
78    /// Iterate over only the messages.
79    pub fn without_id(self) -> MessageIterator<'a, M> {
80        MessageIterator { iter: self }
81    }
82}
83
84impl<'a, M: Message> Iterator for MessageIteratorWithId<'a, M> {
85    type Item = (&'a M, MessageId<M>);
86    fn next(&mut self) -> Option<Self::Item> {
87        match self
88            .chain
89            .next()
90            .map(|instance| (&instance.message, instance.message_id))
91        {
92            Some(item) => {
93                #[cfg(feature = "detailed_trace")]
94                tracing::trace!("MessageReader::iter() -> {}", item.1);
95                self.reader.last_message_count += 1;
96                self.unread -= 1;
97                Some(item)
98            }
99            None => None,
100        }
101    }
102
103    fn size_hint(&self) -> (usize, Option<usize>) {
104        self.chain.size_hint()
105    }
106
107    fn count(self) -> usize {
108        self.reader.last_message_count += self.unread;
109        self.unread
110    }
111
112    fn last(self) -> Option<Self::Item>
113    where
114        Self: Sized,
115    {
116        let MessageInstance {
117            message_id,
118            message,
119        } = self.chain.last()?;
120        self.reader.last_message_count += self.unread;
121        Some((message, *message_id))
122    }
123
124    fn nth(&mut self, n: usize) -> Option<Self::Item> {
125        if let Some(MessageInstance {
126            message_id,
127            message,
128        }) = self.chain.nth(n)
129        {
130            self.reader.last_message_count += n + 1;
131            self.unread -= n + 1;
132            Some((message, *message_id))
133        } else {
134            self.reader.last_message_count += self.unread;
135            self.unread = 0;
136            None
137        }
138    }
139}
140
141impl<'a, M: Message> ExactSizeIterator for MessageIteratorWithId<'a, M> {
142    fn len(&self) -> usize {
143        self.unread
144    }
145}
146
147/// A parallel iterator over `Message`s.
148#[cfg(feature = "multi_threaded")]
149#[derive(Debug)]
150pub struct MessageParIter<'a, M: Message> {
151    reader: &'a mut MessageCursor<M>,
152    slices: [&'a [MessageInstance<M>]; 2],
153    batching_strategy: BatchingStrategy,
154    #[cfg(not(target_arch = "wasm32"))]
155    unread: usize,
156}
157
158#[cfg(feature = "multi_threaded")]
159impl<'a, M: Message> MessageParIter<'a, M> {
160    /// Creates a new parallel iterator over `messages` that have not yet been seen by `reader`.
161    pub fn new(reader: &'a mut MessageCursor<M>, messages: &'a Messages<M>) -> Self {
162        let a_index = reader
163            .last_message_count
164            .saturating_sub(messages.messages_a.start_message_count);
165        let b_index = reader
166            .last_message_count
167            .saturating_sub(messages.messages_b.start_message_count);
168        let a = messages.messages_a.get(a_index..).unwrap_or_default();
169        let b = messages.messages_b.get(b_index..).unwrap_or_default();
170
171        let unread_count = a.len() + b.len();
172        // Ensure `len` is implemented correctly
173        debug_assert_eq!(unread_count, reader.len(messages));
174        reader.last_message_count = messages.message_count - unread_count;
175
176        Self {
177            reader,
178            slices: [a, b],
179            batching_strategy: BatchingStrategy::default(),
180            #[cfg(not(target_arch = "wasm32"))]
181            unread: unread_count,
182        }
183    }
184
185    /// Changes the batching strategy used when iterating.
186    ///
187    /// For more information on how this affects the resultant iteration, see
188    /// [`BatchingStrategy`].
189    pub fn batching_strategy(mut self, strategy: BatchingStrategy) -> Self {
190        self.batching_strategy = strategy;
191        self
192    }
193
194    /// Runs the provided closure for each unread message in parallel.
195    ///
196    /// Unlike normal iteration, the message order is not guaranteed in any form.
197    ///
198    /// # Panics
199    /// If the [`ComputeTaskPool`] is not initialized. If using this from a message reader that is being
200    /// initialized and run from the ECS scheduler, this should never panic.
201    ///
202    /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
203    pub fn for_each<FN: Fn(&'a M) + Send + Sync + Clone>(self, func: FN) {
204        self.for_each_with_id(move |e, _| func(e));
205    }
206
207    /// Runs the provided closure for each unread message in parallel, like [`for_each`](Self::for_each),
208    /// but additionally provides the [`MessageId`] to the closure.
209    ///
210    /// Note that the order of iteration is not guaranteed, but [`MessageId`]s are ordered by send order.
211    ///
212    /// # Panics
213    /// If the [`ComputeTaskPool`] is not initialized. If using this from a message reader that is being
214    /// initialized and run from the ECS scheduler, this should never panic.
215    ///
216    /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
217    #[cfg_attr(
218        target_arch = "wasm32",
219        expect(unused_mut, reason = "not mutated on this target")
220    )]
221    pub fn for_each_with_id<FN: Fn(&'a M, MessageId<M>) + Send + Sync + Clone>(mut self, func: FN) {
222        #[cfg(target_arch = "wasm32")]
223        {
224            self.into_iter().for_each(|(e, i)| func(e, i));
225        }
226
227        #[cfg(not(target_arch = "wasm32"))]
228        {
229            let pool = bevy_tasks::ComputeTaskPool::get();
230            let thread_count = pool.thread_num();
231            if thread_count <= 1 {
232                return self.into_iter().for_each(|(e, i)| func(e, i));
233            }
234
235            let batch_size = self
236                .batching_strategy
237                .calc_batch_size(|| self.len(), thread_count);
238            let chunks = self.slices.map(|s| s.chunks_exact(batch_size));
239            let remainders = chunks.each_ref().map(core::slice::ChunksExact::remainder);
240
241            pool.scope(|scope| {
242                for batch in chunks.into_iter().flatten().chain(remainders) {
243                    let func = func.clone();
244                    scope.spawn(async move {
245                        for message_instance in batch {
246                            func(&message_instance.message, message_instance.message_id);
247                        }
248                    });
249                }
250            });
251
252            // Messages are guaranteed to be read at this point.
253            self.reader.last_message_count += self.unread;
254            self.unread = 0;
255        }
256    }
257
258    /// Returns the number of [`Message`]s to be iterated.
259    pub fn len(&self) -> usize {
260        self.slices.iter().map(|s| s.len()).sum()
261    }
262
263    /// Returns [`true`] if there are no messages remaining in this iterator.
264    pub fn is_empty(&self) -> bool {
265        self.slices.iter().all(|x| x.is_empty())
266    }
267}
268
269#[cfg(feature = "multi_threaded")]
270impl<'a, M: Message> IntoIterator for MessageParIter<'a, M> {
271    type IntoIter = MessageIteratorWithId<'a, M>;
272    type Item = <Self::IntoIter as Iterator>::Item;
273
274    fn into_iter(self) -> Self::IntoIter {
275        let MessageParIter {
276            reader,
277            slices: [a, b],
278            ..
279        } = self;
280        let unread = a.len() + b.len();
281        let chain = a.iter().chain(b);
282        MessageIteratorWithId {
283            reader,
284            chain,
285            unread,
286        }
287    }
288}