1mod conv;
111mod help;
112mod keywords;
113mod ray;
114mod storage;
115mod writer;
116
117use alloc::{string::String, vec::Vec};
118use core::fmt::Error as FmtError;
119
120use thiserror::Error;
121
122use crate::{back, ir, proc};
123
124#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
125#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
126#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
127pub struct BindTarget {
128    pub space: u8,
129    pub register: u32,
133    pub binding_array_size: Option<u32>,
135    pub dynamic_storage_buffer_offsets_index: Option<u32>,
137    #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
141    pub restrict_indexing: bool,
142}
143
144#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
145#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
146#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
147pub struct OffsetsBindTarget {
149    pub space: u8,
150    pub register: u32,
151    pub size: u32,
152}
153
154#[cfg(any(feature = "serialize", feature = "deserialize"))]
155#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
156#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
157struct BindingMapSerialization {
158    resource_binding: crate::ResourceBinding,
159    bind_target: BindTarget,
160}
161
162#[cfg(feature = "deserialize")]
163fn deserialize_binding_map<'de, D>(deserializer: D) -> Result<BindingMap, D::Error>
164where
165    D: serde::Deserializer<'de>,
166{
167    use serde::Deserialize;
168
169    let vec = Vec::<BindingMapSerialization>::deserialize(deserializer)?;
170    let mut map = BindingMap::default();
171    for item in vec {
172        map.insert(item.resource_binding, item.bind_target);
173    }
174    Ok(map)
175}
176
177pub type BindingMap = alloc::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
179
180#[allow(non_snake_case, non_camel_case_types)]
182#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
183#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
184#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
185pub enum ShaderModel {
186    V5_0,
187    V5_1,
188    V6_0,
189    V6_1,
190    V6_2,
191    V6_3,
192    V6_4,
193    V6_5,
194    V6_6,
195    V6_7,
196}
197
198impl ShaderModel {
199    pub const fn to_str(self) -> &'static str {
200        match self {
201            Self::V5_0 => "5_0",
202            Self::V5_1 => "5_1",
203            Self::V6_0 => "6_0",
204            Self::V6_1 => "6_1",
205            Self::V6_2 => "6_2",
206            Self::V6_3 => "6_3",
207            Self::V6_4 => "6_4",
208            Self::V6_5 => "6_5",
209            Self::V6_6 => "6_6",
210            Self::V6_7 => "6_7",
211        }
212    }
213}
214
215impl crate::ShaderStage {
216    pub const fn to_hlsl_str(self) -> &'static str {
217        match self {
218            Self::Vertex => "vs",
219            Self::Fragment => "ps",
220            Self::Compute => "cs",
221            Self::Task | Self::Mesh => unreachable!(),
222        }
223    }
224}
225
226impl crate::ImageDimension {
227    const fn to_hlsl_str(self) -> &'static str {
228        match self {
229            Self::D1 => "1D",
230            Self::D2 => "2D",
231            Self::D3 => "3D",
232            Self::Cube => "Cube",
233        }
234    }
235}
236
237#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
238#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
239#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
240pub struct SamplerIndexBufferKey {
241    pub group: u32,
242}
243
244#[derive(Clone, Debug, Hash, PartialEq, Eq)]
245#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
246#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
247#[cfg_attr(feature = "deserialize", serde(default))]
248pub struct SamplerHeapBindTargets {
249    pub standard_samplers: BindTarget,
250    pub comparison_samplers: BindTarget,
251}
252
253impl Default for SamplerHeapBindTargets {
254    fn default() -> Self {
255        Self {
256            standard_samplers: BindTarget {
257                space: 0,
258                register: 0,
259                binding_array_size: None,
260                dynamic_storage_buffer_offsets_index: None,
261                restrict_indexing: false,
262            },
263            comparison_samplers: BindTarget {
264                space: 1,
265                register: 0,
266                binding_array_size: None,
267                dynamic_storage_buffer_offsets_index: None,
268                restrict_indexing: false,
269            },
270        }
271    }
272}
273
274#[cfg(any(feature = "serialize", feature = "deserialize"))]
275#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
276#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
277struct SamplerIndexBufferBindingSerialization {
278    group: u32,
279    bind_target: BindTarget,
280}
281
282#[cfg(feature = "deserialize")]
283fn deserialize_sampler_index_buffer_bindings<'de, D>(
284    deserializer: D,
285) -> Result<SamplerIndexBufferBindingMap, D::Error>
286where
287    D: serde::Deserializer<'de>,
288{
289    use serde::Deserialize;
290
291    let vec = Vec::<SamplerIndexBufferBindingSerialization>::deserialize(deserializer)?;
292    let mut map = SamplerIndexBufferBindingMap::default();
293    for item in vec {
294        map.insert(
295            SamplerIndexBufferKey { group: item.group },
296            item.bind_target,
297        );
298    }
299    Ok(map)
300}
301
302pub type SamplerIndexBufferBindingMap =
304    alloc::collections::BTreeMap<SamplerIndexBufferKey, BindTarget>;
305
306#[cfg(any(feature = "serialize", feature = "deserialize"))]
307#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
308#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
309struct DynamicStorageBufferOffsetTargetSerialization {
310    index: u32,
311    bind_target: OffsetsBindTarget,
312}
313
314#[cfg(feature = "deserialize")]
315fn deserialize_storage_buffer_offsets<'de, D>(
316    deserializer: D,
317) -> Result<DynamicStorageBufferOffsetsTargets, D::Error>
318where
319    D: serde::Deserializer<'de>,
320{
321    use serde::Deserialize;
322
323    let vec = Vec::<DynamicStorageBufferOffsetTargetSerialization>::deserialize(deserializer)?;
324    let mut map = DynamicStorageBufferOffsetsTargets::default();
325    for item in vec {
326        map.insert(item.index, item.bind_target);
327    }
328    Ok(map)
329}
330
331pub type DynamicStorageBufferOffsetsTargets = alloc::collections::BTreeMap<u32, OffsetsBindTarget>;
332
333type BackendResult = Result<(), Error>;
335
336#[derive(Clone, Debug, PartialEq, thiserror::Error)]
337#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
338#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
339pub enum EntryPointError {
340    #[error("mapping of {0:?} is missing")]
341    MissingBinding(crate::ResourceBinding),
342}
343
344#[derive(Clone, Debug, Hash, PartialEq, Eq)]
346#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
347#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
348#[cfg_attr(feature = "deserialize", serde(default))]
349pub struct Options {
350    pub shader_model: ShaderModel,
352    #[cfg_attr(
354        feature = "deserialize",
355        serde(deserialize_with = "deserialize_binding_map")
356    )]
357    pub binding_map: BindingMap,
358    pub fake_missing_bindings: bool,
360    pub special_constants_binding: Option<BindTarget>,
363    pub push_constants_target: Option<BindTarget>,
365    pub sampler_heap_target: SamplerHeapBindTargets,
367    #[cfg_attr(
369        feature = "deserialize",
370        serde(deserialize_with = "deserialize_sampler_index_buffer_bindings")
371    )]
372    pub sampler_buffer_binding_map: SamplerIndexBufferBindingMap,
373    #[cfg_attr(
375        feature = "deserialize",
376        serde(deserialize_with = "deserialize_storage_buffer_offsets")
377    )]
378    pub dynamic_storage_buffer_offsets_targets: DynamicStorageBufferOffsetsTargets,
379    pub zero_initialize_workgroup_memory: bool,
381    pub restrict_indexing: bool,
383    pub force_loop_bounding: bool,
386}
387
388impl Default for Options {
389    fn default() -> Self {
390        Options {
391            shader_model: ShaderModel::V5_1,
392            binding_map: BindingMap::default(),
393            fake_missing_bindings: true,
394            special_constants_binding: None,
395            sampler_heap_target: SamplerHeapBindTargets::default(),
396            sampler_buffer_binding_map: alloc::collections::BTreeMap::default(),
397            push_constants_target: None,
398            dynamic_storage_buffer_offsets_targets: alloc::collections::BTreeMap::new(),
399            zero_initialize_workgroup_memory: true,
400            restrict_indexing: true,
401            force_loop_bounding: true,
402        }
403    }
404}
405
406impl Options {
407    fn resolve_resource_binding(
408        &self,
409        res_binding: &crate::ResourceBinding,
410    ) -> Result<BindTarget, EntryPointError> {
411        match self.binding_map.get(res_binding) {
412            Some(target) => Ok(*target),
413            None if self.fake_missing_bindings => Ok(BindTarget {
414                space: res_binding.group as u8,
415                register: res_binding.binding,
416                binding_array_size: None,
417                dynamic_storage_buffer_offsets_index: None,
418                restrict_indexing: false,
419            }),
420            None => Err(EntryPointError::MissingBinding(*res_binding)),
421        }
422    }
423}
424
425#[derive(Default)]
427pub struct ReflectionInfo {
428    pub entry_point_names: Vec<Result<String, EntryPointError>>,
435}
436
437#[derive(Debug, Default, Clone)]
439#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
440#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
441#[cfg_attr(feature = "deserialize", serde(default))]
442pub struct PipelineOptions {
443    pub entry_point: Option<(ir::ShaderStage, String)>,
451}
452
453#[derive(Error, Debug)]
454pub enum Error {
455    #[error(transparent)]
456    IoError(#[from] FmtError),
457    #[error("A scalar with an unsupported width was requested: {0:?}")]
458    UnsupportedScalar(crate::Scalar),
459    #[error("{0}")]
460    Unimplemented(String), #[error("{0}")]
462    Custom(String),
463    #[error("overrides should not be present at this stage")]
464    Override,
465    #[error(transparent)]
466    ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
467    #[error("entry point with stage {0:?} and name '{1}' not found")]
468    EntryPointNotFound(ir::ShaderStage, String),
469}
470
471#[derive(PartialEq, Eq, Hash)]
472enum WrappedType {
473    ZeroValue(help::WrappedZeroValue),
474    ArrayLength(help::WrappedArrayLength),
475    ImageSample(help::WrappedImageSample),
476    ImageQuery(help::WrappedImageQuery),
477    ImageLoadScalar(crate::Scalar),
478    Constructor(help::WrappedConstructor),
479    StructMatrixAccess(help::WrappedStructMatrixAccess),
480    MatCx2(help::WrappedMatCx2),
481    Math(help::WrappedMath),
482    UnaryOp(help::WrappedUnaryOp),
483    BinaryOp(help::WrappedBinaryOp),
484    Cast(help::WrappedCast),
485}
486
487#[derive(Default)]
488struct Wrapped {
489    types: crate::FastHashSet<WrappedType>,
490    sampler_heaps: bool,
492    sampler_index_buffers: crate::FastHashMap<SamplerIndexBufferKey, String>,
494}
495
496impl Wrapped {
497    fn insert(&mut self, r#type: WrappedType) -> bool {
498        self.types.insert(r#type)
499    }
500
501    fn clear(&mut self) {
502        self.types.clear();
503    }
504}
505
506pub struct FragmentEntryPoint<'a> {
515    module: &'a crate::Module,
516    func: &'a crate::Function,
517}
518
519impl<'a> FragmentEntryPoint<'a> {
520    pub fn new(module: &'a crate::Module, ep_name: &'a str) -> Option<Self> {
523        module
524            .entry_points
525            .iter()
526            .find(|ep| ep.name == ep_name)
527            .filter(|ep| ep.stage == crate::ShaderStage::Fragment)
528            .map(|ep| Self {
529                module,
530                func: &ep.function,
531            })
532    }
533}
534
535pub struct Writer<'a, W> {
536    out: W,
537    names: crate::FastHashMap<proc::NameKey, String>,
538    namer: proc::Namer,
539    options: &'a Options,
541    pipeline_options: &'a PipelineOptions,
543    entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
545    named_expressions: crate::NamedExpressions,
547    wrapped: Wrapped,
548    written_committed_intersection: bool,
549    written_candidate_intersection: bool,
550    continue_ctx: back::continue_forward::ContinueCtx,
551
552    temp_access_chain: Vec<storage::SubAccess>,
570    need_bake_expressions: back::NeedBakeExpressions,
571}