1use crate::Matrix;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{format_ident, quote, quote_spanned};
4use syn::spanned::Spanned;
5use syn::{Error, Expr, Lit};
6
7#[allow(clippy::too_many_lines)]
8pub fn stack_impl(matrix: Matrix) -> syn::Result<TokenStream2> {
9 let prefix = "___na";
15 let n_block_rows = matrix.nrows();
16 let n_block_cols = matrix.ncols();
17
18 let mut output = quote! {};
19
20 for i in 0..n_block_rows {
23 for j in 0..n_block_cols {
24 let expr = &matrix[(i, j)];
25 if !is_literal_zero(expr) {
26 let ident_block = format_ident!("{prefix}_stack_{i}_{j}_block");
27 let ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape");
28 output.extend(std::iter::once(quote_spanned! {expr.span()=>
29 let ref #ident_block = #expr;
30 let #ident_shape = #ident_block.shape_generic();
31 }));
32 }
33 }
34 }
35
36 for i in 0..n_block_rows {
40 let dim = (0 ..n_block_cols)
43 .filter_map(|j| {
44 let expr = &matrix[(i, j)];
45 if !is_literal_zero(expr) {
46 let mut ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape");
47 ident_shape.set_span(ident_shape.span().located_at(expr.span()));
48 Some(quote_spanned!{expr.span()=> #ident_shape.0 })
49 } else {
50 None
51 }
52 }).reduce(|a, b| {
53 let expect_msg = format!("All blocks in block row {i} must have the same number of rows");
54 quote_spanned!{b.span()=>
55 <nalgebra::constraint::ShapeConstraint as nalgebra::constraint::SameNumberOfRows<_, _>>::representative(#a, #b)
56 .expect(#expect_msg)
57 }
58 }).ok_or(Error::new(Span::call_site(), format!("Block row {i} cannot consist entirely of implicit zero blocks.")))?;
59
60 let dim_ident = format_ident!("{prefix}_stack_row_{i}_dim");
61 let offset_ident = format_ident!("{prefix}_stack_row_{i}_offset");
62
63 let offset = if i == 0 {
64 quote! { 0 }
65 } else {
66 let prev_offset_ident = format_ident!("{prefix}_stack_row_{}_offset", i - 1);
67 let prev_dim_ident = format_ident!("{prefix}_stack_row_{}_dim", i - 1);
68 quote! { #prev_offset_ident + <_ as nalgebra::Dim>::value(&#prev_dim_ident) }
69 };
70
71 output.extend(std::iter::once(quote! {
72 let #dim_ident = #dim;
73 let #offset_ident = #offset;
74 }));
75 }
76
77 for j in 0..n_block_cols {
79 let dim = (0 ..n_block_rows)
80 .filter_map(|i| {
81 let expr = &matrix[(i, j)];
82 if !is_literal_zero(expr) {
83 let mut ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape");
84 ident_shape.set_span(ident_shape.span().located_at(expr.span()));
85 Some(quote_spanned!{expr.span()=> #ident_shape.1 })
86 } else {
87 None
88 }
89 }).reduce(|a, b| {
90 let expect_msg = format!("All blocks in block column {j} must have the same number of columns");
91 quote_spanned!{b.span()=>
92 <nalgebra::constraint::ShapeConstraint as nalgebra::constraint::SameNumberOfColumns<_, _>>::representative(#a, #b)
93 .expect(#expect_msg)
94 }
95 }).ok_or(Error::new(Span::call_site(), format!("Block column {j} cannot consist entirely of implicit zero blocks.")))?;
96
97 let dim_ident = format_ident!("{prefix}_stack_col_{j}_dim");
98 let offset_ident = format_ident!("{prefix}_stack_col_{j}_offset");
99
100 let offset = if j == 0 {
101 quote! { 0 }
102 } else {
103 let prev_offset_ident = format_ident!("{prefix}_stack_col_{}_offset", j - 1);
104 let prev_dim_ident = format_ident!("{prefix}_stack_col_{}_dim", j - 1);
105 quote! { #prev_offset_ident + <_ as nalgebra::Dim>::value(&#prev_dim_ident) }
106 };
107
108 output.extend(std::iter::once(quote! {
109 let #dim_ident = #dim;
110 let #offset_ident = #offset;
111 }));
112 }
113
114 let num_rows = (0..n_block_rows)
117 .map(|i| {
118 let ident = format_ident!("{prefix}_stack_row_{i}_dim");
119 quote! { #ident }
120 })
121 .reduce(|a, b| {
122 quote! {
123 <_ as nalgebra::DimAdd<_>>::add(#a, #b)
124 }
125 })
126 .unwrap_or(quote! { nalgebra::dimension::U0 });
127
128 let num_cols = (0..n_block_cols)
129 .map(|j| {
130 let ident = format_ident!("{prefix}_stack_col_{j}_dim");
131 quote! { #ident }
132 })
133 .reduce(|a, b| {
134 quote! {
135 <_ as nalgebra::DimAdd<_>>::add(#a, #b)
136 }
137 })
138 .unwrap_or(quote! { nalgebra::dimension::U0 });
139
140 output.extend(std::iter::once(quote! {
144 let mut matrix = nalgebra::Matrix::zeros_generic(#num_rows, #num_cols);
145 }));
146
147 for i in 0..n_block_rows {
148 for j in 0..n_block_cols {
149 let row_dim = format_ident!("{prefix}_stack_row_{i}_dim");
150 let col_dim = format_ident!("{prefix}_stack_col_{j}_dim");
151 let row_offset = format_ident!("{prefix}_stack_row_{i}_offset");
152 let col_offset = format_ident!("{prefix}_stack_col_{j}_offset");
153 let expr = &matrix[(i, j)];
154 if !is_literal_zero(expr) {
155 let expr_ident = format_ident!("{prefix}_stack_{i}_{j}_block");
156 output.extend(std::iter::once(quote! {
157 let start = (#row_offset, #col_offset);
158 let shape = (#row_dim, #col_dim);
159 let input_view = #expr_ident.generic_view((0, 0), shape);
160 let mut output_view = matrix.generic_view_mut(start, shape);
161 output_view.copy_from(&input_view);
162 }));
163 }
164 }
165 }
166
167 Ok(quote! {
168 {
169 #output
170 matrix
171 }
172 })
173}
174
175fn is_literal_zero(expr: &Expr) -> bool {
176 matches!(expr,
177 Expr::Lit(syn::ExprLit { lit: Lit::Int(integer_literal), .. })
178 if integer_literal.base10_digits() == "0")
179}
180
181#[cfg(test)]
182mod tests {
183 use crate::Matrix;
184 use crate::stack_impl::stack_impl;
185 use quote::quote;
186
187 #[test]
188 fn stack_simple_generation() {
189 let input: Matrix = syn::parse_quote![
190 a, 0;
191 0, b;
192 ];
193
194 let result = stack_impl(input).unwrap();
195
196 let expected = quote! {{
197 let ref ___na_stack_0_0_block = a;
198 let ___na_stack_0_0_shape = ___na_stack_0_0_block.shape_generic();
199 let ref ___na_stack_1_1_block = b;
200 let ___na_stack_1_1_shape = ___na_stack_1_1_block.shape_generic();
201 let ___na_stack_row_0_dim = ___na_stack_0_0_shape.0;
202 let ___na_stack_row_0_offset = 0;
203 let ___na_stack_row_1_dim = ___na_stack_1_1_shape.0;
204 let ___na_stack_row_1_offset = ___na_stack_row_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_0_dim);
205 let ___na_stack_col_0_dim = ___na_stack_0_0_shape.1;
206 let ___na_stack_col_0_offset = 0;
207 let ___na_stack_col_1_dim = ___na_stack_1_1_shape.1;
208 let ___na_stack_col_1_offset = ___na_stack_col_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_0_dim);
209 let mut matrix = nalgebra::Matrix::zeros_generic(
210 <_ as nalgebra::DimAdd<_>>::add(___na_stack_row_0_dim, ___na_stack_row_1_dim),
211 <_ as nalgebra::DimAdd<_>>::add(___na_stack_col_0_dim, ___na_stack_col_1_dim)
212 );
213 let start = (___na_stack_row_0_offset, ___na_stack_col_0_offset);
214 let shape = (___na_stack_row_0_dim, ___na_stack_col_0_dim);
215 let input_view = ___na_stack_0_0_block.generic_view((0,0), shape);
216 let mut output_view = matrix.generic_view_mut(start, shape);
217 output_view.copy_from(&input_view);
218 let start = (___na_stack_row_1_offset, ___na_stack_col_1_offset);
219 let shape = (___na_stack_row_1_dim, ___na_stack_col_1_dim);
220 let input_view = ___na_stack_1_1_block.generic_view((0,0), shape);
221 let mut output_view = matrix.generic_view_mut(start, shape);
222 output_view.copy_from(&input_view);
223 matrix
224 }};
225
226 assert_eq!(format!("{result}"), format!("{}", expected));
227 }
228
229 #[test]
230 fn stack_complex_generation() {
231 let input: Matrix = syn::parse_quote![
232 a, 0, b;
233 0, c, d;
234 e, 0, 0;
235 ];
236
237 let result = stack_impl(input).unwrap();
238
239 let expected = quote! {{
240 let ref ___na_stack_0_0_block = a;
241 let ___na_stack_0_0_shape = ___na_stack_0_0_block.shape_generic();
242 let ref ___na_stack_0_2_block = b;
243 let ___na_stack_0_2_shape = ___na_stack_0_2_block.shape_generic();
244 let ref ___na_stack_1_1_block = c;
245 let ___na_stack_1_1_shape = ___na_stack_1_1_block.shape_generic();
246 let ref ___na_stack_1_2_block = d;
247 let ___na_stack_1_2_shape = ___na_stack_1_2_block.shape_generic();
248 let ref ___na_stack_2_0_block = e;
249 let ___na_stack_2_0_shape = ___na_stack_2_0_block.shape_generic();
250 let ___na_stack_row_0_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfRows < _ , _ >> :: representative (___na_stack_0_0_shape . 0 , ___na_stack_0_2_shape . 0) . expect ("All blocks in block row 0 must have the same number of rows") ;
251 let ___na_stack_row_0_offset = 0;
252 let ___na_stack_row_1_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfRows < _ , _ >> :: representative (___na_stack_1_1_shape . 0 , ___na_stack_1_2_shape . 0) . expect ("All blocks in block row 1 must have the same number of rows") ;
253 let ___na_stack_row_1_offset = ___na_stack_row_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_0_dim);
254 let ___na_stack_row_2_dim = ___na_stack_2_0_shape.0;
255 let ___na_stack_row_2_offset = ___na_stack_row_1_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_1_dim);
256 let ___na_stack_col_0_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfColumns < _ , _ >> :: representative (___na_stack_0_0_shape . 1 , ___na_stack_2_0_shape . 1) . expect ("All blocks in block column 0 must have the same number of columns") ;
257 let ___na_stack_col_0_offset = 0;
258 let ___na_stack_col_1_dim = ___na_stack_1_1_shape.1;
259 let ___na_stack_col_1_offset = ___na_stack_col_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_0_dim);
260 let ___na_stack_col_2_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfColumns < _ , _ >> :: representative (___na_stack_0_2_shape . 1 , ___na_stack_1_2_shape . 1) . expect ("All blocks in block column 2 must have the same number of columns") ;
261 let ___na_stack_col_2_offset = ___na_stack_col_1_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_1_dim);
262 let mut matrix = nalgebra::Matrix::zeros_generic(
263 <_ as nalgebra::DimAdd<_>>::add(
264 <_ as nalgebra::DimAdd<_>>::add(___na_stack_row_0_dim, ___na_stack_row_1_dim),
265 ___na_stack_row_2_dim
266 ),
267 <_ as nalgebra::DimAdd<_>>::add(
268 <_ as nalgebra::DimAdd<_>>::add(___na_stack_col_0_dim, ___na_stack_col_1_dim),
269 ___na_stack_col_2_dim
270 )
271 );
272 let start = (___na_stack_row_0_offset, ___na_stack_col_0_offset);
273 let shape = (___na_stack_row_0_dim, ___na_stack_col_0_dim);
274 let input_view = ___na_stack_0_0_block.generic_view((0,0), shape);
275 let mut output_view = matrix.generic_view_mut(start, shape);
276 output_view.copy_from(&input_view);
277 let start = (___na_stack_row_0_offset, ___na_stack_col_2_offset);
278 let shape = (___na_stack_row_0_dim, ___na_stack_col_2_dim);
279 let input_view = ___na_stack_0_2_block.generic_view((0,0), shape);
280 let mut output_view = matrix.generic_view_mut(start, shape);
281 output_view.copy_from(&input_view);
282 let start = (___na_stack_row_1_offset, ___na_stack_col_1_offset);
283 let shape = (___na_stack_row_1_dim, ___na_stack_col_1_dim);
284 let input_view = ___na_stack_1_1_block.generic_view((0,0), shape);
285 let mut output_view = matrix.generic_view_mut(start, shape);
286 output_view.copy_from(&input_view);
287 let start = (___na_stack_row_1_offset, ___na_stack_col_2_offset);
288 let shape = (___na_stack_row_1_dim, ___na_stack_col_2_dim);
289 let input_view = ___na_stack_1_2_block.generic_view((0,0), shape);
290 let mut output_view = matrix.generic_view_mut(start, shape);
291 output_view.copy_from(&input_view);
292 let start = (___na_stack_row_2_offset, ___na_stack_col_0_offset);
293 let shape = (___na_stack_row_2_dim, ___na_stack_col_0_dim);
294 let input_view = ___na_stack_2_0_block.generic_view((0,0), shape);
295 let mut output_view = matrix.generic_view_mut(start, shape);
296 output_view.copy_from(&input_view);
297 matrix
298 }};
299
300 assert_eq!(format!("{result}"), format!("{}", expected));
301 }
302}