1use super::ShaderDefVal;
2use alloc::borrow::Cow;
3use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
4use bevy_reflect::TypePath;
5use bevy_utils::define_atomic_id;
6use thiserror::Error;
7
8define_atomic_id!(ShaderId);
9
10#[derive(Clone, Debug, Default)]
18pub enum ValidateShader {
19 #[default]
20 Disabled,
24 Enabled,
31}
32
33#[derive(Asset, TypePath, Debug, Clone)]
35pub struct Shader {
36 pub path: String,
38 pub source: Source,
40 pub import_path: ShaderImport,
42 pub imports: Vec<ShaderImport>,
44 pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
46 pub shader_defs: Vec<ShaderDefVal>,
48 pub file_dependencies: Vec<Handle<Shader>>,
51 pub validate_shader: ValidateShader,
55}
56
57impl Shader {
58 fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
59 let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
60
61 let import_path = import_path
62 .map(ShaderImport::Custom)
63 .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
64
65 let imports = imports
66 .into_iter()
67 .map(|import| {
68 if import.import.starts_with('\"') {
69 let import = import
70 .import
71 .chars()
72 .skip(1)
73 .take_while(|c| *c != '\"')
74 .collect();
75 ShaderImport::AssetPath(import)
76 } else {
77 ShaderImport::Custom(import.import)
78 }
79 })
80 .collect();
81
82 (import_path, imports)
83 }
84
85 pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
87 let source = source.into();
88 let path = path.into();
89 let (import_path, imports) = Shader::preprocess(&source, &path);
90 Shader {
91 path,
92 imports,
93 import_path,
94 source: Source::Wgsl(source),
95 additional_imports: Default::default(),
96 shader_defs: Default::default(),
97 file_dependencies: Default::default(),
98 validate_shader: ValidateShader::Disabled,
99 }
100 }
101
102 pub fn from_wgsl_with_defs(
104 source: impl Into<Cow<'static, str>>,
105 path: impl Into<String>,
106 shader_defs: Vec<ShaderDefVal>,
107 ) -> Shader {
108 Self {
109 shader_defs,
110 ..Self::from_wgsl(source, path)
111 }
112 }
113
114 pub fn from_glsl(
116 source: impl Into<Cow<'static, str>>,
117 stage: naga::ShaderStage,
118 path: impl Into<String>,
119 ) -> Shader {
120 let source = source.into();
121 let path = path.into();
122 let (import_path, imports) = Shader::preprocess(&source, &path);
123 Shader {
124 path,
125 imports,
126 import_path,
127 source: Source::Glsl(source, stage),
128 additional_imports: Default::default(),
129 shader_defs: Default::default(),
130 file_dependencies: Default::default(),
131 validate_shader: ValidateShader::Disabled,
132 }
133 }
134
135 pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
137 let path = path.into();
138 Shader {
139 path: path.clone(),
140 imports: Vec::new(),
141 import_path: ShaderImport::AssetPath(path),
142 source: Source::SpirV(source.into()),
143 additional_imports: Default::default(),
144 shader_defs: Default::default(),
145 file_dependencies: Default::default(),
146 validate_shader: ValidateShader::Disabled,
147 }
148 }
149
150 #[cfg(feature = "shader_format_wesl")]
152 pub fn from_wesl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
153 let source = source.into();
154 let path = path.into();
155 let (import_path, imports) = Shader::preprocess(&source, &path);
156
157 match import_path {
158 ShaderImport::AssetPath(asset_path) => {
159 let shader_path = std::path::Path::new("/").join(&asset_path);
161
162 let import_path_str = shader_path
164 .with_extension("")
165 .to_string_lossy()
166 .replace('\\', "/");
167
168 let import_path = ShaderImport::AssetPath(import_path_str.to_string());
169
170 Shader {
171 path,
172 imports,
173 import_path,
174 source: Source::Wesl(source),
175 additional_imports: Default::default(),
176 shader_defs: Default::default(),
177 file_dependencies: Default::default(),
178 validate_shader: ValidateShader::Disabled,
179 }
180 }
181 ShaderImport::Custom(_) => {
182 panic!("Wesl shaders must be imported from an asset path");
183 }
184 }
185 }
186}
187
188impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
189 fn from(shader: &'a Shader) -> Self {
190 let shader_defs = shader
191 .shader_defs
192 .iter()
193 .map(|def| match def {
194 ShaderDefVal::Bool(name, b) => {
195 (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
196 }
197 ShaderDefVal::Int(name, i) => {
198 (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
199 }
200 ShaderDefVal::UInt(name, i) => {
201 (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
202 }
203 })
204 .collect();
205
206 let as_name = match &shader.import_path {
208 ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
209 ShaderImport::Custom(_) => None,
210 };
211
212 naga_oil::compose::ComposableModuleDescriptor {
213 source: shader.source.as_str(),
214 file_path: &shader.path,
215 language: (&shader.source).into(),
216 additional_imports: &shader.additional_imports,
217 shader_defs,
218 as_name,
219 }
220 }
221}
222
223impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
224 fn from(shader: &'a Shader) -> Self {
225 naga_oil::compose::NagaModuleDescriptor {
226 source: shader.source.as_str(),
227 file_path: &shader.path,
228 shader_type: (&shader.source).into(),
229 ..Default::default()
230 }
231 }
232}
233
234#[expect(missing_docs, reason = "The variants are self-explanatory.")]
236#[derive(Debug, Clone)]
237pub enum Source {
238 Wgsl(Cow<'static, str>),
239 Wesl(Cow<'static, str>),
240 Glsl(Cow<'static, str>, naga::ShaderStage),
241 SpirV(Cow<'static, [u8]>),
242 }
246
247impl Source {
248 pub fn as_str(&self) -> &str {
250 match self {
251 Source::Wgsl(s) | Source::Wesl(s) | Source::Glsl(s, _) => s,
252 Source::SpirV(_) => panic!("spirv not yet implemented"),
253 }
254 }
255}
256
257impl From<&Source> for naga_oil::compose::ShaderLanguage {
258 fn from(value: &Source) -> Self {
259 match value {
260 Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
261 #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
262 Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
263 #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
264 Source::Glsl(_, _) => panic!(
265 "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
266 ),
267 Source::SpirV(_) => panic!("spirv not yet implemented"),
268 Source::Wesl(_) => panic!("wesl not yet implemented"),
269 }
270 }
271}
272
273impl From<&Source> for naga_oil::compose::ShaderType {
274 fn from(value: &Source) -> Self {
275 match value {
276 Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
277 #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
278 Source::Glsl(_, shader_stage) => match shader_stage {
279 naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
280 naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
281 naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
282 naga::ShaderStage::Task => panic!("task shaders not yet implemented"),
283 naga::ShaderStage::Mesh => panic!("mesh shaders not yet implemented"),
284 naga::ShaderStage::RayGeneration => {
285 panic!("ray generation shader not yet implemented")
286 }
287 naga::ShaderStage::Miss => panic!("miss shader not yet implemented"),
288 naga::ShaderStage::AnyHit => panic!("any hit shader not yet implemented"),
289 naga::ShaderStage::ClosestHit => panic!("closest hit shader not yet implemented"),
290 },
291 #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
292 Source::Glsl(_, _) => panic!(
293 "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
294 ),
295 Source::SpirV(_) => panic!("spirv not yet implemented"),
296 Source::Wesl(_) => panic!("wesl not yet implemented"),
297 }
298 }
299}
300
301#[derive(Default, TypePath)]
303pub struct ShaderLoader;
304
305#[non_exhaustive]
307#[derive(Debug, Error)]
308#[expect(missing_docs, reason = "The variants are self-explanatory.")]
309pub enum ShaderLoaderError {
310 #[error("Could not load shader: {0}")]
311 Io(#[from] std::io::Error),
312 #[error("Could not parse shader: {0}")]
313 Parse(#[from] alloc::string::FromUtf8Error),
314}
315
316#[derive(serde::Serialize, serde::Deserialize, Debug, Default)]
318pub struct ShaderSettings {
319 pub shader_defs: Vec<ShaderDefVal>,
321}
322
323impl AssetLoader for ShaderLoader {
324 type Asset = Shader;
325 type Settings = ShaderSettings;
326 type Error = ShaderLoaderError;
327 async fn load(
328 &self,
329 reader: &mut dyn Reader,
330 settings: &Self::Settings,
331 load_context: &mut LoadContext<'_>,
332 ) -> Result<Shader, Self::Error> {
333 let ext = load_context
334 .path()
335 .path()
336 .extension()
337 .unwrap()
338 .to_str()
339 .unwrap();
340 let path = load_context.path().to_string();
341 let path = path.replace(std::path::MAIN_SEPARATOR, "/");
344 let mut bytes = Vec::new();
345 reader.read_to_end(&mut bytes).await?;
346 if ext != "wgsl" && !settings.shader_defs.is_empty() {
347 tracing::warn!(
348 "Tried to load a non-wgsl shader with shader defs, this isn't supported: \
349 The shader defs will be ignored."
350 );
351 }
352 let mut shader = match ext {
353 "spv" => Shader::from_spirv(bytes, load_context.path().path().to_string_lossy()),
354 "wgsl" => Shader::from_wgsl_with_defs(
355 String::from_utf8(bytes)?,
356 path,
357 settings.shader_defs.clone(),
358 ),
359 "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
360 "frag" => {
361 Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
362 }
363 "comp" => {
364 Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
365 }
366 #[cfg(feature = "shader_format_wesl")]
367 "wesl" => Shader::from_wesl(String::from_utf8(bytes)?, path),
368 _ => panic!("unhandled extension: {ext}"),
369 };
370
371 for import in &shader.imports {
373 if let ShaderImport::AssetPath(asset_path) = import {
374 shader.file_dependencies.push(load_context.load(asset_path));
375 }
376 }
377 Ok(shader)
378 }
379
380 fn extensions(&self) -> &[&str] {
381 &["spv", "wgsl", "vert", "frag", "comp", "wesl"]
382 }
383}
384
385#[derive(Debug, PartialEq, Eq, Clone, Hash)]
387pub enum ShaderImport {
388 AssetPath(String),
390 Custom(String),
392}
393
394impl ShaderImport {
395 pub fn module_name(&self) -> Cow<'_, String> {
397 match self {
398 ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
399 ShaderImport::Custom(s) => Cow::Borrowed(s),
400 }
401 }
402}
403
404#[derive(Default)]
406pub enum ShaderRef {
407 #[default]
409 Default,
410 Handle(Handle<Shader>),
412 Path(AssetPath<'static>),
414}
415
416impl From<Handle<Shader>> for ShaderRef {
417 fn from(handle: Handle<Shader>) -> Self {
418 Self::Handle(handle)
419 }
420}
421
422impl From<AssetPath<'static>> for ShaderRef {
423 fn from(path: AssetPath<'static>) -> Self {
424 Self::Path(path)
425 }
426}
427
428impl From<&'static str> for ShaderRef {
429 fn from(path: &'static str) -> Self {
430 Self::Path(AssetPath::from(path))
431 }
432}