Skip to main content

naga_oil/
redirect.rs

1use std::collections::{HashMap, HashSet};
2
3use naga::{Block, Expression, Function, Handle, Module, Statement};
4use thiserror::Error;
5
6use crate::derive::DerivedModule;
7
8#[derive(Debug, Error)]
9pub enum RedirectError {
10    #[error("can't find function {0} for redirection")]
11    FunctionNotFound(String),
12    #[error("{0} cannot override {1} due to argument mismatch")]
13    ArgumentMismatch(String, String),
14    #[error("{0} cannot override {1} due to return type mismatch")]
15    ReturnTypeMismatch(String, String),
16    #[error("circular reference; can't find an order for : {0}")]
17    CircularReference(String),
18}
19
20pub struct Redirector {
21    module: Module,
22}
23
24impl Redirector {
25    pub fn new(module: Module) -> Self {
26        Self { module }
27    }
28
29    fn redirect_block(block: &mut Block, original: Handle<Function>, new: Handle<Function>) {
30        for stmt in block.iter_mut() {
31            match stmt {
32                Statement::Call {
33                    ref mut function, ..
34                } => {
35                    if *function == original {
36                        *function = new;
37                    }
38                }
39                Statement::Block(b) => Self::redirect_block(b, original, new),
40                Statement::If {
41                    condition: _,
42                    accept,
43                    reject,
44                } => {
45                    Self::redirect_block(accept, original, new);
46                    Self::redirect_block(reject, original, new);
47                }
48                Statement::Switch { selector: _, cases } => {
49                    for case in cases.iter_mut() {
50                        Self::redirect_block(&mut case.body, original, new);
51                    }
52                }
53                Statement::Loop {
54                    body,
55                    continuing,
56                    break_if: _,
57                } => {
58                    Self::redirect_block(body, original, new);
59                    Self::redirect_block(continuing, original, new);
60                }
61                Statement::Emit(_)
62                | Statement::Break
63                | Statement::Continue
64                | Statement::Return { .. }
65                | Statement::WorkGroupUniformLoad { .. }
66                | Statement::Kill
67                | Statement::MemoryBarrier(_)
68                | Statement::ControlBarrier(_)
69                | Statement::Store { .. }
70                | Statement::ImageStore { .. }
71                | Statement::Atomic { .. }
72                | Statement::RayQuery { .. }
73                | Statement::SubgroupBallot { .. }
74                | Statement::SubgroupGather { .. }
75                | Statement::SubgroupCollectiveOperation { .. }
76                | Statement::ImageAtomic { .. }
77                | Statement::RayPipelineFunction(..)
78                | Statement::CooperativeStore { .. } => (),
79            }
80        }
81    }
82
83    fn redirect_expr(expr: &mut Expression, original: Handle<Function>, new: Handle<Function>) {
84        if let Expression::CallResult(f) = expr {
85            if f == &original {
86                *expr = Expression::CallResult(new);
87            }
88        }
89    }
90
91    fn redirect_fn(func: &mut Function, original: Handle<Function>, new: Handle<Function>) {
92        Self::redirect_block(&mut func.body, original, new);
93        for (_, expr) in func.expressions.iter_mut() {
94            Self::redirect_expr(expr, original, new);
95        }
96    }
97
98    /// redirect all calls to the function named `original` with references to the function named `replacement`, except within the replacement function
99    /// or in any function contained in the `omit` set.
100    /// returns handles to the original and replacement functions.
101    /// NB: requires the replacement to be defined in the arena before any calls to the original, or validation will fail.
102    pub fn redirect_function(
103        &mut self,
104        original: &str,
105        replacement: &str,
106        omit: &HashSet<String>,
107    ) -> Result<(Handle<Function>, Handle<Function>), RedirectError> {
108        let (h_original, f_original) = self
109            .module
110            .functions
111            .iter()
112            .find(|(_, f)| f.name.as_deref() == Some(original))
113            .ok_or_else(|| RedirectError::FunctionNotFound(original.to_owned()))?;
114        let (h_replacement, f_replacement) = self
115            .module
116            .functions
117            .iter()
118            .find(|(_, f)| f.name.as_deref() == Some(replacement))
119            .ok_or_else(|| RedirectError::FunctionNotFound(replacement.to_owned()))?;
120
121        for (arg1, arg2) in f_original
122            .arguments
123            .iter()
124            .zip(f_replacement.arguments.iter())
125        {
126            if arg1.ty != arg2.ty {
127                return Err(RedirectError::ArgumentMismatch(
128                    original.to_owned(),
129                    replacement.to_owned(),
130                ));
131            }
132        }
133
134        if f_original.result.as_ref().map(|r| r.ty) != f_replacement.result.as_ref().map(|r| r.ty) {
135            return Err(RedirectError::ReturnTypeMismatch(
136                original.to_owned(),
137                replacement.to_owned(),
138            ));
139        }
140
141        for (h_f, f) in self.module.functions.iter_mut() {
142            if h_f != h_replacement && !omit.contains(f.name.as_ref().unwrap()) {
143                Self::redirect_fn(f, h_original, h_replacement);
144            }
145        }
146
147        for ep in &mut self.module.entry_points {
148            Self::redirect_fn(&mut ep.function, h_original, h_replacement);
149        }
150
151        Ok((h_original, h_replacement))
152    }
153
154    fn gather_requirements(block: &Block) -> HashSet<Handle<Function>> {
155        let mut requirements = HashSet::default();
156
157        for stmt in block.iter() {
158            match stmt {
159                Statement::Block(b) => requirements.extend(Self::gather_requirements(b)),
160                Statement::If { accept, reject, .. } => {
161                    requirements.extend(Self::gather_requirements(accept));
162                    requirements.extend(Self::gather_requirements(reject));
163                }
164                Statement::Switch { cases, .. } => {
165                    for case in cases {
166                        requirements.extend(Self::gather_requirements(&case.body));
167                    }
168                }
169                Statement::Loop {
170                    body, continuing, ..
171                } => {
172                    requirements.extend(Self::gather_requirements(body));
173                    requirements.extend(Self::gather_requirements(continuing));
174                }
175                Statement::Call { function, .. } => {
176                    requirements.insert(*function);
177                }
178                _ => (),
179            }
180        }
181
182        requirements
183    }
184
185    pub fn into_module(self) -> Result<naga::Module, RedirectError> {
186        // reorder functions so that dependents come first
187        let mut requirements: HashMap<_, _> = self
188            .module
189            .functions
190            .iter()
191            .map(|(h_f, f)| (h_f, Self::gather_requirements(&f.body)))
192            .collect();
193
194        let mut derived = DerivedModule::default();
195        derived.set_shader_source(&self.module, 0);
196
197        while !requirements.is_empty() {
198            let start_len = requirements.len();
199
200            let mut added: HashSet<Handle<Function>> = HashSet::new();
201
202            // add anything that has all requirements satisfied
203            requirements.retain(|h_f, reqs| {
204                if reqs.is_empty() {
205                    let func = self.module.functions.try_get(*h_f).unwrap();
206                    let span = self.module.functions.get_span(*h_f);
207                    derived.import_function(func, span);
208                    added.insert(*h_f);
209                    false
210                } else {
211                    true
212                }
213            });
214
215            // remove things we added from requirements
216            for reqs in requirements.values_mut() {
217                reqs.retain(|req| !added.contains(req));
218            }
219
220            if requirements.len() == start_len {
221                return Err(RedirectError::CircularReference(format!(
222                    "{:#?}",
223                    requirements.keys()
224                )));
225            }
226        }
227
228        Ok(derived.into_module_with_entrypoints())
229    }
230}
231
232impl TryFrom<Redirector> for naga::Module {
233    type Error = RedirectError;
234
235    fn try_from(redirector: Redirector) -> Result<Self, Self::Error> {
236        redirector.into_module()
237    }
238}