use crate::{
attributes::{impl_custom_attribute_methods, CustomAttributes},
NamedField, UnnamedField,
};
use bevy_utils::HashMap;
use core::slice::Iter;
use alloc::sync::Arc;
use derive_more::derive::{Display, Error};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum VariantType {
Struct,
Tuple,
Unit,
}
#[derive(Debug, Error, Display)]
pub enum VariantInfoError {
#[display("variant type mismatch: expected {expected:?}, received {received:?}")]
TypeMismatch {
expected: VariantType,
received: VariantType,
},
}
#[derive(Clone, Debug)]
pub enum VariantInfo {
Struct(StructVariantInfo),
Tuple(TupleVariantInfo),
Unit(UnitVariantInfo),
}
impl VariantInfo {
pub fn name(&self) -> &'static str {
match self {
Self::Struct(info) => info.name(),
Self::Tuple(info) => info.name(),
Self::Unit(info) => info.name(),
}
}
#[cfg(feature = "documentation")]
pub fn docs(&self) -> Option<&str> {
match self {
Self::Struct(info) => info.docs(),
Self::Tuple(info) => info.docs(),
Self::Unit(info) => info.docs(),
}
}
pub fn variant_type(&self) -> VariantType {
match self {
Self::Struct(_) => VariantType::Struct,
Self::Tuple(_) => VariantType::Tuple,
Self::Unit(_) => VariantType::Unit,
}
}
impl_custom_attribute_methods!(
self,
match self {
Self::Struct(info) => info.custom_attributes(),
Self::Tuple(info) => info.custom_attributes(),
Self::Unit(info) => info.custom_attributes(),
},
"variant"
);
}
macro_rules! impl_cast_method {
($name:ident : $kind:ident => $info:ident) => {
#[doc = concat!("Attempts a cast to [`", stringify!($info), "`].")]
#[doc = concat!("\n\nReturns an error if `self` is not [`VariantInfo::", stringify!($kind), "`].")]
pub fn $name(&self) -> Result<&$info, VariantInfoError> {
match self {
Self::$kind(info) => Ok(info),
_ => Err(VariantInfoError::TypeMismatch {
expected: VariantType::$kind,
received: self.variant_type(),
}),
}
}
};
}
impl VariantInfo {
impl_cast_method!(as_struct_variant: Struct => StructVariantInfo);
impl_cast_method!(as_tuple_variant: Tuple => TupleVariantInfo);
impl_cast_method!(as_unit_variant: Unit => UnitVariantInfo);
}
#[derive(Clone, Debug)]
pub struct StructVariantInfo {
name: &'static str,
fields: Box<[NamedField]>,
field_names: Box<[&'static str]>,
field_indices: HashMap<&'static str, usize>,
custom_attributes: Arc<CustomAttributes>,
#[cfg(feature = "documentation")]
docs: Option<&'static str>,
}
impl StructVariantInfo {
pub fn new(name: &'static str, fields: &[NamedField]) -> Self {
let field_indices = Self::collect_field_indices(fields);
let field_names = fields.iter().map(NamedField::name).collect();
Self {
name,
fields: fields.to_vec().into_boxed_slice(),
field_names,
field_indices,
custom_attributes: Arc::new(CustomAttributes::default()),
#[cfg(feature = "documentation")]
docs: None,
}
}
#[cfg(feature = "documentation")]
pub fn with_docs(self, docs: Option<&'static str>) -> Self {
Self { docs, ..self }
}
pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
Self {
custom_attributes: Arc::new(custom_attributes),
..self
}
}
pub fn name(&self) -> &'static str {
self.name
}
pub fn field_names(&self) -> &[&'static str] {
&self.field_names
}
pub fn field(&self, name: &str) -> Option<&NamedField> {
self.field_indices
.get(name)
.map(|index| &self.fields[*index])
}
pub fn field_at(&self, index: usize) -> Option<&NamedField> {
self.fields.get(index)
}
pub fn index_of(&self, name: &str) -> Option<usize> {
self.field_indices.get(name).copied()
}
pub fn iter(&self) -> Iter<'_, NamedField> {
self.fields.iter()
}
pub fn field_len(&self) -> usize {
self.fields.len()
}
fn collect_field_indices(fields: &[NamedField]) -> HashMap<&'static str, usize> {
fields
.iter()
.enumerate()
.map(|(index, field)| (field.name(), index))
.collect()
}
#[cfg(feature = "documentation")]
pub fn docs(&self) -> Option<&'static str> {
self.docs
}
impl_custom_attribute_methods!(self.custom_attributes, "variant");
}
#[derive(Clone, Debug)]
pub struct TupleVariantInfo {
name: &'static str,
fields: Box<[UnnamedField]>,
custom_attributes: Arc<CustomAttributes>,
#[cfg(feature = "documentation")]
docs: Option<&'static str>,
}
impl TupleVariantInfo {
pub fn new(name: &'static str, fields: &[UnnamedField]) -> Self {
Self {
name,
fields: fields.to_vec().into_boxed_slice(),
custom_attributes: Arc::new(CustomAttributes::default()),
#[cfg(feature = "documentation")]
docs: None,
}
}
#[cfg(feature = "documentation")]
pub fn with_docs(self, docs: Option<&'static str>) -> Self {
Self { docs, ..self }
}
pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
Self {
custom_attributes: Arc::new(custom_attributes),
..self
}
}
pub fn name(&self) -> &'static str {
self.name
}
pub fn field_at(&self, index: usize) -> Option<&UnnamedField> {
self.fields.get(index)
}
pub fn iter(&self) -> Iter<'_, UnnamedField> {
self.fields.iter()
}
pub fn field_len(&self) -> usize {
self.fields.len()
}
#[cfg(feature = "documentation")]
pub fn docs(&self) -> Option<&'static str> {
self.docs
}
impl_custom_attribute_methods!(self.custom_attributes, "variant");
}
#[derive(Clone, Debug)]
pub struct UnitVariantInfo {
name: &'static str,
custom_attributes: Arc<CustomAttributes>,
#[cfg(feature = "documentation")]
docs: Option<&'static str>,
}
impl UnitVariantInfo {
pub fn new(name: &'static str) -> Self {
Self {
name,
custom_attributes: Arc::new(CustomAttributes::default()),
#[cfg(feature = "documentation")]
docs: None,
}
}
#[cfg(feature = "documentation")]
pub fn with_docs(self, docs: Option<&'static str>) -> Self {
Self { docs, ..self }
}
pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
Self {
custom_attributes: Arc::new(custom_attributes),
..self
}
}
pub fn name(&self) -> &'static str {
self.name
}
#[cfg(feature = "documentation")]
pub fn docs(&self) -> Option<&'static str> {
self.docs
}
impl_custom_attribute_methods!(self.custom_attributes, "variant");
}
#[cfg(test)]
mod tests {
use super::*;
use crate as bevy_reflect;
use crate::{Reflect, Typed};
#[test]
fn should_return_error_on_invalid_cast() {
#[derive(Reflect)]
enum Foo {
Bar,
}
let info = Foo::type_info().as_enum().unwrap();
let variant = info.variant_at(0).unwrap();
assert!(matches!(
variant.as_tuple_variant(),
Err(VariantInfoError::TypeMismatch {
expected: VariantType::Tuple,
received: VariantType::Unit
})
));
}
}