use crate::arena::Handle;
use std::{fmt::Display, num::NonZeroU32, ops};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Alignment(NonZeroU32);
impl Alignment {
pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
pub const MIN_UNIFORM: Self = Self::SIXTEEN;
pub const fn new(n: u32) -> Option<Self> {
if n.is_power_of_two() {
Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
} else {
None
}
}
pub fn from_width(width: u8) -> Self {
Self::new(width as u32).unwrap()
}
pub const fn is_aligned(&self, n: u32) -> bool {
n & (self.0.get() - 1) == 0
}
pub const fn round_up(&self, n: u32) -> u32 {
let mask = self.0.get() - 1;
(n + mask) & !mask
}
}
impl Display for Alignment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.get().fmt(f)
}
}
impl ops::Mul<u32> for Alignment {
type Output = u32;
fn mul(self, rhs: u32) -> Self::Output {
self.0.get() * rhs
}
}
impl ops::Mul for Alignment {
type Output = Alignment;
fn mul(self, rhs: Alignment) -> Self::Output {
Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
}
}
impl From<crate::VectorSize> for Alignment {
fn from(size: crate::VectorSize) -> Self {
match size {
crate::VectorSize::Bi => Alignment::TWO,
crate::VectorSize::Tri => Alignment::FOUR,
crate::VectorSize::Quad => Alignment::FOUR,
}
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct TypeLayout {
pub size: u32,
pub alignment: Alignment,
}
impl TypeLayout {
pub const fn to_stride(&self) -> u32 {
self.alignment.round_up(self.size)
}
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Layouter {
layouts: Vec<TypeLayout>,
}
impl ops::Index<Handle<crate::Type>> for Layouter {
type Output = TypeLayout;
fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
&self.layouts[handle.index()]
}
}
#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
pub enum LayoutErrorInner {
#[error("Array element type {0:?} doesn't exist")]
InvalidArrayElementType(Handle<crate::Type>),
#[error("Struct member[{0}] type {1:?} doesn't exist")]
InvalidStructMemberType(u32, Handle<crate::Type>),
#[error("Type width must be a power of two")]
NonPowerOfTwoWidth,
}
#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
#[error("Error laying out type {ty:?}: {inner}")]
pub struct LayoutError {
pub ty: Handle<crate::Type>,
pub inner: LayoutErrorInner,
}
impl LayoutErrorInner {
const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
LayoutError { ty, inner: self }
}
}
impl Layouter {
pub fn clear(&mut self) {
self.layouts.clear();
}
#[allow(clippy::or_fun_call)]
pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
use crate::TypeInner as Ti;
for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
let size = ty.inner.size(gctx);
let layout = match ty.inner {
Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
let alignment = Alignment::new(scalar.width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout { size, alignment }
}
Ti::Vector {
size: vec_size,
scalar,
} => {
let alignment = Alignment::new(scalar.width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout {
size,
alignment: Alignment::from(vec_size) * alignment,
}
}
Ti::Matrix {
columns: _,
rows,
scalar,
} => {
let alignment = Alignment::new(scalar.width as u32)
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
TypeLayout {
size,
alignment: Alignment::from(rows) * alignment,
}
}
Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
size,
alignment: Alignment::ONE,
},
Ti::Array {
base,
stride: _,
size: _,
} => TypeLayout {
size,
alignment: if base < ty_handle {
self[base].alignment
} else {
return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
},
},
Ti::Struct { span, ref members } => {
let mut alignment = Alignment::ONE;
for (index, member) in members.iter().enumerate() {
alignment = if member.ty < ty_handle {
alignment.max(self[member.ty].alignment)
} else {
return Err(LayoutErrorInner::InvalidStructMemberType(
index as u32,
member.ty,
)
.with(ty_handle));
};
}
TypeLayout {
size: span,
alignment,
}
}
Ti::Image { .. }
| Ti::Sampler { .. }
| Ti::AccelerationStructure
| Ti::RayQuery
| Ti::BindingArray { .. } => TypeLayout {
size,
alignment: Alignment::ONE,
},
};
debug_assert!(size <= layout.size);
self.layouts.push(layout);
}
Ok(())
}
}