1use crate::shader::*;
2use alloc::sync::Arc;
3use bevy_asset::AssetId;
4use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
5use core::hash::Hash;
6use thiserror::Error;
7use tracing::debug;
8use wgpu_types::{DownlevelFlags, Features};
9
10#[cfg_attr(
20 not(feature = "decoupled_naga"),
21 expect(
22 clippy::large_enum_variant,
23 reason = "naga modules are the most common use, and are large"
24 )
25)]
26#[derive(Clone, Debug)]
27pub enum ShaderCacheSource<'a> {
28 SpirV(&'a [u8]),
30 Wgsl(String),
32 #[cfg(not(feature = "decoupled_naga"))]
34 Naga(naga::Module),
35}
36
37pub type CachedPipelineId = usize;
40
41struct ShaderData<ShaderModule> {
42 pipelines: HashSet<CachedPipelineId>,
43 processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<ShaderModule>>,
44 resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
45 dependents: HashSet<AssetId<Shader>>,
46}
47
48impl<T> Default for ShaderData<T> {
49 fn default() -> Self {
50 Self {
51 pipelines: Default::default(),
52 processed_shaders: Default::default(),
53 resolved_imports: Default::default(),
54 dependents: Default::default(),
55 }
56 }
57}
58
59pub struct ShaderCache<ShaderModule, RenderDevice> {
67 device: RenderDevice,
68 data: HashMap<AssetId<Shader>, ShaderData<ShaderModule>>,
69 load_module: fn(
70 &RenderDevice,
71 ShaderCacheSource,
72 &ValidateShader,
73 ) -> Result<ShaderModule, ShaderCacheError>,
74 #[cfg(feature = "shader_format_wesl")]
75 module_path_to_asset_id: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
76 shaders: HashMap<AssetId<Shader>, Shader>,
77 import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
78 waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
79 #[doc(hidden)]
81 pub composer: naga_oil::compose::Composer,
82}
83
84#[expect(missing_docs, reason = "Enum variants are self-explanatory")]
87#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug, Hash)]
88pub enum ShaderDefVal {
89 Bool(String, bool),
90 Int(String, i32),
91 UInt(String, u32),
92}
93
94impl From<&str> for ShaderDefVal {
95 fn from(key: &str) -> Self {
96 ShaderDefVal::Bool(key.to_string(), true)
97 }
98}
99
100impl From<String> for ShaderDefVal {
101 fn from(key: String) -> Self {
102 ShaderDefVal::Bool(key, true)
103 }
104}
105
106impl ShaderDefVal {
107 pub fn value_as_string(&self) -> String {
109 match self {
110 ShaderDefVal::Bool(_, def) => def.to_string(),
111 ShaderDefVal::Int(_, def) => def.to_string(),
112 ShaderDefVal::UInt(_, def) => def.to_string(),
113 }
114 }
115}
116
117impl<ShaderModule, RenderDevice> ShaderCache<ShaderModule, RenderDevice> {
118 pub fn new(
122 device: RenderDevice,
123 features: Features,
124 downlevel: DownlevelFlags,
125 load_module: fn(
126 &RenderDevice,
127 ShaderCacheSource,
128 &ValidateShader,
129 ) -> Result<ShaderModule, ShaderCacheError>,
130 ) -> Self {
131 let capabilities = wgpu_naga_bridge::features_to_naga_capabilities(features, downlevel);
132 #[cfg(debug_assertions)]
133 let composer = naga_oil::compose::Composer::default();
134 #[cfg(not(debug_assertions))]
135 let composer = naga_oil::compose::Composer::non_validating();
136
137 let composer = composer.with_capabilities(capabilities);
138
139 Self {
140 device,
141 composer,
142 load_module,
143 data: Default::default(),
144 #[cfg(feature = "shader_format_wesl")]
145 module_path_to_asset_id: Default::default(),
146 shaders: Default::default(),
147 import_path_shaders: Default::default(),
148 waiting_on_import: Default::default(),
149 }
150 }
151
152 fn add_import_to_composer(
153 composer: &mut naga_oil::compose::Composer,
154 import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
155 shaders: &HashMap<AssetId<Shader>, Shader>,
156 import: &ShaderImport,
157 ) -> Result<(), ShaderCacheError> {
158 if composer.contains_module(&import.module_name()) {
160 return Ok(());
161 }
162
163 let shader = import_path_shaders
165 .get(import)
166 .and_then(|handle| shaders.get(handle))
167 .ok_or(ShaderCacheError::ShaderImportNotYetAvailable)?;
168
169 for import in &shader.imports {
171 Self::add_import_to_composer(composer, import_path_shaders, shaders, import)?;
172 }
173
174 composer
175 .add_composable_module(shader.into())
176 .map_err(Box::new)?;
177 Ok(())
180 }
181
182 pub fn get(
193 &mut self,
194 pipeline: CachedPipelineId,
195 id: AssetId<Shader>,
196 shader_defs: &[ShaderDefVal],
197 ) -> Result<Arc<ShaderModule>, ShaderCacheError> {
198 let shader = self
199 .shaders
200 .get(&id)
201 .ok_or(ShaderCacheError::ShaderNotLoaded(id))?;
202
203 let data = self.data.entry(id).or_default();
204 let n_asset_imports = shader
205 .imports
206 .iter()
207 .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
208 .count();
209 let n_resolved_asset_imports = data
210 .resolved_imports
211 .keys()
212 .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
213 .count();
214 if n_asset_imports != n_resolved_asset_imports {
215 return Err(ShaderCacheError::ShaderImportNotYetAvailable);
216 }
217
218 data.pipelines.insert(pipeline);
219
220 let module = match data.processed_shaders.entry_ref(shader_defs) {
221 EntryRef::Occupied(entry) => entry.into_mut(),
222 EntryRef::Vacant(entry) => {
223 debug!(
224 "processing shader {}, with shader defs {:?}",
225 id, shader_defs
226 );
227 let shader_source = match &shader.source {
228 Source::SpirV(data) => ShaderCacheSource::SpirV(data.as_ref()),
229 #[cfg(feature = "shader_format_wesl")]
230 Source::Wesl(_) => {
231 if let ShaderImport::AssetPath(path) = &shader.import_path {
232 let shader_resolver =
233 ShaderResolver::new(&self.module_path_to_asset_id, &self.shaders);
234 let module_path = wesl::syntax::ModulePath::from_path(path);
235 let mut compiler_options = wesl::CompileOptions {
236 imports: true,
237 condcomp: true,
238 lower: true,
239 ..Default::default()
240 };
241
242 for shader_def in shader_defs {
243 match shader_def {
244 ShaderDefVal::Bool(key, value) => {
245 compiler_options.features.flags.insert(key.clone(), (*value).into());
246 }
247 _ => debug!(
248 "ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
249 ),
250 }
251 }
252
253 let compiled = wesl::compile(
254 &module_path,
255 &shader_resolver,
256 &wesl::EscapeMangler,
257 &compiler_options,
258 )
259 .unwrap();
260
261 ShaderCacheSource::Wgsl(compiled.to_string())
262 } else {
263 panic!("Wesl shaders must be imported from a file");
264 }
265 }
266 _ => {
267 for import in shader.imports.iter() {
268 Self::add_import_to_composer(
269 &mut self.composer,
270 &self.import_path_shaders,
271 &self.shaders,
272 import,
273 )?;
274 }
275
276 let shader_defs = shader_defs
277 .iter()
278 .chain(shader.shader_defs.iter())
279 .map(|def| match def.clone() {
280 ShaderDefVal::Bool(k, v) => {
281 (k, naga_oil::compose::ShaderDefValue::Bool(v))
282 }
283 ShaderDefVal::Int(k, v) => {
284 (k, naga_oil::compose::ShaderDefValue::Int(v))
285 }
286 ShaderDefVal::UInt(k, v) => {
287 (k, naga_oil::compose::ShaderDefValue::UInt(v))
288 }
289 })
290 .collect::<std::collections::HashMap<_, _>>();
291
292 let naga = self
293 .composer
294 .make_naga_module(naga_oil::compose::NagaModuleDescriptor {
295 shader_defs,
296 ..shader.into()
297 })
298 .map_err(Box::new)?;
299
300 #[cfg(not(feature = "decoupled_naga"))]
301 {
302 ShaderCacheSource::Naga(naga)
303 }
304
305 #[cfg(feature = "decoupled_naga")]
306 {
307 let mut validator = naga::valid::Validator::new(
308 naga::valid::ValidationFlags::all(),
309 self.composer.capabilities,
310 );
311 let module_info = validator.validate(&naga).unwrap();
312 let wgsl = naga::back::wgsl::write_string(
313 &naga,
314 &module_info,
315 naga::back::wgsl::WriterFlags::empty(),
316 )
317 .unwrap();
318 ShaderCacheSource::Wgsl(wgsl)
319 }
320 }
321 };
322
323 let shader_module =
324 (self.load_module)(&self.device, shader_source, &shader.validate_shader)?;
325
326 entry.insert(Arc::new(shader_module))
327 }
328 };
329
330 Ok(module.clone())
331 }
332
333 fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
334 let mut shaders_to_clear = vec![id];
335 let mut pipelines_to_queue = Vec::new();
336 while let Some(handle) = shaders_to_clear.pop() {
337 if let Some(data) = self.data.get_mut(&handle) {
338 data.processed_shaders.clear();
339 pipelines_to_queue.extend(data.pipelines.iter().copied());
340 shaders_to_clear.extend(data.dependents.iter().copied());
341
342 if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
343 self.composer
344 .remove_composable_module(&import_path.module_name());
345 }
346 }
347 }
348
349 pipelines_to_queue
350 }
351
352 pub fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
357 let pipelines_to_queue = self.clear(id);
358 let path = &shader.import_path;
359 self.import_path_shaders.insert(path.clone(), id);
360 if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
361 for waiting_shader in waiting_shaders.drain(..) {
362 let data = self.data.entry(waiting_shader).or_default();
364 data.resolved_imports.insert(path.clone(), id);
365 let data = self.data.entry(id).or_default();
367 data.dependents.insert(waiting_shader);
368 }
369 }
370
371 for import in shader.imports.iter() {
372 if let Some(import_id) = self.import_path_shaders.get(import).copied() {
373 let data = self.data.entry(id).or_default();
375 data.resolved_imports.insert(import.clone(), import_id);
376 let data = self.data.entry(import_id).or_default();
378 data.dependents.insert(id);
379 } else {
380 let waiting = self.waiting_on_import.entry(import.clone()).or_default();
381 waiting.push(id);
382 }
383 }
384
385 #[cfg(feature = "shader_format_wesl")]
386 if let Source::Wesl(_) = shader.source
387 && let ShaderImport::AssetPath(path) = &shader.import_path
388 {
389 self.module_path_to_asset_id
390 .insert(wesl::syntax::ModulePath::from_path(path), id);
391 }
392 self.shaders.insert(id, shader);
393 pipelines_to_queue
394 }
395
396 pub fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
401 let pipelines_to_queue = self.clear(id);
402 if let Some(shader) = self.shaders.remove(&id) {
403 self.import_path_shaders.remove(&shader.import_path);
404 }
405
406 pipelines_to_queue
407 }
408}
409
410#[cfg(feature = "shader_format_wesl")]
412pub struct ShaderResolver<'a> {
413 module_path_to_asset_id: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
414 shaders: &'a HashMap<AssetId<Shader>, Shader>,
415}
416
417#[cfg(feature = "shader_format_wesl")]
418impl<'a> ShaderResolver<'a> {
419 pub fn new(
423 module_path_to_asset_id: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
424 shaders: &'a HashMap<AssetId<Shader>, Shader>,
425 ) -> Self {
426 Self {
427 module_path_to_asset_id,
428 shaders,
429 }
430 }
431}
432
433#[cfg(feature = "shader_format_wesl")]
434impl<'a> wesl::Resolver for ShaderResolver<'a> {
435 fn resolve_source(
436 &self,
437 module_path: &wesl::syntax::ModulePath,
438 ) -> Result<alloc::borrow::Cow<'_, str>, wesl::ResolveError> {
439 let asset_id = self
440 .module_path_to_asset_id
441 .get(module_path)
442 .ok_or_else(|| {
443 wesl::ResolveError::ModuleNotFound(
444 module_path.clone(),
445 "Invalid asset id".to_string(),
446 )
447 })?;
448
449 let shader = self.shaders.get(asset_id).unwrap();
450 Ok(alloc::borrow::Cow::Borrowed(shader.source.as_str()))
451 }
452}
453
454#[expect(missing_docs, reason = "Enum variants are self-explanatory")]
456#[derive(Error, Debug)]
457pub enum ShaderCacheError {
458 #[error(
459 "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
460 )]
461 ShaderNotLoaded(AssetId<Shader>),
462 #[error(transparent)]
463 ProcessShaderError(#[from] Box<naga_oil::compose::ComposerError>),
464 #[error("Shader import not yet available.")]
465 ShaderImportNotYetAvailable,
466 #[error("Could not create shader module: {0}")]
467 CreateShaderModule(String),
468}