bevy_app/
task_pool_plugin.rs

1use crate::{App, Plugin};
2
3use alloc::string::ToString;
4use bevy_platform::sync::Arc;
5use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
6use core::fmt::Debug;
7use log::trace;
8
9cfg_if::cfg_if! {
10    if #[cfg(not(all(target_arch = "wasm32", feature = "web")))] {
11        use {crate::Last, bevy_tasks::tick_global_task_pools_on_main_thread};
12        use bevy_ecs::system::NonSendMarker;
13
14        /// A system used to check and advanced our task pools.
15        ///
16        /// Calls [`tick_global_task_pools_on_main_thread`],
17        /// and uses [`NonSendMarker`] to ensure that this system runs on the main thread
18        fn tick_global_task_pools(_main_thread_marker: NonSendMarker) {
19            tick_global_task_pools_on_main_thread();
20        }
21    }
22}
23
24/// Setup of default task pools: [`AsyncComputeTaskPool`], [`ComputeTaskPool`], [`IoTaskPool`].
25#[derive(Default)]
26pub struct TaskPoolPlugin {
27    /// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start.
28    pub task_pool_options: TaskPoolOptions,
29}
30
31impl Plugin for TaskPoolPlugin {
32    fn build(&self, _app: &mut App) {
33        // Setup the default bevy task pools
34        self.task_pool_options.create_default_pools();
35
36        #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
37        _app.add_systems(Last, tick_global_task_pools);
38    }
39}
40
41/// Defines a simple way to determine how many threads to use given the number of remaining cores
42/// and number of total cores
43#[derive(Clone)]
44pub struct TaskPoolThreadAssignmentPolicy {
45    /// Force using at least this many threads
46    pub min_threads: usize,
47    /// Under no circumstance use more than this many threads for this pool
48    pub max_threads: usize,
49    /// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is
50    /// permitted to use 1.0 to try to use all remaining threads
51    pub percent: f32,
52    /// Callback that is invoked once for every created thread as it starts.
53    /// This configuration will be ignored under wasm platform.
54    pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
55    /// Callback that is invoked once for every created thread as it terminates
56    /// This configuration will be ignored under wasm platform.
57    pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
58}
59
60impl Debug for TaskPoolThreadAssignmentPolicy {
61    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62        f.debug_struct("TaskPoolThreadAssignmentPolicy")
63            .field("min_threads", &self.min_threads)
64            .field("max_threads", &self.max_threads)
65            .field("percent", &self.percent)
66            .finish()
67    }
68}
69
70impl TaskPoolThreadAssignmentPolicy {
71    /// Determine the number of threads to use for this task pool
72    fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
73        assert!(self.percent >= 0.0);
74        let proportion = total_threads as f32 * self.percent;
75        let mut desired = proportion as usize;
76
77        // Equivalent to round() for positive floats without libm requirement for
78        // no_std compatibility
79        if proportion - desired as f32 >= 0.5 {
80            desired += 1;
81        }
82
83        // Limit ourselves to the number of cores available
84        desired = desired.min(remaining_threads);
85
86        // Clamp by min_threads, max_threads. (This may result in us using more threads than are
87        // available, this is intended. An example case where this might happen is a device with
88        // <= 2 threads.
89        desired.clamp(self.min_threads, self.max_threads)
90    }
91}
92
93/// Helper for configuring and creating the default task pools. For end-users who want full control,
94/// set up [`TaskPoolPlugin`]
95#[derive(Clone, Debug)]
96pub struct TaskPoolOptions {
97    /// If the number of physical cores is less than `min_total_threads`, force using
98    /// `min_total_threads`
99    pub min_total_threads: usize,
100    /// If the number of physical cores is greater than `max_total_threads`, force using
101    /// `max_total_threads`
102    pub max_total_threads: usize,
103
104    /// Used to determine number of IO threads to allocate
105    pub io: TaskPoolThreadAssignmentPolicy,
106    /// Used to determine number of async compute threads to allocate
107    pub async_compute: TaskPoolThreadAssignmentPolicy,
108    /// Used to determine number of compute threads to allocate
109    pub compute: TaskPoolThreadAssignmentPolicy,
110}
111
112impl Default for TaskPoolOptions {
113    fn default() -> Self {
114        TaskPoolOptions {
115            // By default, use however many cores are available on the system
116            min_total_threads: 1,
117            max_total_threads: usize::MAX,
118
119            // Use 25% of cores for IO, at least 1, no more than 4
120            io: TaskPoolThreadAssignmentPolicy {
121                min_threads: 1,
122                max_threads: 4,
123                percent: 0.25,
124                on_thread_spawn: None,
125                on_thread_destroy: None,
126            },
127
128            // Use 25% of cores for async compute, at least 1, no more than 4
129            async_compute: TaskPoolThreadAssignmentPolicy {
130                min_threads: 1,
131                max_threads: 4,
132                percent: 0.25,
133                on_thread_spawn: None,
134                on_thread_destroy: None,
135            },
136
137            // Use all remaining cores for compute (at least 1)
138            compute: TaskPoolThreadAssignmentPolicy {
139                min_threads: 1,
140                max_threads: usize::MAX,
141                percent: 1.0, // This 1.0 here means "whatever is left over"
142                on_thread_spawn: None,
143                on_thread_destroy: None,
144            },
145        }
146    }
147}
148
149impl TaskPoolOptions {
150    /// Create a configuration that forces using the given number of threads.
151    pub fn with_num_threads(thread_count: usize) -> Self {
152        TaskPoolOptions {
153            min_total_threads: thread_count,
154            max_total_threads: thread_count,
155            ..Default::default()
156        }
157    }
158
159    /// Inserts the default thread pools into the given resource map based on the configured values
160    pub fn create_default_pools(&self) {
161        let total_threads = bevy_tasks::available_parallelism()
162            .clamp(self.min_total_threads, self.max_total_threads);
163        trace!("Assigning {total_threads} cores to default task pools");
164
165        let mut remaining_threads = total_threads;
166
167        {
168            // Determine the number of IO threads we will use
169            let io_threads = self
170                .io
171                .get_number_of_threads(remaining_threads, total_threads);
172
173            trace!("IO Threads: {io_threads}");
174            remaining_threads = remaining_threads.saturating_sub(io_threads);
175
176            IoTaskPool::get_or_init(|| {
177                let builder = TaskPoolBuilder::default()
178                    .num_threads(io_threads)
179                    .thread_name("IO Task Pool".to_string());
180
181                #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
182                let builder = {
183                    let mut builder = builder;
184                    if let Some(f) = self.io.on_thread_spawn.clone() {
185                        builder = builder.on_thread_spawn(move || f());
186                    }
187                    if let Some(f) = self.io.on_thread_destroy.clone() {
188                        builder = builder.on_thread_destroy(move || f());
189                    }
190                    builder
191                };
192
193                builder.build()
194            });
195        }
196
197        {
198            // Determine the number of async compute threads we will use
199            let async_compute_threads = self
200                .async_compute
201                .get_number_of_threads(remaining_threads, total_threads);
202
203            trace!("Async Compute Threads: {async_compute_threads}");
204            remaining_threads = remaining_threads.saturating_sub(async_compute_threads);
205
206            AsyncComputeTaskPool::get_or_init(|| {
207                let builder = TaskPoolBuilder::default()
208                    .num_threads(async_compute_threads)
209                    .thread_name("Async Compute Task Pool".to_string());
210
211                #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
212                let builder = {
213                    let mut builder = builder;
214                    if let Some(f) = self.async_compute.on_thread_spawn.clone() {
215                        builder = builder.on_thread_spawn(move || f());
216                    }
217                    if let Some(f) = self.async_compute.on_thread_destroy.clone() {
218                        builder = builder.on_thread_destroy(move || f());
219                    }
220                    builder
221                };
222
223                builder.build()
224            });
225        }
226
227        {
228            // Determine the number of compute threads we will use
229            // This is intentionally last so that an end user can specify 1.0 as the percent
230            let compute_threads = self
231                .compute
232                .get_number_of_threads(remaining_threads, total_threads);
233
234            trace!("Compute Threads: {compute_threads}");
235
236            ComputeTaskPool::get_or_init(|| {
237                let builder = TaskPoolBuilder::default()
238                    .num_threads(compute_threads)
239                    .thread_name("Compute Task Pool".to_string());
240
241                #[cfg(not(all(target_arch = "wasm32", feature = "web")))]
242                let builder = {
243                    let mut builder = builder;
244                    if let Some(f) = self.compute.on_thread_spawn.clone() {
245                        builder = builder.on_thread_spawn(move || f());
246                    }
247                    if let Some(f) = self.compute.on_thread_destroy.clone() {
248                        builder = builder.on_thread_destroy(move || f());
249                    }
250                    builder
251                };
252
253                builder.build()
254            });
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
263
264    #[test]
265    fn runs_spawn_local_tasks() {
266        let mut app = App::new();
267        app.add_plugins(TaskPoolPlugin::default());
268
269        let (async_tx, async_rx) = crossbeam_channel::unbounded();
270        AsyncComputeTaskPool::get()
271            .spawn_local(async move {
272                async_tx.send(()).unwrap();
273            })
274            .detach();
275
276        let (compute_tx, compute_rx) = crossbeam_channel::unbounded();
277        ComputeTaskPool::get()
278            .spawn_local(async move {
279                compute_tx.send(()).unwrap();
280            })
281            .detach();
282
283        let (io_tx, io_rx) = crossbeam_channel::unbounded();
284        IoTaskPool::get()
285            .spawn_local(async move {
286                io_tx.send(()).unwrap();
287            })
288            .detach();
289
290        app.run();
291
292        async_rx.try_recv().unwrap();
293        compute_rx.try_recv().unwrap();
294        io_rx.try_recv().unwrap();
295    }
296}