1use bevy_ecs::{prelude::Entity, world::World};
2use bevy_platform::collections::HashMap;
3#[cfg(feature = "trace")]
4use tracing::info_span;
5
6use alloc::{borrow::Cow, collections::VecDeque};
7use smallvec::{smallvec, SmallVec};
8use thiserror::Error;
9
10use crate::{
11 diagnostic::internal::{DiagnosticsRecorder, RenderDiagnosticsMutex},
12 render_graph::{
13 Edge, InternedRenderLabel, InternedRenderSubGraph, NodeRunError, NodeState, RenderGraph,
14 RenderGraphContext, SlotLabel, SlotType, SlotValue,
15 },
16 renderer::{RenderContext, RenderDevice},
17};
18
19pub(crate) struct RenderGraphRunner;
31
32#[derive(Error, Debug)]
33pub enum RenderGraphRunnerError {
34 #[error(transparent)]
35 NodeRunError(#[from] NodeRunError),
36 #[error("node output slot not set (index {slot_index}, name {slot_name})")]
37 EmptyNodeOutputSlot {
38 type_name: &'static str,
39 slot_index: usize,
40 slot_name: Cow<'static, str>,
41 },
42 #[error("graph '{sub_graph:?}' could not be run because slot '{slot_name}' at index {slot_index} has no value")]
43 MissingInput {
44 slot_index: usize,
45 slot_name: Cow<'static, str>,
46 sub_graph: Option<InternedRenderSubGraph>,
47 },
48 #[error("attempted to use the wrong type for input slot")]
49 MismatchedInputSlotType {
50 slot_index: usize,
51 label: SlotLabel,
52 expected: SlotType,
53 actual: SlotType,
54 },
55 #[error(
56 "node (name: '{node_name:?}') has {slot_count} input slots, but was provided {value_count} values"
57 )]
58 MismatchedInputCount {
59 node_name: InternedRenderLabel,
60 slot_count: usize,
61 value_count: usize,
62 },
63}
64
65impl RenderGraphRunner {
66 pub fn run(
67 graph: &RenderGraph,
68 render_device: RenderDevice,
69 mut diagnostics_recorder: Option<DiagnosticsRecorder>,
70 queue: &wgpu::Queue,
71 world: &World,
72 finalizer: impl FnOnce(&mut wgpu::CommandEncoder),
73 ) -> Result<Option<DiagnosticsRecorder>, RenderGraphRunnerError> {
74 if let Some(recorder) = &mut diagnostics_recorder {
75 recorder.begin_frame();
76 }
77
78 let mut render_context = RenderContext::new(render_device, diagnostics_recorder);
79 Self::run_graph(graph, None, &mut render_context, world, &[], None, None)?;
80 finalizer(render_context.command_encoder());
81
82 let (render_device, mut diagnostics_recorder) = {
83 let (commands, render_device, diagnostics_recorder) = render_context.finish();
84
85 #[cfg(feature = "trace")]
86 let _span = info_span!("submit_graph_commands").entered();
87 queue.submit(commands);
88
89 (render_device, diagnostics_recorder)
90 };
91
92 if let Some(recorder) = &mut diagnostics_recorder {
93 let render_diagnostics_mutex = world.resource::<RenderDiagnosticsMutex>().0.clone();
94 recorder.finish_frame(&render_device, move |diagnostics| {
95 *render_diagnostics_mutex.lock().expect("lock poisoned") = Some(diagnostics);
96 });
97 }
98
99 Ok(diagnostics_recorder)
100 }
101
102 fn run_graph<'w>(
105 graph: &RenderGraph,
106 sub_graph: Option<InternedRenderSubGraph>,
107 render_context: &mut RenderContext<'w>,
108 world: &'w World,
109 inputs: &[SlotValue],
110 view_entity: Option<Entity>,
111 debug_group: Option<String>,
112 ) -> Result<(), RenderGraphRunnerError> {
113 let mut node_outputs: HashMap<InternedRenderLabel, SmallVec<[SlotValue; 4]>> =
114 HashMap::default();
115 #[cfg(feature = "trace")]
116 let span = if let Some(render_label) = &sub_graph {
117 let name = format!("{render_label:?}");
118 if let Some(debug_group) = debug_group.as_ref() {
119 info_span!("run_graph", name = name, debug_group = debug_group)
120 } else {
121 info_span!("run_graph", name = name)
122 }
123 } else {
124 info_span!("run_graph", name = "main_graph")
125 };
126 #[cfg(feature = "trace")]
127 let _guard = span.enter();
128
129 if let Some(debug_group) = debug_group.as_ref() {
130 render_context
137 .command_encoder()
138 .insert_debug_marker(&format!("Start {debug_group}"));
139 }
140
141 let mut node_queue: VecDeque<&NodeState> = graph
143 .iter_nodes()
144 .filter(|node| node.input_slots.is_empty())
145 .collect();
146
147 if let Some(input_node) = graph.get_input_node() {
149 let mut input_values: SmallVec<[SlotValue; 4]> = SmallVec::new();
150 for (i, input_slot) in input_node.input_slots.iter().enumerate() {
151 if let Some(input_value) = inputs.get(i) {
152 if input_slot.slot_type != input_value.slot_type() {
153 return Err(RenderGraphRunnerError::MismatchedInputSlotType {
154 slot_index: i,
155 actual: input_value.slot_type(),
156 expected: input_slot.slot_type,
157 label: input_slot.name.clone().into(),
158 });
159 }
160 input_values.push(input_value.clone());
161 } else {
162 return Err(RenderGraphRunnerError::MissingInput {
163 slot_index: i,
164 slot_name: input_slot.name.clone(),
165 sub_graph,
166 });
167 }
168 }
169
170 node_outputs.insert(input_node.label, input_values);
171
172 for (_, node_state) in graph
173 .iter_node_outputs(input_node.label)
174 .expect("node exists")
175 {
176 node_queue.push_front(node_state);
177 }
178 }
179
180 'handle_node: while let Some(node_state) = node_queue.pop_back() {
181 if node_outputs.contains_key(&node_state.label) {
183 continue;
184 }
185
186 let mut slot_indices_and_inputs: SmallVec<[(usize, SlotValue); 4]> = SmallVec::new();
187 for (edge, input_node) in graph
189 .iter_node_inputs(node_state.label)
190 .expect("node is in graph")
191 {
192 match edge {
193 Edge::SlotEdge {
194 output_index,
195 input_index,
196 ..
197 } => {
198 if let Some(outputs) = node_outputs.get(&input_node.label) {
199 slot_indices_and_inputs
200 .push((*input_index, outputs[*output_index].clone()));
201 } else {
202 node_queue.push_front(node_state);
203 continue 'handle_node;
204 }
205 }
206 Edge::NodeEdge { .. } => {
207 if !node_outputs.contains_key(&input_node.label) {
208 node_queue.push_front(node_state);
209 continue 'handle_node;
210 }
211 }
212 }
213 }
214
215 slot_indices_and_inputs.sort_by_key(|(index, _)| *index);
217 let inputs: SmallVec<[SlotValue; 4]> = slot_indices_and_inputs
218 .into_iter()
219 .map(|(_, value)| value)
220 .collect();
221
222 if inputs.len() != node_state.input_slots.len() {
223 return Err(RenderGraphRunnerError::MismatchedInputCount {
224 node_name: node_state.label,
225 slot_count: node_state.input_slots.len(),
226 value_count: inputs.len(),
227 });
228 }
229
230 let mut outputs: SmallVec<[Option<SlotValue>; 4]> =
231 smallvec![None; node_state.output_slots.len()];
232 {
233 let mut context = RenderGraphContext::new(graph, node_state, &inputs, &mut outputs);
234 if let Some(view_entity) = view_entity {
235 context.set_view_entity(view_entity);
236 }
237
238 {
239 #[cfg(feature = "trace")]
240 let _span = info_span!("node", name = node_state.type_name).entered();
241
242 node_state.node.run(&mut context, render_context, world)?;
243 }
244
245 for run_sub_graph in context.finish() {
246 let sub_graph = graph
247 .get_sub_graph(run_sub_graph.sub_graph)
248 .expect("sub graph exists because it was validated when queued.");
249 Self::run_graph(
250 sub_graph,
251 Some(run_sub_graph.sub_graph),
252 render_context,
253 world,
254 &run_sub_graph.inputs,
255 run_sub_graph.view_entity,
256 run_sub_graph.debug_group,
257 )?;
258 }
259 }
260
261 let mut values: SmallVec<[SlotValue; 4]> = SmallVec::new();
262 for (i, output) in outputs.into_iter().enumerate() {
263 if let Some(value) = output {
264 values.push(value);
265 } else {
266 let empty_slot = node_state.output_slots.get_slot(i).unwrap();
267 return Err(RenderGraphRunnerError::EmptyNodeOutputSlot {
268 type_name: node_state.type_name,
269 slot_index: i,
270 slot_name: empty_slot.name.clone(),
271 });
272 }
273 }
274 node_outputs.insert(node_state.label, values);
275
276 for (_, node_state) in graph
277 .iter_node_outputs(node_state.label)
278 .expect("node exists")
279 {
280 node_queue.push_front(node_state);
281 }
282 }
283
284 if let Some(debug_group) = debug_group {
285 render_context
286 .command_encoder()
287 .insert_debug_marker(&format!("End {debug_group}"));
288 }
289
290 Ok(())
291 }
292}