1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10 #[error("can't find function {0} for redirection")]
11 FunctionNotFound(String),
12 #[error("{0} cannot override {1} due to argument mismatch")]
13 ArgumentMismatch(String, String),
14 #[error("{0} cannot override {1} due to return type mismatch")]
15 ReturnTypeMismatch(String, String),
16 #[error("circular reference; can't find an order for : {0}")]
17 CircularReference(String),
18}
19
20pub struct Redirector {
21 module: Module,
22}
23
24impl Redirector {
25 pub fn new(module: Module) -> Self {
26 Self { module }
27 }
28
29 fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30 for stmt in block.iter_mut() {
31 match stmt {
32 Statement::Call {
33 ref mut function, ..
34 } => {
35 if *function == original {
36 *function = new;
37 }
38 }
39 Statement::Block(b) => Self::redirect_block(b, original, new),
40 Statement::If {
41 condition: _,
42 accept,
43 reject,
44 } => {
45 Self::redirect_block(accept, original, new);
46 Self::redirect_block(reject, original, new);
47 }
48 Statement::Switch { selector: _, cases } => {
49 for case in cases.iter_mut() {
50 Self::redirect_block(&mut case.body, original, new);
51 }
52 }
53 Statement::Loop {
54 body,
55 continuing,
56 break_if: _,
57 } => {
58 Self::redirect_block(body, original, new);
59 Self::redirect_block(continuing, original, new);
60 }
61 Statement::Emit(_)
62 | Statement::Break
63 | Statement::Continue
64 | Statement::Return { .. }
65 | Statement::WorkGroupUniformLoad { .. }
66 | Statement::Kill
67 | Statement::MemoryBarrier(_)
68 | Statement::ControlBarrier(_)
69 | Statement::Store { .. }
70 | Statement::ImageStore { .. }
71 | Statement::Atomic { .. }
72 | Statement::RayQuery { .. }
73 | Statement::SubgroupBallot { .. }
74 | Statement::SubgroupGather { .. }
75 | Statement::SubgroupCollectiveOperation { .. }
76 | Statement::ImageAtomic { .. }
77 | Statement::RayPipelineFunction(..)
78 | Statement::CooperativeStore { .. } => (),
79 }
80 }
81 }
82
83 fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
84 if let Expression::CallResult(f) = expr {
85 if f == &original {
86 *expr = Expression::CallResult(new);
87 }
88 }
89 }
90
91 fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
92 Self::redirect_block(&mut func.body, original, new);
93 for (_, expr) in func.expressions.iter_mut() {
94 Self::redirect_expr(expr, original, new);
95 }
96 }
97
98 pub fn redirect_function(
103 &mut self,
104 original: &str,
105 replacement: &str,
106 omit: &HashSet<String>,
107 ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
108 let (h_original, f_original) = self
109 .module
110 .functions
111 .iter()
112 .find(|(_, f)| f.name.as_deref() == Some(original))
113 .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
114 let (h_replacement, f_replacement) = self
115 .module
116 .functions
117 .iter()
118 .find(|(_, f)| f.name.as_deref() == Some(replacement))
119 .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
120
121 for (arg1, arg2) in f_original
122 .arguments
123 .iter()
124 .zip(f_replacement.arguments.iter())
125 {
126 if arg1.ty != arg2.ty {
127 return Err(RedirectError::ArgumentMismatch(
128 original.to_owned(),
129 replacement.to_owned(),
130 ));
131 }
132 }
133
134 if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
135 return Err(RedirectError::ReturnTypeMismatch(
136 original.to_owned(),
137 replacement.to_owned(),
138 ));
139 }
140
141 for (h_f, f) in self.module.functions.iter_mut() {
142 if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
143 Self::redirect_fn(f, h_original, h_replacement);
144 }
145 }
146
147 for ep in &mut self.module.entry_points {
148 Self::redirect_fn(&mut ep.function, h_original, h_replacement);
149 }
150
151 Ok((h_original, h_replacement))
152 }
153
154 fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
155 let mut requirements = HashSet::default();
156
157 for stmt in block.iter() {
158 match stmt {
159 Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
160 Statement::If { accept, reject, .. } => {
161 requirements.extend(Self::gather_requirements(accept));
162 requirements.extend(Self::gather_requirements(reject));
163 }
164 Statement::Switch { cases, .. } => {
165 for case in cases {
166 requirements.extend(Self::gather_requirements(&case.body));
167 }
168 }
169 Statement::Loop {
170 body, continuing, ..
171 } => {
172 requirements.extend(Self::gather_requirements(body));
173 requirements.extend(Self::gather_requirements(continuing));
174 }
175 Statement::Call { function, .. } => {
176 requirements.insert(*function);
177 }
178 _ => (),
179 }
180 }
181
182 requirements
183 }
184
185 pub fn into_module(self) -> Result<naga::Module, RedirectError> {
186 let mut requirements: HashMap<_, _> = self
188 .module
189 .functions
190 .iter()
191 .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
192 .collect();
193
194 let mut derived = DerivedModule::default();
195 derived.set_shader_source(&self.module, 0);
196
197 while !requirements.is_empty() {
198 let start_len = requirements.len();
199
200 let mut added: HashSet<Handle<Function>> = HashSet::new();
201
202 requirements.retain(|h_f, reqs| {
204 if reqs.is_empty() {
205 let func = self.module.functions.try_get(*h_f).unwrap();
206 let span = self.module.functions.get_span(*h_f);
207 derived.import_function(func, span);
208 added.insert(*h_f);
209 false
210 } else {
211 true
212 }
213 });
214
215 for reqs in requirements.values_mut() {
217 reqs.retain(|req| !added.contains(req));
218 }
219
220 if requirements.len() == start_len {
221 return Err(RedirectError::CircularReference(format!(
222 "{:#?}",
223 requirements.keys()
224 )));
225 }
226 }
227
228 Ok(derived.into_module_with_entrypoints())
229 }
230}
231
232impl TryFrom<Redirector> for naga::Module {
233 type Error = RedirectError;
234
235 fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
236 redirector.into_module()
237 }
238}