use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use std::fmt::{Error as FmtError, Write};
mod keywords;
pub mod sampler;
mod writer;
pub use writer::Writer;
pub type Slot = u8;
pub type InlineSamplerIndex = u8;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum BindSamplerTarget {
Resource(Slot),
Inline(InlineSamplerIndex),
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct BindTarget {
pub buffer: Option<Slot>,
pub texture: Option<Slot>,
pub sampler: Option<BindSamplerTarget>,
pub binding_array_size: Option<u32>,
pub mutable: bool,
}
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct EntryPointResources {
pub resources: BindingMap,
pub push_constant_buffer: Option<Slot>,
pub sizes_buffer: Option<Slot>,
}
pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
enum ResolvedBinding {
BuiltIn(crate::BuiltIn),
Attribute(u32),
Color {
location: u32,
second_blend_source: bool,
},
User {
prefix: &'static str,
index: u32,
interpolation: Option<ResolvedInterpolation>,
},
Resource(BindTarget),
}
#[derive(Copy, Clone)]
enum ResolvedInterpolation {
CenterPerspective,
CenterNoPerspective,
CentroidPerspective,
CentroidNoPerspective,
SamplePerspective,
SampleNoPerspective,
Flat,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Format(#[from] FmtError),
#[error("bind target {0:?} is empty")]
UnimplementedBindTarget(BindTarget),
#[error("composing of {0:?} is not implemented yet")]
UnsupportedCompose(Handle<crate::Type>),
#[error("operation {0:?} is not implemented yet")]
UnsupportedBinaryOp(crate::BinaryOperator),
#[error("standard function '{0}' is not implemented yet")]
UnsupportedCall(String),
#[error("feature '{0}' is not implemented yet")]
FeatureNotImplemented(String),
#[error("internal naga error: module should not have validated: {0}")]
GenericValidation(String),
#[error("BuiltIn {0:?} is not supported")]
UnsupportedBuiltIn(crate::BuiltIn),
#[error("capability {0:?} is not supported")]
CapabilityNotSupported(crate::valid::Capabilities),
#[error("attribute '{0}' is not supported for target MSL version")]
UnsupportedAttribute(String),
#[error("function '{0}' is not supported for target MSL version")]
UnsupportedFunction(String),
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
UnsupportedWriteableStorageTexture(crate::ShaderStage),
#[error("can not use read-write storage textures prior to MSL 1.2")]
UnsupportedRWStorageTexture,
#[error("array of '{0}' is not supported for target MSL version")]
UnsupportedArrayOf(String),
#[error("array of type '{0:?}' is not supported")]
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error("overrides should not be present at this stage")]
Override,
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub enum EntryPointError {
#[error("global '{0}' doesn't have a binding")]
MissingBinding(String),
#[error("mapping of {0:?} is missing")]
MissingBindTarget(crate::ResourceBinding),
#[error("mapping for push constants is missing")]
MissingPushConstants,
#[error("mapping for sizes buffer is missing")]
MissingSizesBuffer,
}
#[derive(Clone, Copy, Debug)]
enum LocationMode {
VertexInput,
VertexOutput,
FragmentInput,
FragmentOutput,
Uniform,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Options {
pub lang_version: (u8, u8),
pub per_entry_point_map: EntryPointResourceMap,
pub inline_samplers: Vec<sampler::InlineSampler>,
pub spirv_cross_compatibility: bool,
pub fake_missing_bindings: bool,
#[cfg_attr(feature = "deserialize", serde(default))]
pub bounds_check_policies: index::BoundsCheckPolicies,
pub zero_initialize_workgroup_memory: bool,
}
impl Default for Options {
fn default() -> Self {
Options {
lang_version: (1, 0),
per_entry_point_map: EntryPointResourceMap::default(),
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
bounds_check_policies: index::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: true,
}
}
}
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
pub allow_and_force_point_size: bool,
}
impl Options {
fn resolve_local_binding(
&self,
binding: &crate::Binding,
mode: LocationMode,
) -> Result<ResolvedBinding, Error> {
match *binding {
crate::Binding::BuiltIn(mut built_in) => {
match built_in {
crate::BuiltIn::Position { ref mut invariant } => {
if *invariant && self.lang_version < (2, 1) {
return Err(Error::UnsupportedAttribute("invariant".to_string()));
}
if !matches!(mode, LocationMode::VertexOutput) {
*invariant = false;
}
}
crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => {
return Err(Error::UnsupportedAttribute("base_instance".to_string()));
}
crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => {
return Err(Error::UnsupportedAttribute("instance_id".to_string()));
}
crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => {
return Err(Error::UnsupportedAttribute("primitive_id".to_string()));
}
_ => {}
}
Ok(ResolvedBinding::BuiltIn(built_in))
}
crate::Binding::Location {
location,
interpolation,
sampling,
second_blend_source,
} => match mode {
LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
LocationMode::FragmentOutput => {
if second_blend_source && self.lang_version < (1, 2) {
return Err(Error::UnsupportedAttribute(
"second_blend_source".to_string(),
));
}
Ok(ResolvedBinding::Color {
location,
second_blend_source,
})
}
LocationMode::VertexOutput | LocationMode::FragmentInput => {
Ok(ResolvedBinding::User {
prefix: if self.spirv_cross_compatibility {
"locn"
} else {
"loc"
},
index: location,
interpolation: {
let interpolation = interpolation.unwrap();
let sampling = sampling.unwrap_or(crate::Sampling::Center);
Some(ResolvedInterpolation::from_binding(interpolation, sampling))
},
})
}
LocationMode::Uniform => Err(Error::GenericValidation(format!(
"Unexpected Binding::Location({}) for the Uniform mode",
location
))),
},
}
}
fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
self.per_entry_point_map.get(&ep.name)
}
fn get_resource_binding_target(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Option<&BindTarget> {
self.get_entry_point_resources(ep)
.and_then(|res| res.resources.get(res_binding))
}
fn resolve_resource_binding(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Result<ResolvedBinding, EntryPointError> {
let target = self.get_resource_binding_target(ep, res_binding);
match target {
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
}
}
fn resolve_push_constants(
&self,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.push_constant_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
..Default::default()
})),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingPushConstants),
}
}
fn resolve_sizes_buffer(
&self,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.sizes_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
..Default::default()
})),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingSizesBuffer),
}
}
}
impl ResolvedBinding {
fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
match *self {
Self::Resource(BindTarget {
sampler: Some(BindSamplerTarget::Inline(index)),
..
}) => Some(&options.inline_samplers[index as usize]),
_ => None,
}
}
const fn as_bind_target(&self) -> Option<&BindTarget> {
match *self {
Self::Resource(ref target) => Some(target),
_ => None,
}
}
fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
write!(out, " [[")?;
match *self {
Self::BuiltIn(built_in) => {
use crate::BuiltIn as Bi;
let name = match built_in {
Bi::Position { invariant: false } => "position",
Bi::Position { invariant: true } => "position, invariant",
Bi::BaseInstance => "base_instance",
Bi::BaseVertex => "base_vertex",
Bi::ClipDistance => "clip_distance",
Bi::InstanceIndex => "instance_id",
Bi::PointSize => "point_size",
Bi::VertexIndex => "vertex_id",
Bi::FragDepth => "depth(any)",
Bi::PointCoord => "point_coord",
Bi::FrontFacing => "front_facing",
Bi::PrimitiveIndex => "primitive_id",
Bi::SampleIndex => "sample_id",
Bi::SampleMask => "sample_mask",
Bi::GlobalInvocationId => "thread_position_in_grid",
Bi::LocalInvocationId => "thread_position_in_threadgroup",
Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
Bi::NumWorkGroups => "threadgroups_per_grid",
Bi::NumSubgroups => "simdgroups_per_threadgroup",
Bi::SubgroupId => "simdgroup_index_in_threadgroup",
Bi::SubgroupSize => "threads_per_simdgroup",
Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
Bi::CullDistance | Bi::ViewIndex => {
return Err(Error::UnsupportedBuiltIn(built_in))
}
};
write!(out, "{name}")?;
}
Self::Attribute(index) => write!(out, "attribute({index})")?,
Self::Color {
location,
second_blend_source,
} => {
if second_blend_source {
write!(out, "color({location}) index(1)")?
} else {
write!(out, "color({location})")?
}
}
Self::User {
prefix,
index,
interpolation,
} => {
write!(out, "user({prefix}{index})")?;
if let Some(interpolation) = interpolation {
write!(out, ", ")?;
interpolation.try_fmt(out)?;
}
}
Self::Resource(ref target) => {
if let Some(id) = target.buffer {
write!(out, "buffer({id})")?;
} else if let Some(id) = target.texture {
write!(out, "texture({id})")?;
} else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
write!(out, "sampler({id})")?;
} else {
return Err(Error::UnimplementedBindTarget(target.clone()));
}
}
}
write!(out, "]]")?;
Ok(())
}
}
impl ResolvedInterpolation {
const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
use crate::Interpolation as I;
use crate::Sampling as S;
match (interpolation, sampling) {
(I::Perspective, S::Center) => Self::CenterPerspective,
(I::Perspective, S::Centroid) => Self::CentroidPerspective,
(I::Perspective, S::Sample) => Self::SamplePerspective,
(I::Linear, S::Center) => Self::CenterNoPerspective,
(I::Linear, S::Centroid) => Self::CentroidNoPerspective,
(I::Linear, S::Sample) => Self::SampleNoPerspective,
(I::Flat, _) => Self::Flat,
}
}
fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
let identifier = match self {
Self::CenterPerspective => "center_perspective",
Self::CenterNoPerspective => "center_no_perspective",
Self::CentroidPerspective => "centroid_perspective",
Self::CentroidNoPerspective => "centroid_no_perspective",
Self::SamplePerspective => "sample_perspective",
Self::SampleNoPerspective => "sample_no_perspective",
Self::Flat => "flat",
};
out.write_str(identifier)?;
Ok(())
}
}
pub struct TranslationInfo {
pub entry_point_names: Vec<Result<String, EntryPointError>>,
}
pub fn write_string(
module: &crate::Module,
info: &ModuleInfo,
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<(String, TranslationInfo), Error> {
let mut w = writer::Writer::new(String::new());
let info = w.write(module, info, options, pipeline_options)?;
Ok((w.finish(), info))
}
#[test]
fn test_error_size() {
use std::mem::size_of;
assert_eq!(size_of::<Error>(), 32);
}
impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Self::InclusiveOr => "fetch_or",
Self::ExclusiveOr => "fetch_xor",
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
}
}