1#[cfg(feature = "multi_threaded")]
2use crate::batching::BatchingStrategy;
3use crate::message::{Message, MessageCursor, MessageId, MessageInstance, Messages};
4use core::{iter::Chain, slice::Iter};
5
6#[derive(Debug)]
8pub struct MessageIterator<'a, M: Message> {
9 iter: MessageIteratorWithId<'a, M>,
10}
11
12impl<'a, M: Message> Iterator for MessageIterator<'a, M> {
13 type Item = &'a M;
14 fn next(&mut self) -> Option<Self::Item> {
15 self.iter.next().map(|(message, _)| message)
16 }
17
18 fn size_hint(&self) -> (usize, Option<usize>) {
19 self.iter.size_hint()
20 }
21
22 fn count(self) -> usize {
23 self.iter.count()
24 }
25
26 fn last(self) -> Option<Self::Item>
27 where
28 Self: Sized,
29 {
30 self.iter.last().map(|(message, _)| message)
31 }
32
33 fn nth(&mut self, n: usize) -> Option<Self::Item> {
34 self.iter.nth(n).map(|(message, _)| message)
35 }
36}
37
38impl<'a, M: Message> ExactSizeIterator for MessageIterator<'a, M> {
39 fn len(&self) -> usize {
40 self.iter.len()
41 }
42}
43
44#[derive(Debug)]
46pub struct MessageIteratorWithId<'a, M: Message> {
47 reader: &'a mut MessageCursor<M>,
48 chain: Chain<Iter<'a, MessageInstance<M>>, Iter<'a, MessageInstance<M>>>,
49 unread: usize,
50}
51
52impl<'a, M: Message> MessageIteratorWithId<'a, M> {
53 pub fn new(reader: &'a mut MessageCursor<M>, messages: &'a Messages<M>) -> Self {
55 let a_index = reader
56 .last_message_count
57 .saturating_sub(messages.messages_a.start_message_count);
58 let b_index = reader
59 .last_message_count
60 .saturating_sub(messages.messages_b.start_message_count);
61 let a = messages.messages_a.get(a_index..).unwrap_or_default();
62 let b = messages.messages_b.get(b_index..).unwrap_or_default();
63
64 let unread_count = a.len() + b.len();
65 debug_assert_eq!(unread_count, reader.len(messages));
67 reader.last_message_count = messages.message_count - unread_count;
68 let chain = a.iter().chain(b.iter());
70
71 Self {
72 reader,
73 chain,
74 unread: unread_count,
75 }
76 }
77
78 pub fn without_id(self) -> MessageIterator<'a, M> {
80 MessageIterator { iter: self }
81 }
82}
83
84impl<'a, M: Message> Iterator for MessageIteratorWithId<'a, M> {
85 type Item = (&'a M, MessageId<M>);
86 fn next(&mut self) -> Option<Self::Item> {
87 match self
88 .chain
89 .next()
90 .map(|instance| (&instance.message, instance.message_id))
91 {
92 Some(item) => {
93 #[cfg(feature = "detailed_trace")]
94 tracing::trace!("MessageReader::iter() -> {}", item.1);
95 self.reader.last_message_count += 1;
96 self.unread -= 1;
97 Some(item)
98 }
99 None => None,
100 }
101 }
102
103 fn size_hint(&self) -> (usize, Option<usize>) {
104 self.chain.size_hint()
105 }
106
107 fn count(self) -> usize {
108 self.reader.last_message_count += self.unread;
109 self.unread
110 }
111
112 fn last(self) -> Option<Self::Item>
113 where
114 Self: Sized,
115 {
116 let MessageInstance {
117 message_id,
118 message,
119 } = self.chain.last()?;
120 self.reader.last_message_count += self.unread;
121 Some((message, *message_id))
122 }
123
124 fn nth(&mut self, n: usize) -> Option<Self::Item> {
125 if let Some(MessageInstance {
126 message_id,
127 message,
128 }) = self.chain.nth(n)
129 {
130 self.reader.last_message_count += n + 1;
131 self.unread -= n + 1;
132 Some((message, *message_id))
133 } else {
134 self.reader.last_message_count += self.unread;
135 self.unread = 0;
136 None
137 }
138 }
139}
140
141impl<'a, M: Message> ExactSizeIterator for MessageIteratorWithId<'a, M> {
142 fn len(&self) -> usize {
143 self.unread
144 }
145}
146
147#[cfg(feature = "multi_threaded")]
149#[derive(Debug)]
150pub struct MessageParIter<'a, M: Message> {
151 reader: &'a mut MessageCursor<M>,
152 slices: [&'a [MessageInstance<M>]; 2],
153 batching_strategy: BatchingStrategy,
154 #[cfg(not(target_arch = "wasm32"))]
155 unread: usize,
156}
157
158#[cfg(feature = "multi_threaded")]
159impl<'a, M: Message> MessageParIter<'a, M> {
160 pub fn new(reader: &'a mut MessageCursor<M>, messages: &'a Messages<M>) -> Self {
162 let a_index = reader
163 .last_message_count
164 .saturating_sub(messages.messages_a.start_message_count);
165 let b_index = reader
166 .last_message_count
167 .saturating_sub(messages.messages_b.start_message_count);
168 let a = messages.messages_a.get(a_index..).unwrap_or_default();
169 let b = messages.messages_b.get(b_index..).unwrap_or_default();
170
171 let unread_count = a.len() + b.len();
172 debug_assert_eq!(unread_count, reader.len(messages));
174 reader.last_message_count = messages.message_count - unread_count;
175
176 Self {
177 reader,
178 slices: [a, b],
179 batching_strategy: BatchingStrategy::default(),
180 #[cfg(not(target_arch = "wasm32"))]
181 unread: unread_count,
182 }
183 }
184
185 pub fn batching_strategy(mut self, strategy: BatchingStrategy) -> Self {
190 self.batching_strategy = strategy;
191 self
192 }
193
194 pub fn for_each<FN: Fn(&'a M) + Send + Sync + Clone>(self, func: FN) {
204 self.for_each_with_id(move |e, _| func(e));
205 }
206
207 #[cfg_attr(
218 target_arch = "wasm32",
219 expect(unused_mut, reason = "not mutated on this target")
220 )]
221 pub fn for_each_with_id<FN: Fn(&'a M, MessageId<M>) + Send + Sync + Clone>(mut self, func: FN) {
222 #[cfg(target_arch = "wasm32")]
223 {
224 self.into_iter().for_each(|(e, i)| func(e, i));
225 }
226
227 #[cfg(not(target_arch = "wasm32"))]
228 {
229 let pool = bevy_tasks::ComputeTaskPool::get();
230 let thread_count = pool.thread_num();
231 if thread_count <= 1 {
232 return self.into_iter().for_each(|(e, i)| func(e, i));
233 }
234
235 let batch_size = self
236 .batching_strategy
237 .calc_batch_size(|| self.len(), thread_count);
238 let chunks = self.slices.map(|s| s.chunks_exact(batch_size));
239 let remainders = chunks.each_ref().map(core::slice::ChunksExact::remainder);
240
241 pool.scope(|scope| {
242 for batch in chunks.into_iter().flatten().chain(remainders) {
243 let func = func.clone();
244 scope.spawn(async move {
245 for message_instance in batch {
246 func(&message_instance.message, message_instance.message_id);
247 }
248 });
249 }
250 });
251
252 self.reader.last_message_count += self.unread;
254 self.unread = 0;
255 }
256 }
257
258 pub fn len(&self) -> usize {
260 self.slices.iter().map(|s| s.len()).sum()
261 }
262
263 pub fn is_empty(&self) -> bool {
265 self.slices.iter().all(|x| x.is_empty())
266 }
267}
268
269#[cfg(feature = "multi_threaded")]
270impl<'a, M: Message> IntoIterator for MessageParIter<'a, M> {
271 type IntoIter = MessageIteratorWithId<'a, M>;
272 type Item = <Self::IntoIter as Iterator>::Item;
273
274 fn into_iter(self) -> Self::IntoIter {
275 let MessageParIter {
276 reader,
277 slices: [a, b],
278 ..
279 } = self;
280 let unread = a.len() + b.len();
281 let chain = a.iter().chain(b);
282 MessageIteratorWithId {
283 reader,
284 chain,
285 unread,
286 }
287 }
288}