nalgebra_macros/
stack_impl.rs

1use crate::Matrix;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{format_ident, quote, quote_spanned};
4use syn::spanned::Spanned;
5use syn::{Error, Expr, Lit};
6
7#[allow(clippy::too_many_lines)]
8pub fn stack_impl(matrix: Matrix) -> syn::Result<TokenStream2> {
9    // The prefix is used to construct variable names
10    // that are extremely unlikely to collide with variable names used in e.g. expressions
11    // by the user. Although we could use a long, pseudo-random string, this makes the generated
12    // code very painful to parse, so we settle for something more semantic that is still
13    // very unlikely to collide
14    let prefix = "___na";
15    let n_block_rows = matrix.nrows();
16    let n_block_cols = matrix.ncols();
17
18    let mut output = quote! {};
19
20    // First assign data and shape for each matrix entry to variables
21    // (this is important so that we, for example, don't evaluate an expression more than once)
22    for i in 0..n_block_rows {
23        for j in 0..n_block_cols {
24            let expr = &matrix[(i, j)];
25            if !is_literal_zero(expr) {
26                let ident_block = format_ident!("{prefix}_stack_{i}_{j}_block");
27                let ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape");
28                output.extend(std::iter::once(quote_spanned! {expr.span()=>
29                    let ref #ident_block = #expr;
30                    let #ident_shape = #ident_block.shape_generic();
31                }));
32            }
33        }
34    }
35
36    // Determine the number of rows (dimension) in each block row,
37    // and write out variables that define block row dimensions and offsets into the
38    // output matrix
39    for i in 0..n_block_rows {
40        // The dimension of the block row is the result of trying to unify the row shape of
41        // all blocks in the block row
42        let dim = (0 ..n_block_cols)
43            .filter_map(|j| {
44                let expr = &matrix[(i, j)];
45                if !is_literal_zero(expr) {
46                    let mut ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape");
47                    ident_shape.set_span(ident_shape.span().located_at(expr.span()));
48                    Some(quote_spanned!{expr.span()=> #ident_shape.0 })
49                } else {
50                    None
51                }
52            }).reduce(|a, b| {
53                let expect_msg = format!("All blocks in block row {i} must have the same number of rows");
54                quote_spanned!{b.span()=>
55                    <nalgebra::constraint::ShapeConstraint as nalgebra::constraint::SameNumberOfRows<_, _>>::representative(#a, #b)
56                        .expect(#expect_msg)
57                }
58            }).ok_or(Error::new(Span::call_site(), format!("Block row {i} cannot consist entirely of implicit zero blocks.")))?;
59
60        let dim_ident = format_ident!("{prefix}_stack_row_{i}_dim");
61        let offset_ident = format_ident!("{prefix}_stack_row_{i}_offset");
62
63        let offset = if i == 0 {
64            quote! { 0 }
65        } else {
66            let prev_offset_ident = format_ident!("{prefix}_stack_row_{}_offset", i - 1);
67            let prev_dim_ident = format_ident!("{prefix}_stack_row_{}_dim", i - 1);
68            quote! { #prev_offset_ident + <_ as nalgebra::Dim>::value(&#prev_dim_ident) }
69        };
70
71        output.extend(std::iter::once(quote! {
72            let #dim_ident = #dim;
73            let #offset_ident = #offset;
74        }));
75    }
76
77    // Do the same thing for the block columns
78    for j in 0..n_block_cols {
79        let dim = (0 ..n_block_rows)
80            .filter_map(|i| {
81                let expr = &matrix[(i, j)];
82                if !is_literal_zero(expr) {
83                    let mut ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape");
84                    ident_shape.set_span(ident_shape.span().located_at(expr.span()));
85                    Some(quote_spanned!{expr.span()=> #ident_shape.1 })
86                } else {
87                    None
88                }
89            }).reduce(|a, b| {
90                let expect_msg = format!("All blocks in block column {j} must have the same number of columns");
91                quote_spanned!{b.span()=>
92                        <nalgebra::constraint::ShapeConstraint as nalgebra::constraint::SameNumberOfColumns<_, _>>::representative(#a, #b)
93                            .expect(#expect_msg)
94                }
95            }).ok_or(Error::new(Span::call_site(), format!("Block column {j} cannot consist entirely of implicit zero blocks.")))?;
96
97        let dim_ident = format_ident!("{prefix}_stack_col_{j}_dim");
98        let offset_ident = format_ident!("{prefix}_stack_col_{j}_offset");
99
100        let offset = if j == 0 {
101            quote! { 0 }
102        } else {
103            let prev_offset_ident = format_ident!("{prefix}_stack_col_{}_offset", j - 1);
104            let prev_dim_ident = format_ident!("{prefix}_stack_col_{}_dim", j - 1);
105            quote! { #prev_offset_ident + <_ as nalgebra::Dim>::value(&#prev_dim_ident) }
106        };
107
108        output.extend(std::iter::once(quote! {
109            let #dim_ident = #dim;
110            let #offset_ident = #offset;
111        }));
112    }
113
114    // Determine number of rows and cols in output matrix,
115    // by adding together dimensions of all block rows/cols
116    let num_rows = (0..n_block_rows)
117        .map(|i| {
118            let ident = format_ident!("{prefix}_stack_row_{i}_dim");
119            quote! { #ident }
120        })
121        .reduce(|a, b| {
122            quote! {
123                <_ as nalgebra::DimAdd<_>>::add(#a, #b)
124            }
125        })
126        .unwrap_or(quote! { nalgebra::dimension::U0 });
127
128    let num_cols = (0..n_block_cols)
129        .map(|j| {
130            let ident = format_ident!("{prefix}_stack_col_{j}_dim");
131            quote! { #ident }
132        })
133        .reduce(|a, b| {
134            quote! {
135                <_ as nalgebra::DimAdd<_>>::add(#a, #b)
136            }
137        })
138        .unwrap_or(quote! { nalgebra::dimension::U0 });
139
140    // It should be possible to use `uninitialized_generic` here instead
141    // however that would mean that the macro needs to generate unsafe code
142    // which does not seem like a great idea.
143    output.extend(std::iter::once(quote! {
144        let mut matrix = nalgebra::Matrix::zeros_generic(#num_rows, #num_cols);
145    }));
146
147    for i in 0..n_block_rows {
148        for j in 0..n_block_cols {
149            let row_dim = format_ident!("{prefix}_stack_row_{i}_dim");
150            let col_dim = format_ident!("{prefix}_stack_col_{j}_dim");
151            let row_offset = format_ident!("{prefix}_stack_row_{i}_offset");
152            let col_offset = format_ident!("{prefix}_stack_col_{j}_offset");
153            let expr = &matrix[(i, j)];
154            if !is_literal_zero(expr) {
155                let expr_ident = format_ident!("{prefix}_stack_{i}_{j}_block");
156                output.extend(std::iter::once(quote! {
157                    let start = (#row_offset, #col_offset);
158                    let shape = (#row_dim, #col_dim);
159                    let input_view = #expr_ident.generic_view((0, 0), shape);
160                    let mut output_view = matrix.generic_view_mut(start, shape);
161                    output_view.copy_from(&input_view);
162                }));
163            }
164        }
165    }
166
167    Ok(quote! {
168        {
169            #output
170            matrix
171        }
172    })
173}
174
175fn is_literal_zero(expr: &Expr) -> bool {
176    matches!(expr,
177        Expr::Lit(syn::ExprLit { lit: Lit::Int(integer_literal), .. })
178        if integer_literal.base10_digits() == "0")
179}
180
181#[cfg(test)]
182mod tests {
183    use crate::Matrix;
184    use crate::stack_impl::stack_impl;
185    use quote::quote;
186
187    #[test]
188    fn stack_simple_generation() {
189        let input: Matrix = syn::parse_quote![
190            a, 0;
191            0, b;
192        ];
193
194        let result = stack_impl(input).unwrap();
195
196        let expected = quote! {{
197            let ref ___na_stack_0_0_block = a;
198            let ___na_stack_0_0_shape = ___na_stack_0_0_block.shape_generic();
199            let ref ___na_stack_1_1_block = b;
200            let ___na_stack_1_1_shape = ___na_stack_1_1_block.shape_generic();
201            let ___na_stack_row_0_dim = ___na_stack_0_0_shape.0;
202            let ___na_stack_row_0_offset = 0;
203            let ___na_stack_row_1_dim = ___na_stack_1_1_shape.0;
204            let ___na_stack_row_1_offset = ___na_stack_row_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_0_dim);
205            let ___na_stack_col_0_dim = ___na_stack_0_0_shape.1;
206            let ___na_stack_col_0_offset = 0;
207            let ___na_stack_col_1_dim = ___na_stack_1_1_shape.1;
208            let ___na_stack_col_1_offset = ___na_stack_col_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_0_dim);
209            let mut matrix = nalgebra::Matrix::zeros_generic(
210                <_ as nalgebra::DimAdd<_>>::add(___na_stack_row_0_dim, ___na_stack_row_1_dim),
211                <_ as nalgebra::DimAdd<_>>::add(___na_stack_col_0_dim, ___na_stack_col_1_dim)
212            );
213            let start = (___na_stack_row_0_offset, ___na_stack_col_0_offset);
214            let shape = (___na_stack_row_0_dim, ___na_stack_col_0_dim);
215            let input_view = ___na_stack_0_0_block.generic_view((0,0), shape);
216            let mut output_view = matrix.generic_view_mut(start, shape);
217            output_view.copy_from(&input_view);
218            let start = (___na_stack_row_1_offset, ___na_stack_col_1_offset);
219            let shape = (___na_stack_row_1_dim, ___na_stack_col_1_dim);
220            let input_view = ___na_stack_1_1_block.generic_view((0,0), shape);
221            let mut output_view = matrix.generic_view_mut(start, shape);
222            output_view.copy_from(&input_view);
223            matrix
224        }};
225
226        assert_eq!(format!("{result}"), format!("{}", expected));
227    }
228
229    #[test]
230    fn stack_complex_generation() {
231        let input: Matrix = syn::parse_quote![
232            a, 0, b;
233            0, c, d;
234            e, 0, 0;
235        ];
236
237        let result = stack_impl(input).unwrap();
238
239        let expected = quote! {{
240            let ref ___na_stack_0_0_block = a;
241            let ___na_stack_0_0_shape = ___na_stack_0_0_block.shape_generic();
242            let ref ___na_stack_0_2_block = b;
243            let ___na_stack_0_2_shape = ___na_stack_0_2_block.shape_generic();
244            let ref ___na_stack_1_1_block = c;
245            let ___na_stack_1_1_shape = ___na_stack_1_1_block.shape_generic();
246            let ref ___na_stack_1_2_block = d;
247            let ___na_stack_1_2_shape = ___na_stack_1_2_block.shape_generic();
248            let ref ___na_stack_2_0_block = e;
249            let ___na_stack_2_0_shape = ___na_stack_2_0_block.shape_generic();
250            let ___na_stack_row_0_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfRows < _ , _ >> :: representative (___na_stack_0_0_shape . 0 , ___na_stack_0_2_shape . 0) . expect ("All blocks in block row 0 must have the same number of rows") ;
251            let ___na_stack_row_0_offset = 0;
252            let ___na_stack_row_1_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfRows < _ , _ >> :: representative (___na_stack_1_1_shape . 0 , ___na_stack_1_2_shape . 0) . expect ("All blocks in block row 1 must have the same number of rows") ;
253            let ___na_stack_row_1_offset = ___na_stack_row_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_0_dim);
254            let ___na_stack_row_2_dim = ___na_stack_2_0_shape.0;
255            let ___na_stack_row_2_offset = ___na_stack_row_1_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_1_dim);
256            let ___na_stack_col_0_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfColumns < _ , _ >> :: representative (___na_stack_0_0_shape . 1 , ___na_stack_2_0_shape . 1) . expect ("All blocks in block column 0 must have the same number of columns") ;
257            let ___na_stack_col_0_offset = 0;
258            let ___na_stack_col_1_dim = ___na_stack_1_1_shape.1;
259            let ___na_stack_col_1_offset = ___na_stack_col_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_0_dim);
260            let ___na_stack_col_2_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfColumns < _ , _ >> :: representative (___na_stack_0_2_shape . 1 , ___na_stack_1_2_shape . 1) . expect ("All blocks in block column 2 must have the same number of columns") ;
261            let ___na_stack_col_2_offset = ___na_stack_col_1_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_1_dim);
262            let mut matrix = nalgebra::Matrix::zeros_generic(
263                <_ as nalgebra::DimAdd<_>>::add(
264                    <_ as nalgebra::DimAdd<_>>::add(___na_stack_row_0_dim, ___na_stack_row_1_dim),
265                    ___na_stack_row_2_dim
266                ),
267                <_ as nalgebra::DimAdd<_>>::add(
268                    <_ as nalgebra::DimAdd<_>>::add(___na_stack_col_0_dim, ___na_stack_col_1_dim),
269                    ___na_stack_col_2_dim
270                )
271            );
272            let start = (___na_stack_row_0_offset, ___na_stack_col_0_offset);
273            let shape = (___na_stack_row_0_dim, ___na_stack_col_0_dim);
274            let input_view = ___na_stack_0_0_block.generic_view((0,0), shape);
275            let mut output_view = matrix.generic_view_mut(start, shape);
276            output_view.copy_from(&input_view);
277            let start = (___na_stack_row_0_offset, ___na_stack_col_2_offset);
278            let shape = (___na_stack_row_0_dim, ___na_stack_col_2_dim);
279            let input_view = ___na_stack_0_2_block.generic_view((0,0), shape);
280            let mut output_view = matrix.generic_view_mut(start, shape);
281            output_view.copy_from(&input_view);
282            let start = (___na_stack_row_1_offset, ___na_stack_col_1_offset);
283            let shape = (___na_stack_row_1_dim, ___na_stack_col_1_dim);
284            let input_view = ___na_stack_1_1_block.generic_view((0,0), shape);
285            let mut output_view = matrix.generic_view_mut(start, shape);
286            output_view.copy_from(&input_view);
287            let start = (___na_stack_row_1_offset, ___na_stack_col_2_offset);
288            let shape = (___na_stack_row_1_dim, ___na_stack_col_2_dim);
289            let input_view = ___na_stack_1_2_block.generic_view((0,0), shape);
290            let mut output_view = matrix.generic_view_mut(start, shape);
291            output_view.copy_from(&input_view);
292            let start = (___na_stack_row_2_offset, ___na_stack_col_0_offset);
293            let shape = (___na_stack_row_2_dim, ___na_stack_col_0_dim);
294            let input_view = ___na_stack_2_0_block.generic_view((0,0), shape);
295            let mut output_view = matrix.generic_view_mut(start, shape);
296            output_view.copy_from(&input_view);
297            matrix
298        }};
299
300        assert_eq!(format!("{result}"), format!("{}", expected));
301    }
302}