Skip to main content

bevy_render/
gpu_readback.rs

1use crate::{
2    extract_component::ExtractComponentPlugin,
3    render_asset::RenderAssets,
4    render_resource::{
5        Buffer, BufferUsages, CommandEncoder, Extent3d, TexelCopyBufferLayout, Texture,
6        TextureFormat,
7    },
8    renderer::RenderDevice,
9    storage::{GpuShaderBuffer, ShaderBuffer},
10    sync_world::MainEntity,
11    texture::GpuImage,
12    ExtractSchedule, MainWorld, Render, RenderApp, RenderSystems,
13};
14use async_channel::{Receiver, Sender};
15use bevy_app::{App, Plugin};
16use bevy_asset::Handle;
17use bevy_derive::{Deref, DerefMut};
18use bevy_ecs::{
19    change_detection::ResMut,
20    entity::Entity,
21    event::EntityEvent,
22    prelude::{Component, Resource, World},
23    system::{Query, Res},
24};
25use bevy_ecs::{schedule::IntoScheduleConfigs, template::FromTemplate};
26use bevy_image::{Image, TextureFormatPixelInfo};
27use bevy_log::warn;
28use bevy_platform::collections::HashMap;
29use bevy_reflect::Reflect;
30use bevy_render_macros::ExtractComponent;
31use encase::internal::ReadFrom;
32use encase::private::Reader;
33use encase::ShaderType;
34
35/// A plugin that enables reading back gpu buffers and textures to the cpu.
36pub struct GpuReadbackPlugin {
37    /// Describes the number of frames a buffer can be unused before it is removed from the pool in
38    /// order to avoid unnecessary reallocations.
39    max_unused_frames: usize,
40}
41
42impl Default for GpuReadbackPlugin {
43    fn default() -> Self {
44        Self {
45            max_unused_frames: 10,
46        }
47    }
48}
49
50impl Plugin for GpuReadbackPlugin {
51    fn build(&self, app: &mut App) {
52        app.add_plugins(ExtractComponentPlugin::<Readback>::default());
53
54        if let Some(render_app) = app.get_sub_app_mut(RenderApp) {
55            render_app
56                .init_resource::<GpuReadbackBufferPool>()
57                .init_resource::<GpuReadbacks>()
58                .insert_resource(GpuReadbackMaxUnusedFrames(self.max_unused_frames))
59                .add_systems(ExtractSchedule, sync_readbacks.ambiguous_with_all())
60                .add_systems(
61                    Render,
62                    (
63                        prepare_buffers.in_set(RenderSystems::PrepareResources),
64                        // TODO: this should be in the graph somehow
65                        map_buffers.in_set(RenderSystems::Cleanup),
66                    ),
67                );
68        }
69    }
70}
71
72/// A component that registers the wrapped handle for gpu readback, either a texture or a buffer.
73///
74/// Data is read asynchronously and will be triggered on the entity via the [`ReadbackComplete`] event
75/// when complete. If this component is not removed, the readback will be attempted every frame
76#[derive(Component, ExtractComponent, Clone, Debug, FromTemplate)]
77pub enum Readback {
78    #[default]
79    Texture(Handle<Image>),
80    Buffer {
81        buffer: Handle<ShaderBuffer>,
82        start_offset_and_size: Option<(u64, u64)>,
83    },
84}
85
86impl Readback {
87    /// Create a readback component for a texture using the given handle.
88    pub fn texture(image: Handle<Image>) -> Self {
89        Self::Texture(image)
90    }
91
92    /// Create a readback component for a full buffer using the given handle.
93    pub fn buffer(buffer: Handle<ShaderBuffer>) -> Self {
94        Self::Buffer {
95            buffer,
96            start_offset_and_size: None,
97        }
98    }
99
100    /// Create a readback component for a buffer range using the given handle, a start offset in bytes
101    /// and a number of bytes to read.
102    pub fn buffer_range(buffer: Handle<ShaderBuffer>, start_offset: u64, size: u64) -> Self {
103        Self::Buffer {
104            buffer,
105            start_offset_and_size: Some((start_offset, size)),
106        }
107    }
108}
109
110/// An event that is triggered when a gpu readback is complete.
111///
112/// The event contains the data as a `Vec<u8>`, which can be interpreted as the raw bytes of the
113/// requested buffer or texture.
114#[derive(EntityEvent, Deref, DerefMut, Reflect, Debug)]
115#[reflect(Debug)]
116pub struct ReadbackComplete {
117    pub entity: Entity,
118    #[deref]
119    pub data: Vec<u8>,
120}
121
122impl ReadbackComplete {
123    /// Convert the raw bytes of the event to a shader type.
124    pub fn to_shader_type<T: ShaderType + ReadFrom + Default>(&self) -> T {
125        let mut val = T::default();
126        let mut reader = Reader::new::<T>(&self.data, 0).expect("Failed to create Reader");
127        T::read_from(&mut val, &mut reader);
128        val
129    }
130}
131
132#[derive(Resource)]
133struct GpuReadbackMaxUnusedFrames(usize);
134
135struct GpuReadbackBuffer {
136    buffer: Buffer,
137    taken: bool,
138    frames_unused: usize,
139}
140
141#[derive(Resource, Default)]
142struct GpuReadbackBufferPool {
143    // Map of buffer size to list of buffers, with a flag for whether the buffer is taken and how
144    // many frames it has been unused for.
145    // TODO: We could ideally write all readback data to one big buffer per frame, the assumption
146    // here is that very few entities well actually be read back at once, and their size is
147    // unlikely to change.
148    buffers: HashMap<u64, Vec<GpuReadbackBuffer>>,
149}
150
151impl GpuReadbackBufferPool {
152    fn get(&mut self, render_device: &RenderDevice, size: u64) -> Buffer {
153        let buffers = self.buffers.entry(size).or_default();
154
155        // find an untaken buffer for this size
156        if let Some(buf) = buffers.iter_mut().find(|x| !x.taken) {
157            buf.taken = true;
158            buf.frames_unused = 0;
159            return buf.buffer.clone();
160        }
161
162        let buffer = render_device.create_buffer(&wgpu::BufferDescriptor {
163            label: Some("Readback Buffer"),
164            size,
165            usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ,
166            mapped_at_creation: false,
167        });
168        buffers.push(GpuReadbackBuffer {
169            buffer: buffer.clone(),
170            taken: true,
171            frames_unused: 0,
172        });
173        buffer
174    }
175
176    // Returns the buffer to the pool so it can be used in a future frame
177    fn return_buffer(&mut self, buffer: &Buffer) {
178        let size = buffer.size();
179        let buffers = self
180            .buffers
181            .get_mut(&size)
182            .expect("Returned buffer of untracked size");
183        if let Some(buf) = buffers.iter_mut().find(|x| x.buffer.id() == buffer.id()) {
184            buf.taken = false;
185        } else {
186            warn!("Returned buffer that was not allocated");
187        }
188    }
189
190    fn update(&mut self, max_unused_frames: usize) {
191        for (_, buffers) in &mut self.buffers {
192            // Tick all the buffers
193            for buf in &mut *buffers {
194                if !buf.taken {
195                    buf.frames_unused += 1;
196                }
197            }
198
199            // Remove buffers that haven't been used for MAX_UNUSED_FRAMES
200            buffers.retain(|x| x.frames_unused < max_unused_frames);
201        }
202
203        // Remove empty buffer sizes
204        self.buffers.retain(|_, buffers| !buffers.is_empty());
205    }
206}
207
208enum ReadbackSource {
209    Texture {
210        texture: Texture,
211        layout: TexelCopyBufferLayout,
212        size: Extent3d,
213    },
214    Buffer {
215        buffer: Buffer,
216        start_offset_and_size: Option<(u64, u64)>,
217    },
218}
219
220#[derive(Resource, Default)]
221struct GpuReadbacks {
222    requested: Vec<GpuReadback>,
223    mapped: Vec<GpuReadback>,
224}
225
226struct GpuReadback {
227    pub entity: Entity,
228    pub src: ReadbackSource,
229    pub buffer: Buffer,
230    pub rx: Receiver<(Entity, Buffer, Vec<u8>)>,
231    pub tx: Sender<(Entity, Buffer, Vec<u8>)>,
232}
233
234fn sync_readbacks(
235    mut main_world: ResMut<MainWorld>,
236    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
237    mut readbacks: ResMut<GpuReadbacks>,
238    max_unused_frames: Res<GpuReadbackMaxUnusedFrames>,
239) {
240    readbacks.mapped.retain(|readback| {
241        if let Ok((entity, buffer, data)) = readback.rx.try_recv() {
242            main_world.trigger(ReadbackComplete { data, entity });
243            buffer_pool.return_buffer(&buffer);
244            false
245        } else {
246            true
247        }
248    });
249
250    buffer_pool.update(max_unused_frames.0);
251}
252
253fn prepare_buffers(
254    render_device: Res<RenderDevice>,
255    mut readbacks: ResMut<GpuReadbacks>,
256    mut buffer_pool: ResMut<GpuReadbackBufferPool>,
257    gpu_images: Res<RenderAssets<GpuImage>>,
258    ssbos: Res<RenderAssets<GpuShaderBuffer>>,
259    handles: Query<(&MainEntity, &Readback)>,
260) {
261    for (entity, readback) in handles.iter() {
262        match readback {
263            Readback::Texture(image) => {
264                if let Some(gpu_image) = gpu_images.get(image)
265                    && let Ok(pixel_size) = gpu_image.texture_descriptor.format.pixel_size()
266                {
267                    let layout = layout_data(
268                        gpu_image.texture_descriptor.size,
269                        gpu_image.texture_descriptor.format,
270                    );
271                    let buffer = buffer_pool.get(
272                        &render_device,
273                        get_aligned_size(gpu_image.texture_descriptor.size, pixel_size as u32)
274                            as u64,
275                    );
276                    let (tx, rx) = async_channel::bounded(1);
277                    readbacks.requested.push(GpuReadback {
278                        entity: entity.id(),
279                        src: ReadbackSource::Texture {
280                            texture: gpu_image.texture.clone(),
281                            layout,
282                            size: gpu_image.texture_descriptor.size,
283                        },
284                        buffer,
285                        rx,
286                        tx,
287                    });
288                }
289            }
290            Readback::Buffer {
291                buffer,
292                start_offset_and_size,
293            } => {
294                if let Some(ssbo) = ssbos.get(buffer) {
295                    let full_size = ssbo.buffer.size();
296                    let size = start_offset_and_size
297                        .map(|(start, size)| {
298                            let end = start + size;
299                            if end > full_size {
300                                panic!(
301                                    "Tried to read past the end of the buffer (start: {start}, \
302                                    size: {size}, buffer size: {full_size})."
303                                );
304                            }
305                            size
306                        })
307                        .unwrap_or(full_size);
308                    let buffer = buffer_pool.get(&render_device, size);
309                    let (tx, rx) = async_channel::bounded(1);
310                    readbacks.requested.push(GpuReadback {
311                        entity: entity.id(),
312                        src: ReadbackSource::Buffer {
313                            start_offset_and_size: *start_offset_and_size,
314                            buffer: ssbo.buffer.clone(),
315                        },
316                        buffer,
317                        rx,
318                        tx,
319                    });
320                }
321            }
322        }
323    }
324}
325
326pub(crate) fn submit_readback_commands(world: &World, command_encoder: &mut CommandEncoder) {
327    let readbacks = world.resource::<GpuReadbacks>();
328    for readback in &readbacks.requested {
329        match &readback.src {
330            ReadbackSource::Texture {
331                texture,
332                layout,
333                size,
334            } => {
335                command_encoder.copy_texture_to_buffer(
336                    texture.as_image_copy(),
337                    wgpu::TexelCopyBufferInfo {
338                        buffer: &readback.buffer,
339                        layout: *layout,
340                    },
341                    *size,
342                );
343            }
344            ReadbackSource::Buffer {
345                buffer,
346                start_offset_and_size,
347            } => {
348                let (src_start, size) = start_offset_and_size.unwrap_or((0, buffer.size()));
349                command_encoder.copy_buffer_to_buffer(buffer, src_start, &readback.buffer, 0, size);
350            }
351        }
352    }
353}
354
355/// Move requested readbacks to mapped readbacks after commands have been submitted in render system
356fn map_buffers(mut readbacks: ResMut<GpuReadbacks>) {
357    let requested = readbacks.requested.drain(..).collect::<Vec<GpuReadback>>();
358    for readback in requested {
359        let slice = readback.buffer.slice(..);
360        let entity = readback.entity;
361        let buffer = readback.buffer.clone();
362        let tx = readback.tx.clone();
363        slice.map_async(wgpu::MapMode::Read, move |res| {
364            res.expect("Failed to map buffer");
365            let buffer_slice = buffer.slice(..);
366            let data = buffer_slice.get_mapped_range();
367            let result = Vec::from(&*data);
368            drop(data);
369            buffer.unmap();
370            if let Err(e) = tx.try_send((entity, buffer, result)) {
371                warn!("Failed to send readback result: {}", e);
372            }
373        });
374        readbacks.mapped.push(readback);
375    }
376}
377
378// Utils
379
380/// Round up a given value to be a multiple of [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
381pub(crate) const fn align_byte_size(value: u32) -> u32 {
382    RenderDevice::align_copy_bytes_per_row(value as usize) as u32
383}
384
385/// Get the size of a image when the size of each row has been rounded up to [`wgpu::COPY_BYTES_PER_ROW_ALIGNMENT`].
386pub(crate) const fn get_aligned_size(extent: Extent3d, pixel_size: u32) -> u32 {
387    extent.height * align_byte_size(extent.width * pixel_size) * extent.depth_or_array_layers
388}
389
390/// Get a [`TexelCopyBufferLayout`] aligned such that the image can be copied into a buffer.
391pub(crate) fn layout_data(extent: Extent3d, format: TextureFormat) -> TexelCopyBufferLayout {
392    TexelCopyBufferLayout {
393        bytes_per_row: if extent.height > 1 || extent.depth_or_array_layers > 1 {
394            if let Ok(pixel_size) = format.pixel_size() {
395                // 1 = 1 row
396                Some(get_aligned_size(
397                    Extent3d {
398                        width: extent.width,
399                        height: 1,
400                        depth_or_array_layers: 1,
401                    },
402                    pixel_size as u32,
403                ))
404            } else {
405                None
406            }
407        } else {
408            None
409        },
410        rows_per_image: if extent.depth_or_array_layers > 1 {
411            let (_, block_dimension_y) = format.block_dimensions();
412            Some(extent.height / block_dimension_y)
413        } else {
414            None
415        },
416        offset: 0,
417    }
418}