bevy_app/
task_pool_plugin.rs1use crate::{App, Plugin};
2
3use alloc::string::ToString;
4use bevy_platform::sync::Arc;
5use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
6use core::fmt::Debug;
7use log::trace;
8
9cfg_if::cfg_if! {
10 if #[cfg(not(all(target_arch = "wasm32", feature = "web")))] {
11 use {crate::Last, bevy_tasks::tick_global_task_pools_on_main_thread};
12 use bevy_ecs::system::NonSendMarker;
13
14 fn tick_global_task_pools(_main_thread_marker: NonSendMarker) {
19 tick_global_task_pools_on_main_thread();
20 }
21 }
22}
23
24#[derive(Default)]
26pub struct TaskPoolPlugin {
27 pub task_pool_options: TaskPoolOptions,
29}
30
31impl Plugin for TaskPoolPlugin {
32 fn build(&self, _app: &mut App) {
33 self.task_pool_options.create_default_pools();
35
36 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
37 _app.add_systems(Last, tick_global_task_pools);
38 }
39}
40
41#[derive(Clone)]
44pub struct TaskPoolThreadAssignmentPolicy {
45 pub min_threads: usize,
47 pub max_threads: usize,
49 pub percent: f32,
52 pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
55 pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
58}
59
60impl Debug for TaskPoolThreadAssignmentPolicy {
61 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62 f.debug_struct("TaskPoolThreadAssignmentPolicy")
63 .field("min_threads", &self.min_threads)
64 .field("max_threads", &self.max_threads)
65 .field("percent", &self.percent)
66 .finish()
67 }
68}
69
70impl TaskPoolThreadAssignmentPolicy {
71 fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
73 assert!(self.percent >= 0.0);
74 let proportion = total_threads as f32 * self.percent;
75 let mut desired = proportion as usize;
76
77 if proportion - desired as f32 >= 0.5 {
80 desired += 1;
81 }
82
83 desired = desired.min(remaining_threads);
85
86 desired.clamp(self.min_threads, self.max_threads)
90 }
91}
92
93#[derive(Clone, Debug)]
96pub struct TaskPoolOptions {
97 pub min_total_threads: usize,
100 pub max_total_threads: usize,
103
104 pub io: TaskPoolThreadAssignmentPolicy,
106 pub async_compute: TaskPoolThreadAssignmentPolicy,
108 pub compute: TaskPoolThreadAssignmentPolicy,
110}
111
112impl Default for TaskPoolOptions {
113 fn default() -> Self {
114 TaskPoolOptions {
115 min_total_threads: 1,
117 max_total_threads: usize::MAX,
118
119 io: TaskPoolThreadAssignmentPolicy {
121 min_threads: 1,
122 max_threads: 4,
123 percent: 0.25,
124 on_thread_spawn: None,
125 on_thread_destroy: None,
126 },
127
128 async_compute: TaskPoolThreadAssignmentPolicy {
130 min_threads: 1,
131 max_threads: 4,
132 percent: 0.25,
133 on_thread_spawn: None,
134 on_thread_destroy: None,
135 },
136
137 compute: TaskPoolThreadAssignmentPolicy {
139 min_threads: 1,
140 max_threads: usize::MAX,
141 percent: 1.0, on_thread_spawn: None,
143 on_thread_destroy: None,
144 },
145 }
146 }
147}
148
149impl TaskPoolOptions {
150 pub fn with_num_threads(thread_count: usize) -> Self {
152 TaskPoolOptions {
153 min_total_threads: thread_count,
154 max_total_threads: thread_count,
155 ..Default::default()
156 }
157 }
158
159 pub fn create_default_pools(&self) {
161 let total_threads = bevy_tasks::available_parallelism()
162 .clamp(self.min_total_threads, self.max_total_threads);
163 trace!("Assigning {total_threads} cores to default task pools");
164
165 let mut remaining_threads = total_threads;
166
167 {
168 let io_threads = self
170 .io
171 .get_number_of_threads(remaining_threads, total_threads);
172
173 trace!("IO Threads: {io_threads}");
174 remaining_threads = remaining_threads.saturating_sub(io_threads);
175
176 IoTaskPool::get_or_init(|| {
177 let builder = TaskPoolBuilder::default()
178 .num_threads(io_threads)
179 .thread_name("IO Task Pool".to_string());
180
181 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
182 let builder = {
183 let mut builder = builder;
184 if let Some(f) = self.io.on_thread_spawn.clone() {
185 builder = builder.on_thread_spawn(move || f());
186 }
187 if let Some(f) = self.io.on_thread_destroy.clone() {
188 builder = builder.on_thread_destroy(move || f());
189 }
190 builder
191 };
192
193 builder.build()
194 });
195 }
196
197 {
198 let async_compute_threads = self
200 .async_compute
201 .get_number_of_threads(remaining_threads, total_threads);
202
203 trace!("Async Compute Threads: {async_compute_threads}");
204 remaining_threads = remaining_threads.saturating_sub(async_compute_threads);
205
206 AsyncComputeTaskPool::get_or_init(|| {
207 let builder = TaskPoolBuilder::default()
208 .num_threads(async_compute_threads)
209 .thread_name("Async Compute Task Pool".to_string());
210
211 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
212 let builder = {
213 let mut builder = builder;
214 if let Some(f) = self.async_compute.on_thread_spawn.clone() {
215 builder = builder.on_thread_spawn(move || f());
216 }
217 if let Some(f) = self.async_compute.on_thread_destroy.clone() {
218 builder = builder.on_thread_destroy(move || f());
219 }
220 builder
221 };
222
223 builder.build()
224 });
225 }
226
227 {
228 let compute_threads = self
231 .compute
232 .get_number_of_threads(remaining_threads, total_threads);
233
234 trace!("Compute Threads: {compute_threads}");
235
236 ComputeTaskPool::get_or_init(|| {
237 let builder = TaskPoolBuilder::default()
238 .num_threads(compute_threads)
239 .thread_name("Compute Task Pool".to_string());
240
241 #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
242 let builder = {
243 let mut builder = builder;
244 if let Some(f) = self.compute.on_thread_spawn.clone() {
245 builder = builder.on_thread_spawn(move || f());
246 }
247 if let Some(f) = self.compute.on_thread_destroy.clone() {
248 builder = builder.on_thread_destroy(move || f());
249 }
250 builder
251 };
252
253 builder.build()
254 });
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
263
264 #[test]
265 fn runs_spawn_local_tasks() {
266 let mut app = App::new();
267 app.add_plugins(TaskPoolPlugin::default());
268
269 let (async_tx, async_rx) = crossbeam_channel::unbounded();
270 AsyncComputeTaskPool::get()
271 .spawn_local(async move {
272 async_tx.send(()).unwrap();
273 })
274 .detach();
275
276 let (compute_tx, compute_rx) = crossbeam_channel::unbounded();
277 ComputeTaskPool::get()
278 .spawn_local(async move {
279 compute_tx.send(()).unwrap();
280 })
281 .detach();
282
283 let (io_tx, io_rx) = crossbeam_channel::unbounded();
284 IoTaskPool::get()
285 .spawn_local(async move {
286 io_tx.send(()).unwrap();
287 })
288 .detach();
289
290 app.run();
291
292 async_rx.try_recv().unwrap();
293 compute_rx.try_recv().unwrap();
294 io_rx.try_recv().unwrap();
295 }
296}