1use alloc::sync::Arc;
2
3use crate::{
4    binding_model::BindGroup,
5    id,
6    pipeline::ComputePipeline,
7    resource::{Buffer, QuerySet},
8};
9
10#[derive(Clone, Copy, Debug)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub enum ComputeCommand {
13    SetBindGroup {
14        index: u32,
15        num_dynamic_offsets: usize,
16        bind_group_id: Option<id::BindGroupId>,
17    },
18
19    SetPipeline(id::ComputePipelineId),
20
21    SetPushConstant {
23        offset: u32,
26
27        size_bytes: u32,
29
30        values_offset: u32,
36    },
37
38    Dispatch([u32; 3]),
39
40    DispatchIndirect {
41        buffer_id: id::BufferId,
42        offset: wgt::BufferAddress,
43    },
44
45    PushDebugGroup {
46        color: u32,
47        len: usize,
48    },
49
50    PopDebugGroup,
51
52    InsertDebugMarker {
53        color: u32,
54        len: usize,
55    },
56
57    WriteTimestamp {
58        query_set_id: id::QuerySetId,
59        query_index: u32,
60    },
61
62    BeginPipelineStatisticsQuery {
63        query_set_id: id::QuerySetId,
64        query_index: u32,
65    },
66
67    EndPipelineStatisticsQuery,
68}
69
70impl ComputeCommand {
71    #[cfg(any(feature = "serde", feature = "replay"))]
73    pub fn resolve_compute_command_ids(
74        hub: &crate::hub::Hub,
75        commands: &[ComputeCommand],
76    ) -> Result<alloc::vec::Vec<ArcComputeCommand>, super::ComputePassError> {
77        use super::{ComputePassError, PassErrorScope};
78        use alloc::vec::Vec;
79
80        let buffers_guard = hub.buffers.read();
81        let bind_group_guard = hub.bind_groups.read();
82        let query_set_guard = hub.query_sets.read();
83        let pipelines_guard = hub.compute_pipelines.read();
84
85        let resolved_commands: Vec<ArcComputeCommand> = commands
86            .iter()
87            .map(|c| -> Result<ArcComputeCommand, ComputePassError> {
88                Ok(match *c {
89                    ComputeCommand::SetBindGroup {
90                        index,
91                        num_dynamic_offsets,
92                        bind_group_id,
93                    } => {
94                        if bind_group_id.is_none() {
95                            return Ok(ArcComputeCommand::SetBindGroup {
96                                index,
97                                num_dynamic_offsets,
98                                bind_group: None,
99                            });
100                        }
101
102                        let bind_group_id = bind_group_id.unwrap();
103                        let bg = bind_group_guard.get(bind_group_id).get().map_err(|e| {
104                            ComputePassError {
105                                scope: PassErrorScope::SetBindGroup,
106                                inner: e.into(),
107                            }
108                        })?;
109
110                        ArcComputeCommand::SetBindGroup {
111                            index,
112                            num_dynamic_offsets,
113                            bind_group: Some(bg),
114                        }
115                    }
116                    ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
117                        pipelines_guard
118                            .get(pipeline_id)
119                            .get()
120                            .map_err(|e| ComputePassError {
121                                scope: PassErrorScope::SetPipelineCompute,
122                                inner: e.into(),
123                            })?,
124                    ),
125
126                    ComputeCommand::SetPushConstant {
127                        offset,
128                        size_bytes,
129                        values_offset,
130                    } => ArcComputeCommand::SetPushConstant {
131                        offset,
132                        size_bytes,
133                        values_offset,
134                    },
135
136                    ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
137
138                    ComputeCommand::DispatchIndirect { buffer_id, offset } => {
139                        ArcComputeCommand::DispatchIndirect {
140                            buffer: buffers_guard.get(buffer_id).get().map_err(|e| {
141                                ComputePassError {
142                                    scope: PassErrorScope::Dispatch { indirect: true },
143                                    inner: e.into(),
144                                }
145                            })?,
146                            offset,
147                        }
148                    }
149
150                    ComputeCommand::PushDebugGroup { color, len } => {
151                        ArcComputeCommand::PushDebugGroup { color, len }
152                    }
153
154                    ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
155
156                    ComputeCommand::InsertDebugMarker { color, len } => {
157                        ArcComputeCommand::InsertDebugMarker { color, len }
158                    }
159
160                    ComputeCommand::WriteTimestamp {
161                        query_set_id,
162                        query_index,
163                    } => ArcComputeCommand::WriteTimestamp {
164                        query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
165                            ComputePassError {
166                                scope: PassErrorScope::WriteTimestamp,
167                                inner: e.into(),
168                            }
169                        })?,
170                        query_index,
171                    },
172
173                    ComputeCommand::BeginPipelineStatisticsQuery {
174                        query_set_id,
175                        query_index,
176                    } => ArcComputeCommand::BeginPipelineStatisticsQuery {
177                        query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
178                            ComputePassError {
179                                scope: PassErrorScope::BeginPipelineStatisticsQuery,
180                                inner: e.into(),
181                            }
182                        })?,
183                        query_index,
184                    },
185
186                    ComputeCommand::EndPipelineStatisticsQuery => {
187                        ArcComputeCommand::EndPipelineStatisticsQuery
188                    }
189                })
190            })
191            .collect::<Result<Vec<_>, ComputePassError>>()?;
192        Ok(resolved_commands)
193    }
194}
195
196#[derive(Clone, Debug)]
198pub enum ArcComputeCommand {
199    SetBindGroup {
200        index: u32,
201        num_dynamic_offsets: usize,
202        bind_group: Option<Arc<BindGroup>>,
203    },
204
205    SetPipeline(Arc<ComputePipeline>),
206
207    SetPushConstant {
209        offset: u32,
212
213        size_bytes: u32,
215
216        values_offset: u32,
222    },
223
224    Dispatch([u32; 3]),
225
226    DispatchIndirect {
227        buffer: Arc<Buffer>,
228        offset: wgt::BufferAddress,
229    },
230
231    PushDebugGroup {
232        #[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))]
233        color: u32,
234        len: usize,
235    },
236
237    PopDebugGroup,
238
239    InsertDebugMarker {
240        #[cfg_attr(not(any(feature = "serde", feature = "replay")), allow(dead_code))]
241        color: u32,
242        len: usize,
243    },
244
245    WriteTimestamp {
246        query_set: Arc<QuerySet>,
247        query_index: u32,
248    },
249
250    BeginPipelineStatisticsQuery {
251        query_set: Arc<QuerySet>,
252        query_index: u32,
253    },
254
255    EndPipelineStatisticsQuery,
256}