Skip to main content

bevy_shader/
shader_cache.rs

1use crate::shader::*;
2use alloc::sync::Arc;
3use bevy_asset::AssetId;
4use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
5use core::hash::Hash;
6use thiserror::Error;
7use tracing::debug;
8use wgpu_types::{DownlevelFlags, Features};
9
10/// Fully composed source code of a shader module, with all shader defs applied.
11///
12/// This is roughly equivalent to [`wgpu::ShaderSource`](https://docs.rs/wgpu/latest/wgpu/enum.ShaderSource.html),
13/// but with less variants and more concrete types instead of [`Cow`](alloc::borrow::Cow).
14///
15/// This source will be parsed and validated by the renderer.
16///
17/// Any necessary shader translation (e.g. from WGSL to SPIR-V or vice versa)
18/// must be done internally by the renderer.
19#[cfg_attr(
20    not(feature = "decoupled_naga"),
21    expect(
22        clippy::large_enum_variant,
23        reason = "naga modules are the most common use, and are large"
24    )
25)]
26#[derive(Clone, Debug)]
27pub enum ShaderCacheSource<'a> {
28    /// SPIR-V module represented as a slice of words.
29    SpirV(&'a [u8]),
30    /// WGSL module as a string slice.
31    Wgsl(String),
32    /// Naga module.
33    #[cfg(not(feature = "decoupled_naga"))]
34    Naga(naga::Module),
35}
36
37/// An id of a pipeline, typically in the [`PipelineCache`](https://docs.rs/bevy/latest/bevy/render/render_resource/struct.PipelineCache.html)
38/// Typically corresponds to a unique combination of [`Shader`] and [`ShaderDefVal`]s.
39pub type CachedPipelineId = usize;
40
41struct ShaderData<ShaderModule> {
42    pipelines: HashSet<CachedPipelineId>,
43    processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<ShaderModule>>,
44    resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
45    dependents: HashSet<AssetId<Shader>>,
46}
47
48impl<T> Default for ShaderData<T> {
49    fn default() -> Self {
50        Self {
51            pipelines: Default::default(),
52            processed_shaders: Default::default(),
53            resolved_imports: Default::default(),
54            dependents: Default::default(),
55        }
56    }
57}
58
59/// A cache for shaders and shader imports, with asset state-tracking for
60/// waiting to load shaders until all imports are resolved.
61///
62/// Note that the `RenderDevice` generic parameter is a means by which
63/// to avoid a cyclic dependency with `bevy_render`, while also permitting
64/// alternative rendering implementations. The actual processing of the
65/// shader source into a usable compiled module is left to the renderer.
66pub struct ShaderCache<ShaderModule, RenderDevice> {
67    device: RenderDevice,
68    data: HashMap<AssetId<Shader>, ShaderData<ShaderModule>>,
69    load_module: fn(
70        &RenderDevice,
71        ShaderCacheSource,
72        &ValidateShader,
73    ) -> Result<ShaderModule, ShaderCacheError>,
74    #[cfg(feature = "shader_format_wesl")]
75    module_path_to_asset_id: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
76    shaders: HashMap<AssetId<Shader>, Shader>,
77    import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
78    waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
79    // The naga composer is only public for providing error messages and should not be touched.
80    #[doc(hidden)]
81    pub composer: naga_oil::compose::Composer,
82}
83
84/// A compile time shader value definition to be inlined into the shader source.
85/// Variant tuples contain the name of the definition, and the value.
86#[expect(missing_docs, reason = "Enum variants are self-explanatory")]
87#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug, Hash)]
88pub enum ShaderDefVal {
89    Bool(String, bool),
90    Int(String, i32),
91    UInt(String, u32),
92}
93
94impl From<&str> for ShaderDefVal {
95    fn from(key: &str) -> Self {
96        ShaderDefVal::Bool(key.to_string(), true)
97    }
98}
99
100impl From<String> for ShaderDefVal {
101    fn from(key: String) -> Self {
102        ShaderDefVal::Bool(key, true)
103    }
104}
105
106impl ShaderDefVal {
107    /// Returns the value of the define as a string.
108    pub fn value_as_string(&self) -> String {
109        match self {
110            ShaderDefVal::Bool(_, def) => def.to_string(),
111            ShaderDefVal::Int(_, def) => def.to_string(),
112            ShaderDefVal::UInt(_, def) => def.to_string(),
113        }
114    }
115}
116
117impl<ShaderModule, RenderDevice> ShaderCache<ShaderModule, RenderDevice> {
118    /// Creates a new [`ShaderCache`] with the given features and shader
119    /// module loading function. `load_module` is responsible for actually
120    /// compiling shader source into a module usable by the render device.
121    pub fn new(
122        device: RenderDevice,
123        features: Features,
124        downlevel: DownlevelFlags,
125        load_module: fn(
126            &RenderDevice,
127            ShaderCacheSource,
128            &ValidateShader,
129        ) -> Result<ShaderModule, ShaderCacheError>,
130    ) -> Self {
131        let capabilities = wgpu_naga_bridge::features_to_naga_capabilities(features, downlevel);
132        #[cfg(debug_assertions)]
133        let composer = naga_oil::compose::Composer::default();
134        #[cfg(not(debug_assertions))]
135        let composer = naga_oil::compose::Composer::non_validating();
136
137        let composer = composer.with_capabilities(capabilities);
138
139        Self {
140            device,
141            composer,
142            load_module,
143            data: Default::default(),
144            #[cfg(feature = "shader_format_wesl")]
145            module_path_to_asset_id: Default::default(),
146            shaders: Default::default(),
147            import_path_shaders: Default::default(),
148            waiting_on_import: Default::default(),
149        }
150    }
151
152    fn add_import_to_composer(
153        composer: &mut naga_oil::compose::Composer,
154        import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
155        shaders: &HashMap<AssetId<Shader>, Shader>,
156        import: &ShaderImport,
157    ) -> Result<(), ShaderCacheError> {
158        // Early out if we've already imported this module
159        if composer.contains_module(&import.module_name()) {
160            return Ok(());
161        }
162
163        // Check if the import is available (this handles the recursive import case)
164        let shader = import_path_shaders
165            .get(import)
166            .and_then(|handle| shaders.get(handle))
167            .ok_or(ShaderCacheError::ShaderImportNotYetAvailable)?;
168
169        // Recurse down to ensure all import dependencies are met
170        for import in &shader.imports {
171            Self::add_import_to_composer(composer, import_path_shaders, shaders, import)?;
172        }
173
174        composer
175            .add_composable_module(shader.into())
176            .map_err(Box::new)?;
177        // if we fail to add a module the composer will tell us what is missing
178
179        Ok(())
180    }
181
182    /// Attempts to retrieve or create a compiled shader module for the given
183    /// shader id and shader definitions.
184    ///
185    /// The provided `pipeline` is tracked so it may later be reported "dirty"
186    /// when a shader is removed or replaced.
187    ///
188    /// Note that the cache is keyed by `id` and `shader_defs`, meaning providing
189    /// the same `shader_defs` in a different order, or with redundancies, will
190    /// not result in cache hits, and thus require re-composing the module and
191    /// calling `load_module` again.
192    pub fn get(
193        &mut self,
194        pipeline: CachedPipelineId,
195        id: AssetId<Shader>,
196        shader_defs: &[ShaderDefVal],
197    ) -> Result<Arc<ShaderModule>, ShaderCacheError> {
198        let shader = self
199            .shaders
200            .get(&id)
201            .ok_or(ShaderCacheError::ShaderNotLoaded(id))?;
202
203        let data = self.data.entry(id).or_default();
204        let n_asset_imports = shader
205            .imports
206            .iter()
207            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
208            .count();
209        let n_resolved_asset_imports = data
210            .resolved_imports
211            .keys()
212            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
213            .count();
214        if n_asset_imports != n_resolved_asset_imports {
215            return Err(ShaderCacheError::ShaderImportNotYetAvailable);
216        }
217
218        data.pipelines.insert(pipeline);
219
220        let module = match data.processed_shaders.entry_ref(shader_defs) {
221            EntryRef::Occupied(entry) => entry.into_mut(),
222            EntryRef::Vacant(entry) => {
223                debug!(
224                    "processing shader {}, with shader defs {:?}",
225                    id, shader_defs
226                );
227                let shader_source = match &shader.source {
228                    Source::SpirV(data) => ShaderCacheSource::SpirV(data.as_ref()),
229                    #[cfg(feature = "shader_format_wesl")]
230                    Source::Wesl(_) => {
231                        if let ShaderImport::AssetPath(path) = &shader.import_path {
232                            let shader_resolver =
233                                ShaderResolver::new(&self.module_path_to_asset_id, &self.shaders);
234                            let module_path = wesl::syntax::ModulePath::from_path(path);
235                            let mut compiler_options = wesl::CompileOptions {
236                                imports: true,
237                                condcomp: true,
238                                lower: true,
239                                ..Default::default()
240                            };
241
242                            for shader_def in shader_defs {
243                                match shader_def {
244                                    ShaderDefVal::Bool(key, value) => {
245                                        compiler_options.features.flags.insert(key.clone(), (*value).into());
246                                    }
247                                    _ => debug!(
248                                        "ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
249                                    ),
250                                }
251                            }
252
253                            let compiled = wesl::compile(
254                                &module_path,
255                                &shader_resolver,
256                                &wesl::EscapeMangler,
257                                &compiler_options,
258                            )
259                            .unwrap();
260
261                            ShaderCacheSource::Wgsl(compiled.to_string())
262                        } else {
263                            panic!("Wesl shaders must be imported from a file");
264                        }
265                    }
266                    _ => {
267                        for import in shader.imports.iter() {
268                            Self::add_import_to_composer(
269                                &mut self.composer,
270                                &self.import_path_shaders,
271                                &self.shaders,
272                                import,
273                            )?;
274                        }
275
276                        let shader_defs = shader_defs
277                            .iter()
278                            .chain(shader.shader_defs.iter())
279                            .map(|def| match def.clone() {
280                                ShaderDefVal::Bool(k, v) => {
281                                    (k, naga_oil::compose::ShaderDefValue::Bool(v))
282                                }
283                                ShaderDefVal::Int(k, v) => {
284                                    (k, naga_oil::compose::ShaderDefValue::Int(v))
285                                }
286                                ShaderDefVal::UInt(k, v) => {
287                                    (k, naga_oil::compose::ShaderDefValue::UInt(v))
288                                }
289                            })
290                            .collect::<std::collections::HashMap<_, _>>();
291
292                        let naga = self
293                            .composer
294                            .make_naga_module(naga_oil::compose::NagaModuleDescriptor {
295                                shader_defs,
296                                ..shader.into()
297                            })
298                            .map_err(Box::new)?;
299
300                        #[cfg(not(feature = "decoupled_naga"))]
301                        {
302                            ShaderCacheSource::Naga(naga)
303                        }
304
305                        #[cfg(feature = "decoupled_naga")]
306                        {
307                            let mut validator = naga::valid::Validator::new(
308                                naga::valid::ValidationFlags::all(),
309                                self.composer.capabilities,
310                            );
311                            let module_info = validator.validate(&naga).unwrap();
312                            let wgsl = naga::back::wgsl::write_string(
313                                &naga,
314                                &module_info,
315                                naga::back::wgsl::WriterFlags::empty(),
316                            )
317                            .unwrap();
318                            ShaderCacheSource::Wgsl(wgsl)
319                        }
320                    }
321                };
322
323                let shader_module =
324                    (self.load_module)(&self.device, shader_source, &shader.validate_shader)?;
325
326                entry.insert(Arc::new(shader_module))
327            }
328        };
329
330        Ok(module.clone())
331    }
332
333    fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
334        let mut shaders_to_clear = vec![id];
335        let mut pipelines_to_queue = Vec::new();
336        while let Some(handle) = shaders_to_clear.pop() {
337            if let Some(data) = self.data.get_mut(&handle) {
338                data.processed_shaders.clear();
339                pipelines_to_queue.extend(data.pipelines.iter().copied());
340                shaders_to_clear.extend(data.dependents.iter().copied());
341
342                if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
343                    self.composer
344                        .remove_composable_module(&import_path.module_name());
345                }
346            }
347        }
348
349        pipelines_to_queue
350    }
351
352    /// Inserts and possibly replaces a shader at the given asset id.
353    ///
354    /// Returns a vec of which cached pipelines depended on it
355    /// (directly or indirectly via a shader import) and thus must be recompiled.
356    pub fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
357        let pipelines_to_queue = self.clear(id);
358        let path = &shader.import_path;
359        self.import_path_shaders.insert(path.clone(), id);
360        if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
361            for waiting_shader in waiting_shaders.drain(..) {
362                // resolve waiting shader import
363                let data = self.data.entry(waiting_shader).or_default();
364                data.resolved_imports.insert(path.clone(), id);
365                // add waiting shader as dependent of this shader
366                let data = self.data.entry(id).or_default();
367                data.dependents.insert(waiting_shader);
368            }
369        }
370
371        for import in shader.imports.iter() {
372            if let Some(import_id) = self.import_path_shaders.get(import).copied() {
373                // resolve import because it is currently available
374                let data = self.data.entry(id).or_default();
375                data.resolved_imports.insert(import.clone(), import_id);
376                // add this shader as a dependent of the import
377                let data = self.data.entry(import_id).or_default();
378                data.dependents.insert(id);
379            } else {
380                let waiting = self.waiting_on_import.entry(import.clone()).or_default();
381                waiting.push(id);
382            }
383        }
384
385        #[cfg(feature = "shader_format_wesl")]
386        if let Source::Wesl(_) = shader.source
387            && let ShaderImport::AssetPath(path) = &shader.import_path
388        {
389            self.module_path_to_asset_id
390                .insert(wesl::syntax::ModulePath::from_path(path), id);
391        }
392        self.shaders.insert(id, shader);
393        pipelines_to_queue
394    }
395
396    /// Removes the shader with the given asset id.
397    ///
398    /// Returns a vec of which cached pipelines depended on it
399    /// (directly or indirectly via a shader import) and thus must be recompiled.
400    pub fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
401        let pipelines_to_queue = self.clear(id);
402        if let Some(shader) = self.shaders.remove(&id) {
403            self.import_path_shaders.remove(&shader.import_path);
404        }
405
406        pipelines_to_queue
407    }
408}
409
410/// A Wesl import resolver. Maps module paths to actual Wesl shader source.
411#[cfg(feature = "shader_format_wesl")]
412pub struct ShaderResolver<'a> {
413    module_path_to_asset_id: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
414    shaders: &'a HashMap<AssetId<Shader>, Shader>,
415}
416
417#[cfg(feature = "shader_format_wesl")]
418impl<'a> ShaderResolver<'a> {
419    /// Creates a shader resolver with the given map of module paths to shader asset ids,
420    /// and map of shader asset ids to shader source. This resolver is not meant to be
421    /// long living.
422    pub fn new(
423        module_path_to_asset_id: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
424        shaders: &'a HashMap<AssetId<Shader>, Shader>,
425    ) -> Self {
426        Self {
427            module_path_to_asset_id,
428            shaders,
429        }
430    }
431}
432
433#[cfg(feature = "shader_format_wesl")]
434impl<'a> wesl::Resolver for ShaderResolver<'a> {
435    fn resolve_source(
436        &self,
437        module_path: &wesl::syntax::ModulePath,
438    ) -> Result<alloc::borrow::Cow<'_, str>, wesl::ResolveError> {
439        let asset_id = self
440            .module_path_to_asset_id
441            .get(module_path)
442            .ok_or_else(|| {
443                wesl::ResolveError::ModuleNotFound(
444                    module_path.clone(),
445                    "Invalid asset id".to_string(),
446                )
447            })?;
448
449        let shader = self.shaders.get(asset_id).unwrap();
450        Ok(alloc::borrow::Cow::Borrowed(shader.source.as_str()))
451    }
452}
453
454/// Type of error returned by a `PipelineCache` when the creation of a GPU pipeline object failed.
455#[expect(missing_docs, reason = "Enum variants are self-explanatory")]
456#[derive(Error, Debug)]
457pub enum ShaderCacheError {
458    #[error(
459        "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
460    )]
461    ShaderNotLoaded(AssetId<Shader>),
462    #[error(transparent)]
463    ProcessShaderError(#[from] Box<naga_oil::compose::ComposerError>),
464    #[error("Shader import not yet available.")]
465    ShaderImportNotYetAvailable,
466    #[error("Could not create shader module: {0}")]
467    CreateShaderModule(String),
468}