bevy_reflect_derive/where_clause_options.rs
1use crate::derive_data::ReflectMeta;
2use bevy_macro_utils::fq_std::{FQAny, FQSend, FQSync};
3use indexmap::IndexSet;
4use proc_macro2::{TokenStream, TokenTree};
5use quote::{quote, ToTokens};
6use syn::{punctuated::Punctuated, Ident, Token, Type, WhereClause};
7
8/// Options defining how to extend the `where` clause for reflection.
9pub(crate) struct WhereClauseOptions<'a, 'b> {
10    meta: &'a ReflectMeta<'b>,
11    active_types: IndexSet<Type>,
12}
13
14impl<'a, 'b> WhereClauseOptions<'a, 'b> {
15    pub fn new(meta: &'a ReflectMeta<'b>) -> Self {
16        Self {
17            meta,
18            active_types: IndexSet::new(),
19        }
20    }
21
22    pub fn new_with_types(meta: &'a ReflectMeta<'b>, active_types: IndexSet<Type>) -> Self {
23        Self { meta, active_types }
24    }
25
26    pub fn meta(&self) -> &'a ReflectMeta<'b> {
27        self.meta
28    }
29
30    /// Extends the `where` clause for a type with additional bounds needed for the reflection
31    /// impls.
32    ///
33    /// The default bounds added are as follows:
34    /// - `Self` has:
35    ///   - `Any + Send + Sync` bounds, if generic over types
36    ///   - An `Any` bound, if generic over lifetimes but not types
37    ///   - No bounds, if generic over neither types nor lifetimes
38    /// - Any given bounds in a `where` clause on the type
39    /// - Type parameters have the bound `TypePath` unless `#[reflect(type_path = false)]` is
40    ///   present
41    /// - Active fields with non-generic types have the bounds `TypePath`, either `PartialReflect`
42    ///   if `#[reflect(from_reflect = false)]` is present or `FromReflect` otherwise,
43    ///   `MaybeTyped`, and `RegisterForReflection` (or no bounds at all if
44    ///   `#[reflect(no_field_bounds)]` is present)
45    ///
46    /// When the derive is used with `#[reflect(where)]`, the bounds specified in the attribute are
47    /// added as well.
48    ///
49    /// # Example
50    ///
51    /// ```ignore (bevy_reflect is not accessible from this crate)
52    /// #[derive(Reflect)]
53    /// struct Foo<T, U> {
54    ///   a: T,
55    ///   #[reflect(ignore)]
56    ///   b: U
57    /// }
58    /// ```
59    ///
60    /// Generates the following where clause:
61    ///
62    /// ```ignore (bevy_reflect is not accessible from this crate)
63    /// where
64    ///   // `Self` bounds:
65    ///   Foo<T, U>: Any + Send + Sync,
66    ///   // Type parameter bounds:
67    ///   T: TypePath,
68    ///   U: TypePath,
69    ///   // Active non-generic field bounds
70    ///   T: FromReflect + TypePath + MaybeTyped + RegisterForReflection,
71    ///
72    /// ```
73    ///
74    /// If we add various things to the type:
75    ///
76    /// ```ignore (bevy_reflect is not accessible from this crate)
77    /// #[derive(Reflect)]
78    /// #[reflect(where T: MyTrait)]
79    /// #[reflect(no_field_bounds)]
80    /// struct Foo<T, U>
81    ///     where T: Clone
82    /// {
83    ///   a: T,
84    ///   #[reflect(ignore)]
85    ///   b: U
86    /// }
87    /// ```
88    ///
89    /// It will instead generate the following where clause:
90    ///
91    /// ```ignore (bevy_reflect is not accessible from this crate)
92    /// where
93    ///   // `Self` bounds:
94    ///   Foo<T, U>: Any + Send + Sync,
95    ///   // Given bounds:
96    ///   T: Clone,
97    ///   // Type parameter bounds:
98    ///   T: TypePath,
99    ///   U: TypePath,
100    ///   // No active non-generic field bounds
101    ///   // Custom bounds
102    ///   T: MyTrait,
103    /// ```
104    pub fn extend_where_clause(&self, where_clause: Option<&WhereClause>) -> TokenStream {
105        let mut generic_where_clause = quote! { where };
106
107        // Bounds on `Self`. We would normally just use `Self`, but that won't work for generating
108        // things like assertion functions and trait impls for a type's reference (e.g. `impl
109        // FromArg for &MyType`).
110        let generics = self.meta.type_path().generics();
111        if generics.type_params().next().is_some() {
112            // Generic over types? We need `Any + Send + Sync`.
113            let this = self.meta.type_path().true_type();
114            generic_where_clause.extend(quote! { #this: #FQAny + #FQSend + #FQSync, });
115        } else if generics.lifetimes().next().is_some() {
116            // Generic only over lifetimes? We need `'static`.
117            let this = self.meta.type_path().true_type();
118            generic_where_clause.extend(quote! { #this: 'static, });
119        }
120
121        // Maintain existing where clause bounds, if any.
122        if let Some(where_clause) = where_clause {
123            let predicates = where_clause.predicates.iter();
124            generic_where_clause.extend(quote! { #(#predicates,)* });
125        }
126
127        // Add additional reflection trait bounds.
128        let predicates = self.predicates();
129        generic_where_clause.extend(quote! {
130            #predicates
131        });
132
133        generic_where_clause
134    }
135
136    /// Returns an iterator the where clause predicates to extended the where clause with.
137    fn predicates(&self) -> Punctuated<TokenStream, Token![,]> {
138        let mut predicates = Punctuated::new();
139
140        if let Some(type_param_predicates) = self.type_param_predicates() {
141            predicates.extend(type_param_predicates);
142        }
143
144        if let Some(field_predicates) = self.active_field_predicates() {
145            predicates.extend(field_predicates);
146        }
147
148        if let Some(custom_where) = self.meta.attrs().custom_where() {
149            predicates.push(custom_where.predicates.to_token_stream());
150        }
151
152        predicates
153    }
154
155    /// Returns an iterator over the where clause predicates for the type parameters
156    /// if they require one.
157    fn type_param_predicates(&self) -> Option<impl Iterator<Item = TokenStream> + '_> {
158        self.type_path_bound().map(|type_path_bound| {
159            self.meta
160                .type_path()
161                .generics()
162                .type_params()
163                .map(move |param| {
164                    let ident = ¶m.ident;
165
166                    quote!(#ident : #type_path_bound)
167                })
168        })
169    }
170
171    /// Returns an iterator over the where clause predicates for the active fields.
172    fn active_field_predicates(&self) -> Option<impl Iterator<Item = TokenStream> + '_> {
173        if self.meta.attrs().no_field_bounds() {
174            None
175        } else {
176            let bevy_reflect_path = self.meta.bevy_reflect_path();
177            let reflect_bound = self.reflect_bound();
178
179            // Get the identifiers of all type parameters.
180            let type_param_idents = self
181                .meta
182                .type_path()
183                .generics()
184                .type_params()
185                .map(|type_param| type_param.ident.clone())
186                .collect::<Vec<Ident>>();
187
188            // Do any of the identifiers in `idents` appear in `token_stream`?
189            fn is_any_ident_in_token_stream(idents: &[Ident], token_stream: TokenStream) -> bool {
190                for token_tree in token_stream {
191                    match token_tree {
192                        TokenTree::Ident(ident) => {
193                            if idents.contains(&ident) {
194                                return true;
195                            }
196                        }
197                        TokenTree::Group(group) => {
198                            if is_any_ident_in_token_stream(idents, group.stream()) {
199                                return true;
200                            }
201                        }
202                        TokenTree::Punct(_) | TokenTree::Literal(_) => {}
203                    }
204                }
205                false
206            }
207
208            Some(self.active_types.iter().filter_map(move |ty| {
209                // Field type bounds are only required if `ty` is generic. How to determine that?
210                // Search `ty`s token stream for identifiers that match the identifiers from the
211                // function's type params. E.g. if `T` and `U` are the type param identifiers and
212                // `ty` is `Vec<[T; 4]>` then the `T` identifiers match. This is a bit hacky, but
213                // it works.
214                let is_generic =
215                    is_any_ident_in_token_stream(&type_param_idents, ty.to_token_stream());
216
217                is_generic.then(|| {
218                    quote!(
219                        #ty: #reflect_bound
220                            // Needed to construct `NamedField` and `UnnamedField` instances for
221                            // the `Typed` impl.
222                            + #bevy_reflect_path::TypePath
223                            // Needed for `Typed` impls
224                            + #bevy_reflect_path::MaybeTyped
225                            // Needed for registering type dependencies in the
226                            // `GetTypeRegistration` impl.
227                            + #bevy_reflect_path::__macro_exports::RegisterForReflection
228                    )
229                })
230            }))
231        }
232    }
233
234    /// The `PartialReflect` or `FromReflect` bound to use based on `#[reflect(from_reflect = false)]`.
235    fn reflect_bound(&self) -> TokenStream {
236        let bevy_reflect_path = self.meta.bevy_reflect_path();
237
238        if self.meta.from_reflect().should_auto_derive() {
239            quote!(#bevy_reflect_path::FromReflect)
240        } else {
241            quote!(#bevy_reflect_path::PartialReflect)
242        }
243    }
244
245    /// The `TypePath` bounds to use based on `#[reflect(type_path = false)]`.
246    fn type_path_bound(&self) -> Option<TokenStream> {
247        if self.meta.type_path_attrs().should_auto_derive() {
248            let bevy_reflect_path = self.meta.bevy_reflect_path();
249            Some(quote!(#bevy_reflect_path::TypePath))
250        } else {
251            None
252        }
253    }
254}