Skip to main content

bevy_shader/
shader.rs

1use super::ShaderDefVal;
2use alloc::borrow::Cow;
3use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
4use bevy_reflect::TypePath;
5use bevy_utils::define_atomic_id;
6use thiserror::Error;
7
8define_atomic_id!(ShaderId);
9
10/// Describes whether or not to perform runtime checks on shaders.
11/// Runtime checks can be enabled for safety at the cost of speed.
12/// By default no runtime checks will be performed.
13///
14/// # Panics
15/// Because no runtime checks are performed for spirv,
16/// enabling `ValidateShader` for spirv will cause a panic
17#[derive(Clone, Debug, Default)]
18pub enum ValidateShader {
19    #[default]
20    /// No runtime checks for soundness (e.g. bound checking) are performed.
21    ///
22    /// This is suitable for trusted shaders, written by your program or dependencies you trust.
23    Disabled,
24    /// Enable's runtime checks for soundness (e.g. bound checking).
25    ///
26    /// While this can have a meaningful impact on performance,
27    /// this setting should *always* be enabled when loading untrusted shaders.
28    /// This might occur if you are creating a shader playground, running user-generated shaders
29    /// (as in `VRChat`), or writing a web browser in Bevy.
30    Enabled,
31}
32
33/// An "unprocessed" shader. It can contain preprocessor directives and imports.
34#[derive(Asset, TypePath, Debug, Clone)]
35pub struct Shader {
36    /// The asset path of the shader.
37    pub path: String,
38    /// The raw source code of the shader.
39    pub source: Source,
40    /// The path from which this shader can be imported by other shaders.
41    pub import_path: ShaderImport,
42    /// The import paths this shader depends on.
43    pub imports: Vec<ShaderImport>,
44    /// Extra imports not specified in the source string.
45    pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
46    /// Any shader defs that should be included when this module is used.
47    pub shader_defs: Vec<ShaderDefVal>,
48    /// Strong handles to this shader's dependencies, to prevent them
49    /// from being immediately dropped if this shader is the only user.
50    pub file_dependencies: Vec<Handle<Shader>>,
51    /// Enable or disable runtime shader validation, trading safety against speed.
52    ///
53    /// Please read the [`ValidateShader`] docs for a discussion of the tradeoffs involved.
54    pub validate_shader: ValidateShader,
55}
56
57impl Shader {
58    fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
59        let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
60
61        let import_path = import_path
62            .map(ShaderImport::Custom)
63            .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
64
65        let imports = imports
66            .into_iter()
67            .map(|import| {
68                if import.import.starts_with('\"') {
69                    let import = import
70                        .import
71                        .chars()
72                        .skip(1)
73                        .take_while(|c| *c != '\"')
74                        .collect();
75                    ShaderImport::AssetPath(import)
76                } else {
77                    ShaderImport::Custom(import.import)
78                }
79            })
80            .collect();
81
82        (import_path, imports)
83    }
84
85    /// Creates a new WGSL shader.
86    pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
87        let source = source.into();
88        let path = path.into();
89        let (import_path, imports) = Shader::preprocess(&source, &path);
90        Shader {
91            path,
92            imports,
93            import_path,
94            source: Source::Wgsl(source),
95            additional_imports: Default::default(),
96            shader_defs: Default::default(),
97            file_dependencies: Default::default(),
98            validate_shader: ValidateShader::Disabled,
99        }
100    }
101
102    /// Creates a new WGSL shader with some given shader defs.
103    pub fn from_wgsl_with_defs(
104        source: impl Into<Cow<'static, str>>,
105        path: impl Into<String>,
106        shader_defs: Vec<ShaderDefVal>,
107    ) -> Shader {
108        Self {
109            shader_defs,
110            ..Self::from_wgsl(source, path)
111        }
112    }
113
114    /// Creates a new GLSL shader.
115    pub fn from_glsl(
116        source: impl Into<Cow<'static, str>>,
117        stage: naga::ShaderStage,
118        path: impl Into<String>,
119    ) -> Shader {
120        let source = source.into();
121        let path = path.into();
122        let (import_path, imports) = Shader::preprocess(&source, &path);
123        Shader {
124            path,
125            imports,
126            import_path,
127            source: Source::Glsl(source, stage),
128            additional_imports: Default::default(),
129            shader_defs: Default::default(),
130            file_dependencies: Default::default(),
131            validate_shader: ValidateShader::Disabled,
132        }
133    }
134
135    /// Creates a new SPIR-V shader.
136    pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
137        let path = path.into();
138        Shader {
139            path: path.clone(),
140            imports: Vec::new(),
141            import_path: ShaderImport::AssetPath(path),
142            source: Source::SpirV(source.into()),
143            additional_imports: Default::default(),
144            shader_defs: Default::default(),
145            file_dependencies: Default::default(),
146            validate_shader: ValidateShader::Disabled,
147        }
148    }
149
150    /// Creates a new Wesl shader.
151    #[cfg(feature = "shader_format_wesl")]
152    pub fn from_wesl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
153        let source = source.into();
154        let path = path.into();
155        let (import_path, imports) = Shader::preprocess(&source, &path);
156
157        match import_path {
158            ShaderImport::AssetPath(asset_path) => {
159                // Create the shader import path - always starting with "/"
160                let shader_path = std::path::Path::new("/").join(&asset_path);
161
162                // Convert to a string with forward slashes and without extension
163                let import_path_str = shader_path
164                    .with_extension("")
165                    .to_string_lossy()
166                    .replace('\\', "/");
167
168                let import_path = ShaderImport::AssetPath(import_path_str.to_string());
169
170                Shader {
171                    path,
172                    imports,
173                    import_path,
174                    source: Source::Wesl(source),
175                    additional_imports: Default::default(),
176                    shader_defs: Default::default(),
177                    file_dependencies: Default::default(),
178                    validate_shader: ValidateShader::Disabled,
179                }
180            }
181            ShaderImport::Custom(_) => {
182                panic!("Wesl shaders must be imported from an asset path");
183            }
184        }
185    }
186}
187
188impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
189    fn from(shader: &'a Shader) -> Self {
190        let shader_defs = shader
191            .shader_defs
192            .iter()
193            .map(|def| match def {
194                ShaderDefVal::Bool(name, b) => {
195                    (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
196                }
197                ShaderDefVal::Int(name, i) => {
198                    (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
199                }
200                ShaderDefVal::UInt(name, i) => {
201                    (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
202                }
203            })
204            .collect();
205
206        // It is beyond me why this doesn't just use `shader.import_path.module_name()`.
207        let as_name = match &shader.import_path {
208            ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
209            ShaderImport::Custom(_) => None,
210        };
211
212        naga_oil::compose::ComposableModuleDescriptor {
213            source: shader.source.as_str(),
214            file_path: &shader.path,
215            language: (&shader.source).into(),
216            additional_imports: &shader.additional_imports,
217            shader_defs,
218            as_name,
219        }
220    }
221}
222
223impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
224    fn from(shader: &'a Shader) -> Self {
225        naga_oil::compose::NagaModuleDescriptor {
226            source: shader.source.as_str(),
227            file_path: &shader.path,
228            shader_type: (&shader.source).into(),
229            ..Default::default()
230        }
231    }
232}
233
234/// Raw shader source code.
235#[expect(missing_docs, reason = "The variants are self-explanatory.")]
236#[derive(Debug, Clone)]
237pub enum Source {
238    Wgsl(Cow<'static, str>),
239    Wesl(Cow<'static, str>),
240    Glsl(Cow<'static, str>, naga::ShaderStage),
241    SpirV(Cow<'static, [u8]>),
242    // TODO: consider the following
243    // PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
244    // NagaModule(Module) ... Module impls Serialize/Deserialize
245}
246
247impl Source {
248    /// The underlying source code string, unless it is SPIR-V.
249    pub fn as_str(&self) -> &str {
250        match self {
251            Source::Wgsl(s) | Source::Wesl(s) | Source::Glsl(s, _) => s,
252            Source::SpirV(_) => panic!("spirv not yet implemented"),
253        }
254    }
255}
256
257impl From<&Source> for naga_oil::compose::ShaderLanguage {
258    fn from(value: &Source) -> Self {
259        match value {
260            Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
261            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
262            Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
263            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
264            Source::Glsl(_, _) => panic!(
265                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
266            ),
267            Source::SpirV(_) => panic!("spirv not yet implemented"),
268            Source::Wesl(_) => panic!("wesl not yet implemented"),
269        }
270    }
271}
272
273impl From<&Source> for naga_oil::compose::ShaderType {
274    fn from(value: &Source) -> Self {
275        match value {
276            Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
277            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
278            Source::Glsl(_, shader_stage) => match shader_stage {
279                naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
280                naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
281                naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
282                naga::ShaderStage::Task => panic!("task shaders not yet implemented"),
283                naga::ShaderStage::Mesh => panic!("mesh shaders not yet implemented"),
284                naga::ShaderStage::RayGeneration => {
285                    panic!("ray generation shader not yet implemented")
286                }
287                naga::ShaderStage::Miss => panic!("miss shader not yet implemented"),
288                naga::ShaderStage::AnyHit => panic!("any hit shader not yet implemented"),
289                naga::ShaderStage::ClosestHit => panic!("closest hit shader not yet implemented"),
290            },
291            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
292            Source::Glsl(_, _) => panic!(
293                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
294            ),
295            Source::SpirV(_) => panic!("spirv not yet implemented"),
296            Source::Wesl(_) => panic!("wesl not yet implemented"),
297        }
298    }
299}
300
301/// The [`AssetLoader`] responsible for loading unprocessed shader assets.
302#[derive(Default, TypePath)]
303pub struct ShaderLoader;
304
305/// An error encountered while loading a shader's source.
306#[non_exhaustive]
307#[derive(Debug, Error)]
308#[expect(missing_docs, reason = "The variants are self-explanatory.")]
309pub enum ShaderLoaderError {
310    #[error("Could not load shader: {0}")]
311    Io(#[from] std::io::Error),
312    #[error("Could not parse shader: {0}")]
313    Parse(#[from] alloc::string::FromUtf8Error),
314}
315
316/// Settings for loading shaders.
317#[derive(serde::Serialize, serde::Deserialize, Debug, Default)]
318pub struct ShaderSettings {
319    /// The `#define`s specified for this shader.
320    pub shader_defs: Vec<ShaderDefVal>,
321}
322
323impl AssetLoader for ShaderLoader {
324    type Asset = Shader;
325    type Settings = ShaderSettings;
326    type Error = ShaderLoaderError;
327    async fn load(
328        &self,
329        reader: &mut dyn Reader,
330        settings: &Self::Settings,
331        load_context: &mut LoadContext<'_>,
332    ) -> Result<Shader, Self::Error> {
333        let ext = load_context
334            .path()
335            .path()
336            .extension()
337            .unwrap()
338            .to_str()
339            .unwrap();
340        let path = load_context.path().to_string();
341        // On windows, the path will inconsistently use \ or /.
342        // TODO: remove this once AssetPath forces cross-platform "slash" consistency. See #10511
343        let path = path.replace(std::path::MAIN_SEPARATOR, "/");
344        let mut bytes = Vec::new();
345        reader.read_to_end(&mut bytes).await?;
346        if ext != "wgsl" && !settings.shader_defs.is_empty() {
347            tracing::warn!(
348                "Tried to load a non-wgsl shader with shader defs, this isn't supported: \
349                    The shader defs will be ignored."
350            );
351        }
352        let mut shader = match ext {
353            "spv" => Shader::from_spirv(bytes, load_context.path().path().to_string_lossy()),
354            "wgsl" => Shader::from_wgsl_with_defs(
355                String::from_utf8(bytes)?,
356                path,
357                settings.shader_defs.clone(),
358            ),
359            "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
360            "frag" => {
361                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
362            }
363            "comp" => {
364                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
365            }
366            #[cfg(feature = "shader_format_wesl")]
367            "wesl" => Shader::from_wesl(String::from_utf8(bytes)?, path),
368            _ => panic!("unhandled extension: {ext}"),
369        };
370
371        // collect and store file dependencies
372        for import in &shader.imports {
373            if let ShaderImport::AssetPath(asset_path) = import {
374                shader.file_dependencies.push(load_context.load(asset_path));
375            }
376        }
377        Ok(shader)
378    }
379
380    fn extensions(&self) -> &[&str] {
381        &["spv", "wgsl", "vert", "frag", "comp", "wesl"]
382    }
383}
384
385/// A shader import, described as either an asset path or an import path.
386#[derive(Debug, PartialEq, Eq, Clone, Hash)]
387pub enum ShaderImport {
388    /// An asset path to a shader.
389    AssetPath(String),
390    /// An import path from which a shader may be imported.
391    Custom(String),
392}
393
394impl ShaderImport {
395    /// A name for a shader import.
396    pub fn module_name(&self) -> Cow<'_, String> {
397        match self {
398            ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
399            ShaderImport::Custom(s) => Cow::Borrowed(s),
400        }
401    }
402}
403
404/// A reference to a shader asset.
405#[derive(Default)]
406pub enum ShaderRef {
407    /// Use the "default" shader for the current context.
408    #[default]
409    Default,
410    /// A handle to a shader stored in the [`Assets<Shader>`](bevy_asset::Assets) resource.
411    Handle(Handle<Shader>),
412    /// An asset path leading to a shader.
413    Path(AssetPath<'static>),
414}
415
416impl From<Handle<Shader>> for ShaderRef {
417    fn from(handle: Handle<Shader>) -> Self {
418        Self::Handle(handle)
419    }
420}
421
422impl From<AssetPath<'static>> for ShaderRef {
423    fn from(path: AssetPath<'static>) -> Self {
424        Self::Path(path)
425    }
426}
427
428impl From<&'static str> for ShaderRef {
429    fn from(path: &'static str) -> Self {
430        Self::Path(AssetPath::from(path))
431    }
432}