Skip to main content

naga_oil/compose/
mod.rs

1use indexmap::IndexMap;
2/// the compose module allows construction of shaders from modules (which are themselves shaders).
3///
4/// it does this by treating shaders as modules, and
5/// - building each module independently to naga IR
6/// - creating "header" files for each supported language, which are used to build dependent modules/shaders
7/// - making final shaders by combining the shader IR with the IR for imported modules
8///
9/// for multiple small shaders with large common imports, this can be faster than parsing the full source for each shader, and it allows for constructing shaders in a cleaner modular manner with better scope control.
10///
11/// ## imports
12///
13/// shaders can be added to the composer as modules. this makes their types, constants, variables and functions available to modules/shaders that import them. note that importing a module will affect the final shader's global state if the module defines globals variables with bindings.
14///
15/// modules must include a `#define_import_path` directive that names the module.
16///
17/// ```ignore
18/// #define_import_path my_module
19///
20/// fn my_func() -> f32 {
21///     return 1.0;
22/// }
23/// ```
24///
25/// shaders can then import the module with an `#import` directive (with an optional `as` name). at point of use, imported items must be qualified:
26///
27/// ```ignore
28/// #import my_module
29/// #import my_other_module as Mod2
30///
31/// fn main() -> f32 {
32///     let x = my_module::my_func();
33///     let y = Mod2::my_other_func();
34///     return x*y;
35/// }
36/// ```
37///
38/// or import a comma-separated list of individual items with a `#from` directive. at point of use, imported items must be prefixed with `::` :
39///
40/// ```ignore
41/// #from my_module import my_func, my_const
42///
43/// fn main() -> f32 {
44///     return ::my_func(::my_const);
45/// }
46/// ```
47///
48/// imports can be nested - modules may import other modules, but not recursively. when a new module is added, all its `#import`s must already have been added.
49/// the same module can be imported multiple times by different modules in the import tree.
50/// there is no overlap of namespaces, so the same function names (or type, constant, or variable names) may be used in different modules.
51///
52/// note: when importing an item with the `#from` directive, the final shader will include the required dependencies (bindings, globals, consts, other functions) of the imported item, but will not include the rest of the imported module. it will however still include all of any modules imported by the imported module. this is probably not desired in general and may be fixed in a future version. currently for a more complete culling of unused dependencies the `prune` module can be used.
53///
54/// ## overriding functions
55///
56/// virtual functions can be declared with the `virtual` keyword:
57/// ```ignore
58/// virtual fn point_light(world_position: vec3<f32>) -> vec3<f32> { ... }
59/// ```
60/// virtual functions defined in imported modules can then be overridden using the `override` keyword:
61///
62/// ```ignore
63/// #import bevy_pbr::lighting as Lighting
64///
65/// override fn Lighting::point_light (world_position: vec3<f32>) -> vec3<f32> {
66///     let original = Lighting::point_light(world_position);
67///     let quantized = vec3<u32>(original * 3.0);
68///     return vec3<f32>(quantized) / 3.0;
69/// }
70/// ```
71///
72/// override function definitions cause *all* calls to the original function in the entire shader scope to be replaced by calls to the new function, with the exception of calls within the override function itself.
73///
74/// the function signature of the override must match the base function.
75///
76/// overrides can be specified at any point in the final shader's import tree.
77///
78/// multiple overrides can be applied to the same function. for example, given :
79/// - a module `a` containing a function `f`,
80/// - a module `b` that imports `a`, and containing an `override a::f` function,
81/// - a module `c` that imports `a` and `b`, and containing an `override a::f` function,
82///
83/// then b and c both specify an override for `a::f`.
84/// the `override fn a::f` declared in module `b` may call to `a::f` within its body.
85/// the `override fn a::f` declared in module 'c' may call to `a::f` within its body, but the call will be redirected to `b::f`.
86/// any other calls to `a::f` (within modules 'a' or `b`, or anywhere else) will end up redirected to `c::f`
87/// in this way a chain or stack of overrides can be applied.
88///
89/// different overrides of the same function can be specified in different import branches. the final stack will be ordered based on the first occurrence of the override in the import tree (using a depth first search).
90///
91/// note that imports into a module/shader are processed in order, but are processed before the body of the current shader/module regardless of where they occur in that module, so there is no way to import a module containing an override and inject a call into the override stack prior to that imported override. you can instead create two modules each containing an override and import them into a parent module/shader to order them as required.
92/// override functions can currently only be defined in wgsl.
93///
94/// if the `override_any` crate feature is enabled, then the `virtual` keyword is not required for the function being overridden.
95///
96/// ## languages
97///
98/// modules can we written in GLSL or WGSL. shaders with entry points can be imported as modules (provided they have a `#define_import_path` directive). entry points are available to call from imported modules either via their name (for WGSL) or via `module::main` (for GLSL).
99///
100/// final shaders can also be written in GLSL or WGSL. for GLSL users must specify whether the shader is a vertex shader or fragment shader via the `ShaderType` argument (GLSL compute shaders are not supported).
101///
102/// ## preprocessing
103///
104/// when generating a final shader or adding a composable module, a set of `shader_def` string/value pairs must be provided. The value can be a bool (`ShaderDefValue::Bool`), an i32 (`ShaderDefValue::Int`) or a u32 (`ShaderDefValue::UInt`).
105///
106/// these allow conditional compilation of parts of modules and the final shader. conditional compilation is performed with `#if` / `#ifdef` / `#ifndef`, `#else` and `#endif` preprocessor directives:
107///
108/// ```ignore
109/// fn get_number() -> f32 {
110///     #ifdef BIG_NUMBER
111///         return 999.0;
112///     #else
113///         return 0.999;
114///     #endif
115/// }
116/// ```
117/// the `#ifdef` directive matches when the def name exists in the input binding set (regardless of value). the `#ifndef` directive is the reverse.
118///
119/// the `#if` directive requires a def name, an operator, and a value for comparison:
120/// - the def name must be a provided `shader_def` name.
121/// - the operator must be one of `==`, `!=`, `>=`, `>`, `<`, `<=`
122/// - the value must be an integer literal if comparing to a `ShaderDef::Int`, or `true` or `false` if comparing to a `ShaderDef::Bool`.
123///
124/// shader defs can also be used in the shader source with `#SHADER_DEF` or `#{SHADER_DEF}`, and will be substituted for their value.
125///
126/// ## error reporting
127///
128/// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting.
129///
130use naga::EntryPoint;
131use regex::Regex;
132use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
133use std::sync::LazyLock;
134use tracing::{debug, trace};
135
136use crate::{
137    compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
138    derive::DerivedModule,
139    redirect::Redirector,
140};
141
142pub use self::error::{ComposerError, ComposerErrorInner, ErrSource};
143use self::preprocess::Preprocessor;
144pub use self::wgsl_directives::{
145    DiagnosticDirective, EnableDirective, RequiresDirective, WgslDirectives,
146};
147
148pub mod comment_strip_iter;
149pub mod error;
150pub mod parse_imports;
151pub mod preprocess;
152mod test;
153pub mod tokenizer;
154pub mod wgsl_directives;
155
156#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
157pub enum ShaderLanguage {
158    #[default]
159    Wgsl,
160    #[cfg(feature = "glsl")]
161    Glsl,
162}
163
164#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
165pub enum ShaderType {
166    #[default]
167    Wgsl,
168    #[cfg(feature = "glsl")]
169    GlslVertex,
170    #[cfg(feature = "glsl")]
171    GlslFragment,
172}
173
174impl From<ShaderType> for ShaderLanguage {
175    fn from(ty: ShaderType) -> Self {
176        match ty {
177            ShaderType::Wgsl => ShaderLanguage::Wgsl,
178            #[cfg(feature = "glsl")]
179            ShaderType::GlslVertex | ShaderType::GlslFragment => ShaderLanguage::Glsl,
180        }
181    }
182}
183
184#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
185pub enum ShaderDefValue {
186    Bool(bool),
187    Int(i32),
188    UInt(u32),
189}
190
191impl Default for ShaderDefValue {
192    fn default() -> Self {
193        ShaderDefValue::Bool(true)
194    }
195}
196
197impl ShaderDefValue {
198    fn value_as_string(&self) -> String {
199        match self {
200            ShaderDefValue::Bool(val) => val.to_string(),
201            ShaderDefValue::Int(val) => val.to_string(),
202            ShaderDefValue::UInt(val) => val.to_string(),
203        }
204    }
205}
206
207#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
208pub struct OwnedShaderDefs(BTreeMap<String, ShaderDefValue>);
209
210#[derive(Clone, PartialEq, Eq, Hash, Debug)]
211struct ModuleKey(OwnedShaderDefs);
212
213impl ModuleKey {
214    fn from_members(key: &HashMap<String, ShaderDefValue>, universe: &[String]) -> Self {
215        let mut acc = OwnedShaderDefs::default();
216        for item in universe {
217            if let Some(value) = key.get(item) {
218                acc.0.insert(item.to_owned(), *value);
219            }
220        }
221        ModuleKey(acc)
222    }
223}
224
225// a module built with a specific set of shader_defs
226#[derive(Default, Debug)]
227pub struct ComposableModule {
228    // module decoration, prefixed to all items from this module in the final source
229    pub decorated_name: String,
230    // module names required as imports, optionally with a list of items to import
231    pub imports: Vec<ImportDefinition>,
232    // types exported
233    pub owned_types: HashSet<String>,
234    // constants exported
235    pub owned_constants: HashSet<String>,
236    // vars exported
237    pub owned_vars: HashSet<String>,
238    // functions exported
239    pub owned_functions: HashSet<String>,
240    // local functions that can be overridden
241    pub virtual_functions: HashSet<String>,
242    // overriding functions defined in this module
243    // target function -> Vec<replacement functions>
244    pub override_functions: IndexMap<String, Vec<String>>,
245    // naga module, built against headers for any imports
246    module_ir: naga::Module,
247    // headers in different shader languages, used for building modules/shaders that import this module
248    // headers contain types, constants, global vars and empty function definitions -
249    // just enough to convert source strings that want to import this module into naga IR
250    // headers: HashMap<ShaderLanguage, String>,
251    header_ir: naga::Module,
252    // character offset of the start of the owned module string
253    start_offset: usize,
254}
255
256// data used to build a ComposableModule
257#[derive(Debug)]
258pub struct ComposableModuleDefinition {
259    pub name: String,
260    // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots)
261    pub sanitized_source: String,
262    // language
263    pub language: ShaderLanguage,
264    // source path for error display
265    pub file_path: String,
266    // shader def values bound to this module
267    pub shader_defs: HashMap<String, ShaderDefValue>,
268    // list of shader_defs that can affect this module
269    effective_defs: Vec<String>,
270    // full list of possible imports (regardless of shader_def configuration)
271    all_imports: HashSet<String>,
272    // additional imports to add (as though they were included in the source after any other imports)
273    additional_imports: Vec<ImportDefinition>,
274    // built composable modules for a given set of shader defs
275    modules: HashMap<ModuleKey, ComposableModule>,
276    // used in spans when this module is included
277    module_index: usize,
278    // any directives used in this module
279    wgsl_directives: WgslDirectives,
280}
281
282impl ComposableModuleDefinition {
283    fn get_module(
284        &self,
285        shader_defs: &HashMap<String, ShaderDefValue>,
286    ) -> Option<&ComposableModule> {
287        self.modules
288            .get(&ModuleKey::from_members(shader_defs, &self.effective_defs))
289    }
290
291    fn insert_module(
292        &mut self,
293        shader_defs: &HashMap<String, ShaderDefValue>,
294        module: ComposableModule,
295    ) -> &ComposableModule {
296        match self
297            .modules
298            .entry(ModuleKey::from_members(shader_defs, &self.effective_defs))
299        {
300            Entry::Occupied(_) => panic!("entry already populated"),
301            Entry::Vacant(v) => v.insert(module),
302        }
303    }
304}
305
306#[derive(Debug, Clone, Default, PartialEq, Eq)]
307pub struct ImportDefinition {
308    pub import: String,
309    pub items: Vec<String>,
310}
311
312#[derive(Debug, Clone)]
313pub struct ImportDefWithOffset {
314    definition: ImportDefinition,
315    offset: usize,
316}
317
318/// module composer.
319/// stores any modules that can be imported into a shader
320/// and builds the final shader
321#[derive(Debug)]
322pub struct Composer {
323    pub validate: bool,
324    pub module_sets: HashMap<String, ComposableModuleDefinition>,
325    pub module_index: HashMap<usize, String>,
326    pub capabilities: naga::valid::Capabilities,
327    preprocessor: Preprocessor,
328    check_decoration_regex: Regex,
329    undecorate_regex: Regex,
330    virtual_fn_regex: Regex,
331    override_fn_regex: Regex,
332    undecorate_override_regex: Regex,
333    auto_binding_regex: Regex,
334    auto_binding_index: u32,
335}
336
337// shift for module index
338// 21 gives
339//   max size for shader of 2m characters
340//   max 2048 modules
341const SPAN_SHIFT: usize = 21;
342
343impl Default for Composer {
344    fn default() -> Self {
345        Self {
346            validate: true,
347            capabilities: Default::default(),
348            module_sets: Default::default(),
349            module_index: Default::default(),
350            preprocessor: Preprocessor::default(),
351            check_decoration_regex: Regex::new(
352                format!(
353                    "({}|{})",
354                    regex::escape(DECORATION_PRE),
355                    regex::escape(DECORATION_OVERRIDE_PRE)
356                )
357                .as_str(),
358            )
359            .unwrap(),
360            undecorate_regex: Regex::new(
361                format!(
362                    r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
363                    regex::escape(DECORATION_PRE),
364                    regex::escape(DECORATION_POST)
365                )
366                .as_str(),
367            )
368            .unwrap(),
369            virtual_fn_regex: Regex::new(
370                r"(?P<lead>[\s]*virtual\s+fn\s+)(?P<function>[^\s]+)(?P<trail>\s*)\(",
371            )
372            .unwrap(),
373            override_fn_regex: Regex::new(
374                format!(
375                    r"(override\s+fn\s+)([^\s]+){}([\w\d]+){}(\s*)\(",
376                    regex::escape(DECORATION_PRE),
377                    regex::escape(DECORATION_POST)
378                )
379                .as_str(),
380            )
381            .unwrap(),
382            undecorate_override_regex: Regex::new(
383                format!(
384                    "{}([A-Z0-9]*){}",
385                    regex::escape(DECORATION_OVERRIDE_PRE),
386                    regex::escape(DECORATION_POST)
387                )
388                .as_str(),
389            )
390            .unwrap(),
391            auto_binding_regex: Regex::new(r"@binding\(auto\)").unwrap(),
392            auto_binding_index: 0,
393        }
394    }
395}
396
397const DECORATION_PRE: &str = "X_naga_oil_mod_X";
398const DECORATION_POST: &str = "X";
399
400// must be same length as DECORATION_PRE for spans to work
401const DECORATION_OVERRIDE_PRE: &str = "X_naga_oil_vrt_X";
402
403struct IrBuildResult {
404    module: naga::Module,
405    start_offset: usize,
406    override_functions: IndexMap<String, Vec<String>>,
407}
408
409impl Composer {
410    pub fn decorated_name(module_name: Option<&str>, item_name: &str) -> String {
411        match module_name {
412            Some(module_name) => format!("{}{}", item_name, Self::decorate(module_name)),
413            None => item_name.to_owned(),
414        }
415    }
416
417    fn decorate(module: &str) -> String {
418        let encoded = data_encoding::BASE32_NOPAD.encode(module.as_bytes());
419        format!("{DECORATION_PRE}{encoded}{DECORATION_POST}")
420    }
421
422    fn decode(from: &str) -> String {
423        String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
424    }
425
426    /// Shorthand for creating a naga validator.
427    fn create_validator(&self) -> naga::valid::Validator {
428        naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities)
429    }
430
431    fn undecorate(&self, string: &str) -> String {
432        let undecor = self
433            .undecorate_regex
434            .replace_all(string, |caps: &regex::Captures| {
435                format!(
436                    "{}{}::{}",
437                    caps.get(1).map(|cc| cc.as_str()).unwrap_or(""),
438                    Self::decode(caps.get(3).unwrap().as_str()),
439                    caps.get(2).unwrap().as_str()
440                )
441            });
442
443        let undecor =
444            self.undecorate_override_regex
445                .replace_all(&undecor, |caps: &regex::Captures| {
446                    format!(
447                        "override fn {}::",
448                        Self::decode(caps.get(1).unwrap().as_str())
449                    )
450                });
451
452        undecor.to_string()
453    }
454
455    fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String {
456        let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n");
457        if !substituted_source.ends_with('\n') {
458            substituted_source.push('\n');
459        }
460
461        // replace @binding(auto) with an incrementing index
462        struct AutoBindingReplacer<'a> {
463            auto: &'a mut u32,
464        }
465
466        impl regex::Replacer for AutoBindingReplacer<'_> {
467            fn replace_append(&mut self, _: &regex::Captures<'_>, dst: &mut String) {
468                dst.push_str(&format!("@binding({})", self.auto));
469                *self.auto += 1;
470            }
471        }
472
473        let substituted_source = self.auto_binding_regex.replace_all(
474            &substituted_source,
475            AutoBindingReplacer {
476                auto: &mut self.auto_binding_index,
477            },
478        );
479
480        substituted_source.into_owned()
481    }
482
483    fn naga_to_string(
484        &self,
485        naga_module: &mut naga::Module,
486        language: ShaderLanguage,
487        #[allow(unused)] header_for: &str, // Only used when GLSL is enabled
488    ) -> Result<String, ComposerErrorInner> {
489        // TODO: cache headers again
490        let info = self
491            .create_validator()
492            .validate(naga_module)
493            .map_err(ComposerErrorInner::HeaderValidationError)?;
494
495        match language {
496            ShaderLanguage::Wgsl => naga::back::wgsl::write_string(
497                naga_module,
498                &info,
499                naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
500            )
501            .map_err(ComposerErrorInner::WgslBackError),
502            #[cfg(feature = "glsl")]
503            ShaderLanguage::Glsl => {
504                let vec4 = naga_module.types.insert(
505                    naga::Type {
506                        name: None,
507                        inner: naga::TypeInner::Vector {
508                            size: naga::VectorSize::Quad,
509                            scalar: naga::Scalar::F32,
510                        },
511                    },
512                    naga::Span::UNDEFINED,
513                );
514                // add a dummy entry point for glsl headers
515                let dummy_entry_point = "dummy_module_entry_point".to_owned();
516                let func = naga::Function {
517                    name: Some(dummy_entry_point.clone()),
518                    arguments: Default::default(),
519                    result: Some(naga::FunctionResult {
520                        ty: vec4,
521                        binding: Some(naga::Binding::BuiltIn(naga::BuiltIn::Position {
522                            invariant: false,
523                        })),
524                    }),
525                    local_variables: Default::default(),
526                    expressions: Default::default(),
527                    named_expressions: Default::default(),
528                    body: Default::default(),
529                    diagnostic_filter_leaf: Default::default(),
530                };
531                let ep = EntryPoint {
532                    name: dummy_entry_point.clone(),
533                    stage: naga::ShaderStage::Vertex,
534                    function: func,
535                    early_depth_test: None,
536                    workgroup_size: [0, 0, 0],
537                    workgroup_size_overrides: None,
538                    mesh_info: None,
539                    task_payload: None,
540                    incoming_ray_payload: None,
541                };
542
543                naga_module.entry_points.push(ep);
544
545                let info = self
546                    .create_validator()
547                    .validate(naga_module)
548                    .map_err(ComposerErrorInner::HeaderValidationError)?;
549
550                let mut string = String::new();
551                let options = naga::back::glsl::Options {
552                    version: naga::back::glsl::Version::Desktop(450),
553                    writer_flags: naga::back::glsl::WriterFlags::INCLUDE_UNUSED_ITEMS,
554                    ..Default::default()
555                };
556                let pipeline_options = naga::back::glsl::PipelineOptions {
557                    shader_stage: naga::ShaderStage::Vertex,
558                    entry_point: dummy_entry_point,
559                    multiview: None,
560                };
561                let mut writer = naga::back::glsl::Writer::new(
562                    &mut string,
563                    naga_module,
564                    &info,
565                    &options,
566                    &pipeline_options,
567                    naga::proc::BoundsCheckPolicies::default(),
568                )
569                .map_err(ComposerErrorInner::GlslBackError)?;
570
571                writer.write().map_err(ComposerErrorInner::GlslBackError)?;
572
573                // strip version decl and main() impl
574                let lines: Vec<_> = string.lines().collect();
575                let string = lines[1..lines.len() - 3].join("\n");
576                trace!("glsl header for {}:\n\"\n{:?}\n\"", header_for, string);
577
578                Ok(string)
579            }
580        }
581    }
582
583    // build naga module for a given shader_def configuration. builds a minimal self-contained module built against headers for imports
584    fn create_module_ir(
585        &self,
586        name: &str,
587        source: String,
588        language: ShaderLanguage,
589        imports: &[ImportDefinition],
590        shader_defs: &HashMap<String, ShaderDefValue>,
591        wgsl_directives: &WgslDirectives,
592    ) -> Result<IrBuildResult, ComposerError> {
593        debug!("creating IR for {} with defs: {:?}", name, shader_defs);
594
595        let mut module_string = match language {
596            ShaderLanguage::Wgsl => wgsl_directives.to_wgsl_string(),
597            #[cfg(feature = "glsl")]
598            ShaderLanguage::Glsl => String::from("#version 450\n"),
599        };
600
601        let mut override_functions: IndexMap<String, Vec<String>> = IndexMap::default();
602        let mut added_imports: HashSet<String> = HashSet::new();
603        let mut header_module = DerivedModule::default();
604
605        for import in imports {
606            if added_imports.contains(&import.import) {
607                continue;
608            }
609            // add to header module
610            self.add_import(
611                &mut header_module,
612                import,
613                shader_defs,
614                true,
615                &mut added_imports,
616            );
617
618            // // we must have ensured these exist with Composer::ensure_imports()
619            trace!("looking for {}", import.import);
620            let import_module_set = self.module_sets.get(&import.import).unwrap();
621            trace!("with defs {:?}", shader_defs);
622            let module = import_module_set.get_module(shader_defs).unwrap();
623            trace!("ok");
624
625            // gather overrides
626            if !module.override_functions.is_empty() {
627                for (original, replacements) in &module.override_functions {
628                    match override_functions.entry(original.clone()) {
629                        indexmap::map::Entry::Occupied(o) => {
630                            let existing = o.into_mut();
631                            let new_replacements: Vec<_> = replacements
632                                .iter()
633                                .filter(|rep| !existing.contains(rep))
634                                .cloned()
635                                .collect();
636                            existing.extend(new_replacements);
637                        }
638                        indexmap::map::Entry::Vacant(v) => {
639                            v.insert(replacements.clone());
640                        }
641                    }
642                }
643            }
644        }
645
646        let composed_header = self
647            .naga_to_string(&mut header_module.into(), language, name)
648            .map_err(|inner| ComposerError {
649                inner,
650                source: ErrSource::Module {
651                    name: name.to_owned(),
652                    offset: 0,
653                    defs: shader_defs.clone(),
654                },
655            })?;
656        module_string.push_str(&composed_header);
657
658        let start_offset = module_string.len();
659
660        module_string.push_str(&source);
661
662        trace!(
663            "parsing {}: {}, header len {}, total len {}",
664            name,
665            module_string,
666            start_offset,
667            module_string.len()
668        );
669        let module = match language {
670            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(&module_string).map_err(|e| {
671                debug!("full err'd source file: \n---\n{}\n---", module_string);
672                ComposerError {
673                    inner: ComposerErrorInner::WgslParseError(e),
674                    source: ErrSource::Module {
675                        name: name.to_owned(),
676                        offset: start_offset,
677                        defs: shader_defs.clone(),
678                    },
679                }
680            })?,
681            #[cfg(feature = "glsl")]
682            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
683                .parse(
684                    &naga::front::glsl::Options {
685                        stage: naga::ShaderStage::Vertex,
686                        defines: Default::default(),
687                    },
688                    &module_string,
689                )
690                .map_err(|e| {
691                    debug!("full err'd source file: \n---\n{}\n---", module_string);
692                    ComposerError {
693                        inner: ComposerErrorInner::GlslParseError(e),
694                        source: ErrSource::Module {
695                            name: name.to_owned(),
696                            offset: start_offset,
697                            defs: shader_defs.clone(),
698                        },
699                    }
700                })?,
701        };
702
703        Ok(IrBuildResult {
704            module,
705            start_offset,
706            override_functions,
707        })
708    }
709
710    // check that identifiers exported by a module do not get modified in string export
711    fn validate_identifiers(
712        source_ir: &naga::Module,
713        lang: ShaderLanguage,
714        header: &str,
715        module_decoration: &str,
716        owned_types: &HashSet<String>,
717    ) -> Result<(), ComposerErrorInner> {
718        // TODO: remove this once glsl front support is complete
719        #[cfg(feature = "glsl")]
720        if lang == ShaderLanguage::Glsl {
721            return Ok(());
722        }
723
724        let recompiled = match lang {
725            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(header).unwrap(),
726            #[cfg(feature = "glsl")]
727            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
728                .parse(
729                    &naga::front::glsl::Options {
730                        stage: naga::ShaderStage::Vertex,
731                        defines: Default::default(),
732                    },
733                    &format!("{}\n{}", header, "void main() {}"),
734                )
735                .map_err(|e| {
736                    debug!("full err'd source file: \n---\n{header}\n---");
737                    ComposerErrorInner::GlslParseError(e)
738                })?,
739        };
740
741        let recompiled_types: IndexMap<_, _> = recompiled
742            .types
743            .iter()
744            .flat_map(|(h, ty)| ty.name.as_deref().map(|name| (name, h)))
745            .collect();
746        for (h, ty) in source_ir.types.iter() {
747            if let Some(name) = &ty.name {
748                let decorated_type_name = format!("{name}{module_decoration}");
749                if !owned_types.contains(&decorated_type_name) {
750                    continue;
751                }
752                match recompiled_types.get(decorated_type_name.as_str()) {
753                    Some(recompiled_h) => {
754                        if let naga::TypeInner::Struct { members, .. } = &ty.inner {
755                            let recompiled_ty = recompiled.types.get_handle(*recompiled_h).unwrap();
756                            let naga::TypeInner::Struct {
757                                members: recompiled_members,
758                                ..
759                            } = &recompiled_ty.inner
760                            else {
761                                panic!();
762                            };
763                            for (member, recompiled_member) in
764                                members.iter().zip(recompiled_members)
765                            {
766                                if member.name != recompiled_member.name {
767                                    return Err(ComposerErrorInner::InvalidIdentifier {
768                                        original: member.name.clone().unwrap_or_default(),
769                                        at: source_ir.types.get_span(h),
770                                    });
771                                }
772                            }
773                        }
774                    }
775                    None => {
776                        return Err(ComposerErrorInner::InvalidIdentifier {
777                            original: name.clone(),
778                            at: source_ir.types.get_span(h),
779                        })
780                    }
781                }
782            }
783        }
784
785        let recompiled_consts: HashSet<_> = recompiled
786            .constants
787            .iter()
788            .flat_map(|(_, c)| c.name.as_deref())
789            .filter(|name| name.ends_with(module_decoration))
790            .collect();
791        for (h, c) in source_ir.constants.iter() {
792            if let Some(name) = &c.name {
793                if name.ends_with(module_decoration) && !recompiled_consts.contains(name.as_str()) {
794                    return Err(ComposerErrorInner::InvalidIdentifier {
795                        original: name.clone(),
796                        at: source_ir.constants.get_span(h),
797                    });
798                }
799            }
800        }
801
802        let recompiled_globals: HashSet<_> = recompiled
803            .global_variables
804            .iter()
805            .flat_map(|(_, c)| c.name.as_deref())
806            .filter(|name| name.ends_with(module_decoration))
807            .collect();
808        for (h, gv) in source_ir.global_variables.iter() {
809            if let Some(name) = &gv.name {
810                if name.ends_with(module_decoration) && !recompiled_globals.contains(name.as_str())
811                {
812                    return Err(ComposerErrorInner::InvalidIdentifier {
813                        original: name.clone(),
814                        at: source_ir.global_variables.get_span(h),
815                    });
816                }
817            }
818        }
819
820        let recompiled_fns: HashSet<_> = recompiled
821            .functions
822            .iter()
823            .flat_map(|(_, c)| c.name.as_deref())
824            .filter(|name| name.ends_with(module_decoration))
825            .collect();
826        for (h, f) in source_ir.functions.iter() {
827            if let Some(name) = &f.name {
828                if name.ends_with(module_decoration) && !recompiled_fns.contains(name.as_str()) {
829                    return Err(ComposerErrorInner::InvalidIdentifier {
830                        original: name.clone(),
831                        at: source_ir.functions.get_span(h),
832                    });
833                }
834            }
835        }
836
837        Ok(())
838    }
839
840    // build a ComposableModule from a ComposableModuleDefinition, for a given set of shader defs
841    // - build the naga IR (against headers)
842    // - record any types/vars/constants/functions that are defined within this module
843    // - build headers for each supported language
844    #[allow(clippy::too_many_arguments)]
845    fn create_composable_module(
846        &mut self,
847        module_definition: &ComposableModuleDefinition,
848        module_decoration: String,
849        shader_defs: &HashMap<String, ShaderDefValue>,
850        create_headers: bool,
851        demote_entrypoints: bool,
852        source: &str,
853        imports: Vec<ImportDefWithOffset>,
854    ) -> Result<ComposableModule, ComposerError> {
855        let mut imports: Vec<_> = imports
856            .into_iter()
857            .map(|import_with_offset| import_with_offset.definition)
858            .collect();
859        imports.extend(module_definition.additional_imports.to_vec());
860
861        trace!(
862            "create composable module {}: source len {}",
863            module_definition.name,
864            source.len()
865        );
866
867        // record virtual/overridable functions
868        let mut virtual_functions: HashSet<String> = Default::default();
869        let source = self
870            .virtual_fn_regex
871            .replace_all(source, |cap: &regex::Captures| {
872                let target_function = cap.get(2).unwrap().as_str().to_owned();
873
874                let replacement_str = format!(
875                    "{}fn {}{}(",
876                    " ".repeat(cap.get(1).unwrap().range().len() - 3),
877                    target_function,
878                    " ".repeat(cap.get(3).unwrap().range().len()),
879                );
880
881                virtual_functions.insert(target_function);
882
883                replacement_str
884            });
885
886        // record and rename override functions
887        let mut local_override_functions: IndexMap<String, String> = Default::default();
888
889        #[cfg(not(feature = "override_any"))]
890        let mut override_error = None;
891
892        let source =
893            self.override_fn_regex
894                .replace_all(&source, |cap: &regex::Captures| {
895                    let target_module = cap.get(3).unwrap().as_str().to_owned();
896                    let target_function = cap.get(2).unwrap().as_str().to_owned();
897
898                    #[cfg(not(feature = "override_any"))]
899                    {
900                        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
901                            ComposerError {
902                                inner,
903                                source: ErrSource::Module {
904                                    name: module_definition.name.to_owned(),
905                                    offset: 0,
906                                    defs: shader_defs.clone(),
907                                },
908                            }
909                        };
910
911                        // ensure overrides are applied to virtual functions
912                        let raw_module_name = Self::decode(&target_module);
913                        let module_set = self.module_sets.get(&raw_module_name);
914
915                        match module_set {
916                            None => {
917                                // TODO this should be unreachable?
918                                let pos = cap.get(3).unwrap().start();
919                                override_error = Some(wrap_err(
920                                    ComposerErrorInner::ImportNotFound(raw_module_name, pos),
921                                ));
922                            }
923                            Some(module_set) => {
924                                let module = module_set.get_module(shader_defs).unwrap();
925                                if !module.virtual_functions.contains(&target_function) {
926                                    let pos = cap.get(2).unwrap().start();
927                                    override_error =
928                                        Some(wrap_err(ComposerErrorInner::OverrideNotVirtual {
929                                            name: target_function.clone(),
930                                            pos,
931                                        }));
932                                }
933                            }
934                        }
935                    }
936
937                    let base_name = format!(
938                        "{}{}{}{}",
939                        target_function.as_str(),
940                        DECORATION_PRE,
941                        target_module.as_str(),
942                        DECORATION_POST,
943                    );
944                    let rename = format!(
945                        "{}{}{}{}",
946                        target_function.as_str(),
947                        DECORATION_OVERRIDE_PRE,
948                        target_module.as_str(),
949                        DECORATION_POST,
950                    );
951
952                    let replacement_str = format!(
953                        "{}fn {}{}(",
954                        " ".repeat(cap.get(1).unwrap().range().len() - 3),
955                        rename,
956                        " ".repeat(cap.get(4).unwrap().range().len()),
957                    );
958
959                    local_override_functions.insert(rename, base_name);
960
961                    replacement_str
962                })
963                .to_string();
964
965        #[cfg(not(feature = "override_any"))]
966        if let Some(err) = override_error {
967            return Err(err);
968        }
969
970        trace!("local overrides: {:?}", local_override_functions);
971        trace!(
972            "create composable module {}: source len {}",
973            module_definition.name,
974            source.len()
975        );
976
977        let IrBuildResult {
978            module: mut source_ir,
979            start_offset,
980            mut override_functions,
981        } = self.create_module_ir(
982            &module_definition.name,
983            source,
984            module_definition.language,
985            &imports,
986            shader_defs,
987            &module_definition.wgsl_directives,
988        )?;
989
990        // from here on errors need to be reported using the modified source with start_offset
991        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
992            ComposerError {
993                inner,
994                source: ErrSource::Module {
995                    name: module_definition.name.to_owned(),
996                    offset: start_offset,
997                    defs: shader_defs.clone(),
998                },
999            }
1000        };
1001
1002        // add our local override to the total set of overrides for the given function
1003        for (rename, base_name) in &local_override_functions {
1004            override_functions
1005                .entry(base_name.clone())
1006                .or_default()
1007                .push(format!("{rename}{module_decoration}"));
1008        }
1009
1010        // rename and record owned items (except types which can't be mutably accessed)
1011        let mut owned_constants = IndexMap::new();
1012        for (h, c) in source_ir.constants.iter_mut() {
1013            if let Some(name) = c.name.as_mut() {
1014                if !name.contains(DECORATION_PRE) {
1015                    *name = format!("{name}{module_decoration}");
1016                    owned_constants.insert(name.clone(), h);
1017                }
1018            }
1019        }
1020
1021        // These are naga/wgpu's pipeline override constants, not naga_oil's overrides
1022        let mut owned_pipeline_overrides = IndexMap::new();
1023        for (h, po) in source_ir.overrides.iter_mut() {
1024            if let Some(name) = po.name.as_mut() {
1025                if !name.contains(DECORATION_PRE) {
1026                    *name = format!("{name}{module_decoration}");
1027                    owned_pipeline_overrides.insert(name.clone(), h);
1028                }
1029            }
1030        }
1031
1032        let mut owned_vars = IndexMap::new();
1033        for (h, gv) in source_ir.global_variables.iter_mut() {
1034            if let Some(name) = gv.name.as_mut() {
1035                if !name.contains(DECORATION_PRE) {
1036                    *name = format!("{name}{module_decoration}");
1037
1038                    owned_vars.insert(name.clone(), h);
1039                }
1040            }
1041        }
1042
1043        let mut owned_functions = IndexMap::new();
1044        for (h_f, f) in source_ir.functions.iter_mut() {
1045            if let Some(name) = f.name.as_mut() {
1046                if !name.contains(DECORATION_PRE) {
1047                    *name = format!("{name}{module_decoration}");
1048
1049                    // create dummy header function
1050                    let header_function = naga::Function {
1051                        name: Some(name.clone()),
1052                        arguments: f.arguments.to_vec(),
1053                        result: f.result.clone(),
1054                        local_variables: Default::default(),
1055                        expressions: Default::default(),
1056                        named_expressions: Default::default(),
1057                        body: Default::default(),
1058                        diagnostic_filter_leaf: None,
1059                    };
1060
1061                    // record owned function
1062                    owned_functions.insert(name.clone(), (Some(h_f), header_function));
1063                }
1064            }
1065        }
1066
1067        if demote_entrypoints {
1068            // make normal functions out of the source entry points
1069            for ep in &mut source_ir.entry_points {
1070                ep.function.name = Some(format!(
1071                    "{}{}",
1072                    ep.function.name.as_deref().unwrap_or("main"),
1073                    module_decoration,
1074                ));
1075                let header_function = naga::Function {
1076                    name: ep.function.name.clone(),
1077                    arguments: ep
1078                        .function
1079                        .arguments
1080                        .iter()
1081                        .cloned()
1082                        .map(|arg| naga::FunctionArgument {
1083                            name: arg.name,
1084                            ty: arg.ty,
1085                            binding: None,
1086                        })
1087                        .collect(),
1088                    result: ep.function.result.clone().map(|res| naga::FunctionResult {
1089                        ty: res.ty,
1090                        binding: None,
1091                    }),
1092                    local_variables: Default::default(),
1093                    expressions: Default::default(),
1094                    named_expressions: Default::default(),
1095                    body: Default::default(),
1096                    diagnostic_filter_leaf: None,
1097                };
1098
1099                owned_functions.insert(ep.function.name.clone().unwrap(), (None, header_function));
1100            }
1101        };
1102
1103        let mut module_builder = DerivedModule::default();
1104        let mut header_builder = DerivedModule::default();
1105        module_builder.set_shader_source(&source_ir, 0);
1106        header_builder.set_shader_source(&source_ir, 0);
1107
1108        // gather special types to exclude from owned types
1109        let mut special_types: HashSet<&naga::Handle<naga::Type>> = HashSet::new();
1110        special_types.extend(source_ir.special_types.predeclared_types.values());
1111        special_types.extend(
1112            [
1113                source_ir.special_types.ray_desc.as_ref(),
1114                source_ir.special_types.ray_intersection.as_ref(),
1115                source_ir.special_types.ray_vertex_return.as_ref(),
1116            ]
1117            .iter()
1118            .flatten(),
1119        );
1120
1121        // as the header of imports that use special types includes the special type definitions explicitly,
1122        // we also exclude anything with a name matching the known special type names
1123        let special_type_names = special_types
1124            .iter()
1125            .flat_map(|h| source_ir.types.get_handle(**h).unwrap().name.clone())
1126            .collect::<HashSet<_>>();
1127
1128        let mut owned_types = HashSet::new();
1129        for (h, ty) in source_ir.types.iter() {
1130            if let Some(name) = &ty.name {
1131                // we exclude any special types, these are added back later
1132                if special_types.contains(&h) || special_type_names.contains(name) {
1133                    continue;
1134                }
1135
1136                if !name.contains(DECORATION_PRE) {
1137                    let name = format!("{name}{module_decoration}");
1138                    owned_types.insert(name.clone());
1139                    // copy and rename types
1140                    module_builder.rename_type(&h, Some(name.clone()));
1141                    header_builder.rename_type(&h, Some(name));
1142                    continue;
1143                }
1144            }
1145
1146            // copy all required types
1147            module_builder.import_type(&h);
1148        }
1149
1150        // copy owned types into header and module
1151        for h in owned_constants.values() {
1152            header_builder.import_const(h);
1153            module_builder.import_const(h);
1154        }
1155
1156        for h in owned_pipeline_overrides.values() {
1157            header_builder.import_pipeline_override(h);
1158            module_builder.import_pipeline_override(h);
1159        }
1160
1161        for h in owned_vars.values() {
1162            header_builder.import_global(h);
1163            module_builder.import_global(h);
1164        }
1165
1166        // only stubs of owned functions into the header
1167        for (h_f, f) in owned_functions.values() {
1168            let span = h_f
1169                .map(|h_f| source_ir.functions.get_span(h_f))
1170                .unwrap_or(naga::Span::UNDEFINED);
1171            header_builder.import_function(f, span); // header stub function
1172        }
1173        // all functions into the module (note source_ir only contains stubs for imported functions)
1174        for (h_f, f) in source_ir.functions.iter() {
1175            let span = source_ir.functions.get_span(h_f);
1176            module_builder.import_function(f, span);
1177        }
1178        // // including entry points as vanilla functions if required
1179        if demote_entrypoints {
1180            for ep in &source_ir.entry_points {
1181                let mut f = ep.function.clone();
1182                f.arguments = f
1183                    .arguments
1184                    .into_iter()
1185                    .map(|arg| naga::FunctionArgument {
1186                        name: arg.name,
1187                        ty: arg.ty,
1188                        binding: None,
1189                    })
1190                    .collect();
1191                f.result = f.result.map(|res| naga::FunctionResult {
1192                    ty: res.ty,
1193                    binding: None,
1194                });
1195
1196                module_builder.import_function(&f, naga::Span::UNDEFINED);
1197                // todo figure out how to get span info for entrypoints
1198            }
1199        }
1200
1201        let has_special_types = module_builder.has_required_special_types();
1202        let module_ir = module_builder.into_module_with_entrypoints();
1203        let mut header_ir: naga::Module = header_builder.into();
1204
1205        // note: we cannot validate when special types are used, as writeback isn't supported
1206        if self.validate && create_headers && !has_special_types {
1207            // check that identifiers haven't been renamed
1208            #[allow(clippy::single_element_loop)]
1209            for language in [
1210                ShaderLanguage::Wgsl,
1211                #[cfg(feature = "glsl")]
1212                ShaderLanguage::Glsl,
1213            ] {
1214                let header = self
1215                    .naga_to_string(&mut header_ir, language, &module_definition.name)
1216                    .map_err(wrap_err)?;
1217                Self::validate_identifiers(
1218                    &source_ir,
1219                    language,
1220                    &header,
1221                    &module_decoration,
1222                    &owned_types,
1223                )
1224                .map_err(wrap_err)?;
1225            }
1226        }
1227
1228        let composable_module = ComposableModule {
1229            decorated_name: module_decoration,
1230            imports,
1231            owned_types,
1232            owned_constants: owned_constants.into_keys().collect(),
1233            owned_vars: owned_vars.into_keys().collect(),
1234            owned_functions: owned_functions.into_keys().collect(),
1235            virtual_functions,
1236            override_functions,
1237            module_ir,
1238            header_ir,
1239            start_offset,
1240        };
1241
1242        Ok(composable_module)
1243    }
1244
1245    // shunt all data owned by a composable into a derived module
1246    fn add_composable_data<'a>(
1247        derived: &mut DerivedModule<'a>,
1248        composable: &'a ComposableModule,
1249        items: Option<&Vec<String>>,
1250        span_offset: usize,
1251        header: bool,
1252    ) {
1253        let items: Option<HashSet<String>> = items.map(|items| {
1254            items
1255                .iter()
1256                .map(|item| format!("{}{}", item, composable.decorated_name))
1257                .collect()
1258        });
1259        let items = items.as_ref();
1260
1261        let source_ir = match header {
1262            true => &composable.header_ir,
1263            false => &composable.module_ir,
1264        };
1265
1266        derived.set_shader_source(source_ir, span_offset);
1267
1268        for (h, ty) in source_ir.types.iter() {
1269            if let Some(name) = &ty.name {
1270                if composable.owned_types.contains(name)
1271                    && items.is_none_or(|items| items.contains(name))
1272                {
1273                    derived.import_type(&h);
1274                }
1275            }
1276        }
1277
1278        for (h, c) in source_ir.constants.iter() {
1279            if let Some(name) = &c.name {
1280                if composable.owned_constants.contains(name)
1281                    && items.is_none_or(|items| items.contains(name))
1282                {
1283                    derived.import_const(&h);
1284                }
1285            }
1286        }
1287
1288        for (h, po) in source_ir.overrides.iter() {
1289            if let Some(name) = &po.name {
1290                if composable.owned_functions.contains(name)
1291                    && items.is_none_or(|items| items.contains(name))
1292                {
1293                    derived.import_pipeline_override(&h);
1294                }
1295            }
1296        }
1297
1298        for (h, v) in source_ir.global_variables.iter() {
1299            if let Some(name) = &v.name {
1300                if composable.owned_vars.contains(name)
1301                    && items.is_none_or(|items| items.contains(name))
1302                {
1303                    derived.import_global(&h);
1304                }
1305            }
1306        }
1307
1308        for (h_f, f) in source_ir.functions.iter() {
1309            if let Some(name) = &f.name {
1310                if composable.owned_functions.contains(name)
1311                    && (items.is_none_or(|items| items.contains(name))
1312                        || composable
1313                            .override_functions
1314                            .values()
1315                            .any(|v| v.contains(name)))
1316                {
1317                    let span = composable.module_ir.functions.get_span(h_f);
1318                    derived.import_function_if_new(f, span);
1319                }
1320            }
1321        }
1322
1323        derived.clear_shader_source();
1324    }
1325
1326    // add an import (and recursive imports) into a derived module
1327    fn add_import<'a>(
1328        &'a self,
1329        derived: &mut DerivedModule<'a>,
1330        import: &ImportDefinition,
1331        shader_defs: &HashMap<String, ShaderDefValue>,
1332        header: bool,
1333        already_added: &mut HashSet<String>,
1334    ) {
1335        if already_added.contains(&import.import) {
1336            trace!("skipping {}, already added", import.import);
1337            return;
1338        }
1339
1340        let import_module_set = self.module_sets.get(&import.import).unwrap();
1341        let module = import_module_set.get_module(shader_defs).unwrap();
1342
1343        for import in &module.imports {
1344            self.add_import(derived, import, shader_defs, header, already_added);
1345        }
1346
1347        Self::add_composable_data(
1348            derived,
1349            module,
1350            Some(&import.items),
1351            import_module_set.module_index << SPAN_SHIFT,
1352            header,
1353        );
1354    }
1355
1356    fn ensure_import(
1357        &mut self,
1358        module_set: &ComposableModuleDefinition,
1359        shader_defs: &HashMap<String, ShaderDefValue>,
1360    ) -> Result<ComposableModule, EnsureImportsError> {
1361        let PreprocessOutput {
1362            preprocessed_source,
1363            imports,
1364        } = self
1365            .preprocessor
1366            .preprocess(&module_set.sanitized_source, shader_defs)
1367            .map_err(|inner| {
1368                EnsureImportsError::from(ComposerError {
1369                    inner,
1370                    source: ErrSource::Module {
1371                        name: module_set.name.to_owned(),
1372                        offset: 0,
1373                        defs: shader_defs.clone(),
1374                    },
1375                })
1376            })?;
1377
1378        self.ensure_imports(imports.iter().map(|import| &import.definition), shader_defs)?;
1379        self.ensure_imports(&module_set.additional_imports, shader_defs)?;
1380
1381        self.create_composable_module(
1382            module_set,
1383            Self::decorate(&module_set.name),
1384            shader_defs,
1385            true,
1386            true,
1387            &preprocessed_source,
1388            imports,
1389        )
1390        .map_err(|err| err.into())
1391    }
1392
1393    // build required ComposableModules for a given set of shader_defs
1394    fn ensure_imports<'a>(
1395        &mut self,
1396        imports: impl IntoIterator<Item = &'a ImportDefinition>,
1397        shader_defs: &HashMap<String, ShaderDefValue>,
1398    ) -> Result<(), EnsureImportsError> {
1399        for ImportDefinition { import, .. } in imports.into_iter() {
1400            let Some(module_set) = self.module_sets.get(import) else {
1401                return Err(EnsureImportsError::MissingImport(import.to_owned()));
1402            };
1403            if module_set.get_module(shader_defs).is_some() {
1404                continue;
1405            }
1406
1407            // we need to build the module
1408            // take the set so we can recurse without borrowing
1409            let (set_key, mut module_set) = self.module_sets.remove_entry(import).unwrap();
1410
1411            match self.ensure_import(&module_set, shader_defs) {
1412                Ok(module) => {
1413                    module_set.insert_module(shader_defs, module);
1414                    self.module_sets.insert(set_key, module_set);
1415                }
1416                Err(e) => {
1417                    self.module_sets.insert(set_key, module_set);
1418                    return Err(e);
1419                }
1420            }
1421        }
1422
1423        Ok(())
1424    }
1425}
1426
1427pub enum EnsureImportsError {
1428    MissingImport(String),
1429    ComposerError(ComposerError),
1430}
1431
1432impl EnsureImportsError {
1433    fn into_composer_error(self, err_source: ErrSource) -> ComposerError {
1434        match self {
1435            EnsureImportsError::MissingImport(import) => ComposerError {
1436                inner: ComposerErrorInner::ImportNotFound(import.to_owned(), 0),
1437                source: err_source,
1438            },
1439            EnsureImportsError::ComposerError(err) => err,
1440        }
1441    }
1442}
1443
1444impl From<ComposerError> for EnsureImportsError {
1445    fn from(value: ComposerError) -> Self {
1446        EnsureImportsError::ComposerError(value)
1447    }
1448}
1449
1450#[derive(Default)]
1451pub struct ComposableModuleDescriptor<'a> {
1452    pub source: &'a str,
1453    pub file_path: &'a str,
1454    pub language: ShaderLanguage,
1455    pub as_name: Option<String>,
1456    pub additional_imports: &'a [ImportDefinition],
1457    pub shader_defs: HashMap<String, ShaderDefValue>,
1458}
1459
1460#[derive(Default)]
1461pub struct NagaModuleDescriptor<'a> {
1462    pub source: &'a str,
1463    pub file_path: &'a str,
1464    pub shader_type: ShaderType,
1465    pub shader_defs: HashMap<String, ShaderDefValue>,
1466    pub additional_imports: &'a [ImportDefinition],
1467}
1468
1469// public api
1470impl Composer {
1471    /// create a non-validating composer.
1472    /// validation errors in the final shader will not be caught, and errors resulting from their
1473    /// use will have bad span data, so codespan reporting will fail.
1474    /// use default() to create a validating composer.
1475    pub fn non_validating() -> Self {
1476        Self {
1477            validate: false,
1478            ..Default::default()
1479        }
1480    }
1481
1482    /// specify capabilities to be used for naga module generation.
1483    /// purges any existing modules
1484    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
1485    /// for how to set the subgroup_stages value.
1486    pub fn with_capabilities(self, capabilities: naga::valid::Capabilities) -> Self {
1487        Self {
1488            capabilities,
1489            validate: self.validate,
1490            ..Default::default()
1491        }
1492    }
1493
1494    /// check if a module with the given name has been added
1495    pub fn contains_module(&self, module_name: &str) -> bool {
1496        self.module_sets.contains_key(module_name)
1497    }
1498
1499    /// add a composable module to the composer.
1500    /// all modules imported by this module must already have been added
1501    pub fn add_composable_module(
1502        &mut self,
1503        desc: ComposableModuleDescriptor,
1504    ) -> Result<&ComposableModuleDefinition, ComposerError> {
1505        let ComposableModuleDescriptor {
1506            source,
1507            file_path,
1508            language,
1509            as_name,
1510            additional_imports,
1511            mut shader_defs,
1512        } = desc;
1513
1514        // reject a module containing the DECORATION strings
1515        if let Some(decor) = self.check_decoration_regex.find(source) {
1516            return Err(ComposerError {
1517                inner: ComposerErrorInner::DecorationInSource(decor.range()),
1518                source: ErrSource::Constructing {
1519                    path: file_path.to_owned(),
1520                    source: source.to_owned(),
1521                    offset: 0,
1522                },
1523            });
1524        }
1525
1526        let substituted_source = self.sanitize_and_set_auto_bindings(source);
1527
1528        let PreprocessorMetaData {
1529            name: module_name,
1530            mut imports,
1531            mut effective_defs,
1532            cleaned_source,
1533            wgsl_directives,
1534            ..
1535        } = self
1536            .preprocessor
1537            .get_preprocessor_metadata(&substituted_source, false)
1538            .map_err(|inner| ComposerError {
1539                inner,
1540                source: ErrSource::Constructing {
1541                    path: file_path.to_owned(),
1542                    source: source.to_owned(),
1543                    offset: 0,
1544                },
1545            })?;
1546        let module_name = as_name.or(module_name);
1547        if module_name.is_none() {
1548            return Err(ComposerError {
1549                inner: ComposerErrorInner::NoModuleName,
1550                source: ErrSource::Constructing {
1551                    path: file_path.to_owned(),
1552                    source: source.to_owned(),
1553                    offset: 0,
1554                },
1555            });
1556        }
1557        let module_name = module_name.unwrap();
1558
1559        debug!(
1560            "adding module definition for {} with defs: {:?}",
1561            module_name, shader_defs
1562        );
1563
1564        // add custom imports
1565        let additional_imports = additional_imports.to_vec();
1566        imports.extend(
1567            additional_imports
1568                .iter()
1569                .cloned()
1570                .map(|def| ImportDefWithOffset {
1571                    definition: def,
1572                    offset: 0,
1573                }),
1574        );
1575
1576        for import in &imports {
1577            // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies
1578            let module_set = self
1579                .module_sets
1580                .get(&import.definition.import)
1581                .ok_or_else(|| ComposerError {
1582                    inner: ComposerErrorInner::ImportNotFound(
1583                        import.definition.import.clone(),
1584                        import.offset,
1585                    ),
1586                    source: ErrSource::Constructing {
1587                        path: file_path.to_owned(),
1588                        source: substituted_source.to_owned(),
1589                        offset: 0,
1590                    },
1591                })?;
1592            effective_defs.extend(module_set.effective_defs.iter().cloned());
1593            shader_defs.extend(
1594                module_set
1595                    .shader_defs
1596                    .iter()
1597                    .map(|def| (def.0.clone(), *def.1)),
1598            );
1599        }
1600
1601        // remove defs that are already specified through our imports
1602        effective_defs.retain(|name| !shader_defs.contains_key(name));
1603
1604        // can't gracefully report errors for more modules. perhaps this should be a warning
1605        assert!((self.module_sets.len() as u32) < u32::MAX >> SPAN_SHIFT);
1606        let module_index = self.module_sets.len() + 1;
1607
1608        let module_set = ComposableModuleDefinition {
1609            name: module_name.clone(),
1610            sanitized_source: cleaned_source,
1611            file_path: file_path.to_owned(),
1612            language,
1613            effective_defs: effective_defs.into_iter().collect(),
1614            all_imports: imports.into_iter().map(|id| id.definition.import).collect(),
1615            additional_imports,
1616            shader_defs,
1617            module_index,
1618            modules: Default::default(),
1619            wgsl_directives,
1620        };
1621
1622        // invalidate dependent modules if this module already exists
1623        self.remove_composable_module(&module_name);
1624
1625        self.module_sets.insert(module_name.clone(), module_set);
1626        self.module_index.insert(module_index, module_name.clone());
1627        Ok(self.module_sets.get(&module_name).unwrap())
1628    }
1629
1630    /// remove a composable module. also removes modules that depend on this module, as we cannot be sure about
1631    /// the completeness of their effective shader defs any more...
1632    pub fn remove_composable_module(&mut self, module_name: &str) {
1633        // todo this could be improved by making effective defs an Option<HashSet> and populating on demand?
1634        let mut dependent_sets = Vec::new();
1635
1636        if self.module_sets.remove(module_name).is_some() {
1637            dependent_sets.extend(self.module_sets.iter().filter_map(|(dependent_name, set)| {
1638                if set.all_imports.contains(module_name) {
1639                    Some(dependent_name.clone())
1640                } else {
1641                    None
1642                }
1643            }));
1644        }
1645
1646        for dependent_set in dependent_sets {
1647            self.remove_composable_module(&dependent_set);
1648        }
1649    }
1650
1651    /// build a naga shader module
1652    pub fn make_naga_module(
1653        &mut self,
1654        desc: NagaModuleDescriptor,
1655    ) -> Result<naga::Module, ComposerError> {
1656        let NagaModuleDescriptor {
1657            source,
1658            file_path,
1659            shader_type,
1660            mut shader_defs,
1661            additional_imports,
1662        } = desc;
1663
1664        let sanitized_source = self.sanitize_and_set_auto_bindings(source);
1665
1666        let PreprocessorMetaData {
1667            name,
1668            defines,
1669            wgsl_directives,
1670            cleaned_source,
1671            ..
1672        } = self
1673            .preprocessor
1674            .get_preprocessor_metadata(&sanitized_source, true)
1675            .map_err(|inner| ComposerError {
1676                inner,
1677                source: ErrSource::Constructing {
1678                    path: file_path.to_owned(),
1679                    source: sanitized_source.to_owned(),
1680                    offset: 0,
1681                },
1682            })?;
1683        shader_defs.extend(defines);
1684
1685        let name = name.unwrap_or_default();
1686
1687        let PreprocessOutput { imports, .. } = self
1688            .preprocessor
1689            .preprocess(&cleaned_source, &shader_defs)
1690            .map_err(|inner| ComposerError {
1691                inner,
1692                source: ErrSource::Constructing {
1693                    path: file_path.to_owned(),
1694                    source: sanitized_source.to_owned(),
1695                    offset: 0,
1696                },
1697            })?;
1698
1699        // make sure imports have been added
1700        // and gather additional defs specified at module level
1701        for (import_name, offset) in imports
1702            .iter()
1703            .map(|id| (&id.definition.import, id.offset))
1704            .chain(additional_imports.iter().map(|ai| (&ai.import, 0)))
1705        {
1706            if let Some(module_set) = self.module_sets.get(import_name) {
1707                for (def, value) in &module_set.shader_defs {
1708                    if let Some(prior_value) = shader_defs.insert(def.clone(), *value) {
1709                        if prior_value != *value {
1710                            return Err(ComposerError {
1711                                inner: ComposerErrorInner::InconsistentShaderDefValue {
1712                                    def: def.clone(),
1713                                },
1714                                source: ErrSource::Constructing {
1715                                    path: file_path.to_owned(),
1716                                    source: sanitized_source.to_owned(),
1717                                    offset: 0,
1718                                },
1719                            });
1720                        }
1721                    }
1722                }
1723            } else {
1724                return Err(ComposerError {
1725                    inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset),
1726                    source: ErrSource::Constructing {
1727                        path: file_path.to_owned(),
1728                        source: sanitized_source.to_owned(),
1729                        offset: 0,
1730                    },
1731                });
1732            }
1733        }
1734        self.ensure_imports(
1735            imports.iter().map(|import| &import.definition),
1736            &shader_defs,
1737        )
1738        .map_err(|err| {
1739            err.into_composer_error(ErrSource::Constructing {
1740                path: file_path.to_owned(),
1741                source: sanitized_source.to_owned(),
1742                offset: 0,
1743            })
1744        })?;
1745        self.ensure_imports(additional_imports, &shader_defs)
1746            .map_err(|err| {
1747                err.into_composer_error(ErrSource::Constructing {
1748                    path: file_path.to_owned(),
1749                    source: sanitized_source.to_owned(),
1750                    offset: 0,
1751                })
1752            })?;
1753
1754        let definition = ComposableModuleDefinition {
1755            name,
1756            sanitized_source: cleaned_source.clone(),
1757            language: shader_type.into(),
1758            file_path: file_path.to_owned(),
1759            module_index: 0,
1760            additional_imports: additional_imports.to_vec(),
1761            // we don't care about these for creating a top-level module
1762            effective_defs: Default::default(),
1763            all_imports: Default::default(),
1764            shader_defs: Default::default(),
1765            modules: Default::default(),
1766            wgsl_directives,
1767        };
1768
1769        let PreprocessOutput {
1770            preprocessed_source,
1771            imports,
1772        } = self
1773            .preprocessor
1774            .preprocess(&cleaned_source, &shader_defs)
1775            .map_err(|inner| ComposerError {
1776                inner,
1777                source: ErrSource::Constructing {
1778                    path: file_path.to_owned(),
1779                    source: sanitized_source,
1780                    offset: 0,
1781                },
1782            })?;
1783
1784        let composable = self
1785            .create_composable_module(
1786                &definition,
1787                String::from(""),
1788                &shader_defs,
1789                false,
1790                false,
1791                &preprocessed_source,
1792                imports,
1793            )
1794            .map_err(|e| ComposerError {
1795                inner: e.inner,
1796                source: ErrSource::Constructing {
1797                    path: definition.file_path.to_owned(),
1798                    source: preprocessed_source.clone(),
1799                    offset: e.source.offset(),
1800                },
1801            })?;
1802
1803        let mut derived = DerivedModule::default();
1804
1805        let mut already_added = Default::default();
1806        for import in &composable.imports {
1807            self.add_import(
1808                &mut derived,
1809                import,
1810                &shader_defs,
1811                false,
1812                &mut already_added,
1813            );
1814        }
1815
1816        Self::add_composable_data(&mut derived, &composable, None, 0, false);
1817
1818        let stage = match shader_type {
1819            #[cfg(feature = "glsl")]
1820            ShaderType::GlslVertex => Some(naga::ShaderStage::Vertex),
1821            #[cfg(feature = "glsl")]
1822            ShaderType::GlslFragment => Some(naga::ShaderStage::Fragment),
1823            _ => None,
1824        };
1825
1826        let mut entry_points = Vec::default();
1827        derived.set_shader_source(&composable.module_ir, 0);
1828        for ep in &composable.module_ir.entry_points {
1829            let mapped_func = derived.localize_function(&ep.function);
1830            entry_points.push(EntryPoint {
1831                name: ep.name.clone(),
1832                function: mapped_func,
1833                stage: stage.unwrap_or(ep.stage),
1834                early_depth_test: ep.early_depth_test,
1835                workgroup_size: ep.workgroup_size,
1836                workgroup_size_overrides: ep.workgroup_size_overrides,
1837                mesh_info: ep.mesh_info.clone(),
1838                task_payload: ep.task_payload,
1839                incoming_ray_payload: ep.incoming_ray_payload,
1840            });
1841        }
1842        let mut naga_module = naga::Module {
1843            entry_points,
1844            ..derived.into()
1845        };
1846
1847        // apply overrides
1848        if !composable.override_functions.is_empty() {
1849            let mut redirect = Redirector::new(naga_module);
1850
1851            for (base_function, overrides) in composable.override_functions {
1852                let mut omit = HashSet::default();
1853
1854                let mut original = base_function;
1855                for replacement in overrides {
1856                    let (_h_orig, _h_replace) = redirect
1857                        .redirect_function(&original, &replacement, &omit)
1858                        .map_err(|e| ComposerError {
1859                            inner: e.into(),
1860                            source: ErrSource::Constructing {
1861                                path: file_path.to_owned(),
1862                                source: preprocessed_source.clone(),
1863                                offset: composable.start_offset,
1864                            },
1865                        })?;
1866                    omit.insert(replacement.clone());
1867                    original = replacement;
1868                }
1869            }
1870
1871            naga_module = redirect.into_module().map_err(|e| ComposerError {
1872                inner: e.into(),
1873                source: ErrSource::Constructing {
1874                    path: file_path.to_owned(),
1875                    source: preprocessed_source.clone(),
1876                    offset: composable.start_offset,
1877                },
1878            })?;
1879        }
1880
1881        // validation
1882        if self.validate {
1883            let info = self.create_validator().validate(&naga_module);
1884            match info {
1885                Ok(_) => Ok(naga_module),
1886                Err(e) => {
1887                    let original_span = e.spans().last();
1888                    let err_source = match original_span.and_then(|(span, _)| span.to_range()) {
1889                        Some(rng) => {
1890                            let module_index = rng.start >> SPAN_SHIFT;
1891                            match module_index {
1892                                0 => ErrSource::Constructing {
1893                                    path: file_path.to_owned(),
1894                                    source: preprocessed_source.clone(),
1895                                    offset: composable.start_offset,
1896                                },
1897                                _ => {
1898                                    let module_name =
1899                                        self.module_index.get(&module_index).unwrap().clone();
1900                                    let offset = self
1901                                        .module_sets
1902                                        .get(&module_name)
1903                                        .unwrap()
1904                                        .get_module(&shader_defs)
1905                                        .unwrap()
1906                                        .start_offset;
1907                                    ErrSource::Module {
1908                                        name: module_name,
1909                                        offset,
1910                                        defs: shader_defs.clone(),
1911                                    }
1912                                }
1913                            }
1914                        }
1915                        None => ErrSource::Constructing {
1916                            path: file_path.to_owned(),
1917                            source: preprocessed_source.clone(),
1918                            offset: composable.start_offset,
1919                        },
1920                    };
1921
1922                    Err(ComposerError {
1923                        inner: ComposerErrorInner::ShaderValidationError(e),
1924                        source: err_source,
1925                    })
1926                }
1927            }
1928        } else {
1929            Ok(naga_module)
1930        }
1931    }
1932}
1933
1934static PREPROCESSOR: LazyLock<Preprocessor> = LazyLock::new(Preprocessor::default);
1935
1936/// Get module name and all required imports (ignoring shader_defs) from a shader string
1937pub fn get_preprocessor_data(
1938    source: &str,
1939) -> (
1940    Option<String>,
1941    Vec<ImportDefinition>,
1942    HashMap<String, ShaderDefValue>,
1943) {
1944    if let Ok(PreprocessorMetaData {
1945        name,
1946        imports,
1947        defines,
1948        ..
1949    }) = PREPROCESSOR.get_preprocessor_metadata(source, true)
1950    {
1951        (
1952            name,
1953            imports
1954                .into_iter()
1955                .map(|import_with_offset| import_with_offset.definition)
1956                .collect(),
1957            defines,
1958        )
1959    } else {
1960        // if errors occur we return nothing; the actual error will be displayed when the caller attempts to use the shader
1961        Default::default()
1962    }
1963}