nalgebra_macros/
matrix_vector_impl.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens, TokenStreamExt};
3use std::ops::Index;
4use syn::parse::{Error, Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::Expr;
8use syn::{parse_macro_input, Token};
9
10use proc_macro2::{Delimiter, Spacing, TokenStream as TokenStream2, TokenTree};
11use proc_macro2::{Group, Punct};
12
13pub struct Matrix {
15 data: Vec<Expr>,
17 nrows: usize,
18 ncols: usize,
19}
20
21impl Index<(usize, usize)> for Matrix {
22 type Output = Expr;
23
24 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
25 let linear_idx = self.ncols * row + col;
26 &self.data[linear_idx]
27 }
28}
29
30impl Matrix {
31 pub fn nrows(&self) -> usize {
32 self.nrows
33 }
34
35 pub fn ncols(&self) -> usize {
36 self.ncols
37 }
38
39 pub fn to_col_major_nested_array_tokens(&self) -> TokenStream2 {
41 let mut result = TokenStream2::new();
42 for j in 0..self.ncols() {
43 let mut col = TokenStream2::new();
44 let col_iter = (0..self.nrows()).map(|i| &self[(i, j)]);
45 col.append_separated(col_iter, Punct::new(',', Spacing::Alone));
46 result.append(Group::new(Delimiter::Bracket, col));
47 result.append(Punct::new(',', Spacing::Alone));
48 }
49 TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result)))
50 }
51
52 pub fn to_col_major_flat_array_tokens(&self) -> TokenStream2 {
55 let mut data = TokenStream2::new();
56 for j in 0..self.ncols() {
57 for i in 0..self.nrows() {
58 self[(i, j)].to_tokens(&mut data);
59 data.append(Punct::new(',', Spacing::Alone));
60 }
61 }
62 TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data)))
63 }
64}
65
66type MatrixRowSyntax = Punctuated<Expr, Token![,]>;
67
68impl Parse for Matrix {
69 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
70 let mut data = Vec::new();
71 let mut ncols = None;
72 let mut nrows = 0;
73
74 while !input.is_empty() {
75 let row = MatrixRowSyntax::parse_separated_nonempty(input)?;
76 let row_span = row.span();
77
78 if let Some(ncols) = ncols {
79 if row.len() != ncols {
80 let error_msg = format!(
81 "Unexpected number of entries in row {}. Expected {}, found {} entries.",
82 nrows,
83 ncols,
84 row.len()
85 );
86 return Err(Error::new(row_span, error_msg));
87 }
88 } else {
89 ncols = Some(row.len());
90 }
91 data.extend(row.into_iter());
92 nrows += 1;
93
94 if !input.is_empty() {
97 input.parse::<Token![;]>()?;
98 }
99 }
100
101 Ok(Self {
102 data,
103 nrows,
104 ncols: ncols.unwrap_or(0),
105 })
106 }
107}
108
109pub struct Vector {
110 elements: Vec<Expr>,
111}
112
113impl Vector {
114 pub fn to_array_tokens(&self) -> TokenStream2 {
115 let mut data = TokenStream2::new();
116 data.append_separated(&self.elements, Punct::new(',', Spacing::Alone));
117 TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data)))
118 }
119
120 pub fn len(&self) -> usize {
121 self.elements.len()
122 }
123}
124
125impl Parse for Vector {
126 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
127 if input.is_empty() {
129 Ok(Self {
130 elements: Vec::new(),
131 })
132 } else {
133 let elements = MatrixRowSyntax::parse_terminated(input)?
134 .into_iter()
135 .collect();
136 Ok(Self { elements })
137 }
138 }
139}
140
141pub fn matrix_impl(stream: TokenStream) -> TokenStream {
142 let matrix = parse_macro_input!(stream as Matrix);
143
144 let row_dim = matrix.nrows();
145 let col_dim = matrix.ncols();
146
147 let array_tokens = matrix.to_col_major_nested_array_tokens();
148
149 let output = quote! {
151 nalgebra::SMatrix::<_, #row_dim, #col_dim>
152 ::from_array_storage(nalgebra::ArrayStorage(#array_tokens))
153 };
154
155 proc_macro::TokenStream::from(output)
156}
157
158pub fn dmatrix_impl(stream: TokenStream) -> TokenStream {
159 let matrix = parse_macro_input!(stream as Matrix);
160
161 let row_dim = matrix.nrows();
162 let col_dim = matrix.ncols();
163
164 let array_tokens = matrix.to_col_major_flat_array_tokens();
165
166 let output = quote! {
168 nalgebra::DMatrix::<_>
169 ::from_vec_storage(nalgebra::VecStorage::new(
170 nalgebra::Dyn(#row_dim),
171 nalgebra::Dyn(#col_dim),
172 vec!#array_tokens))
173 };
174
175 proc_macro::TokenStream::from(output)
176}
177
178pub fn vector_impl(stream: TokenStream) -> TokenStream {
179 let vector = parse_macro_input!(stream as Vector);
180 let len = vector.len();
181 let array_tokens = vector.to_array_tokens();
182 let output = quote! {
183 nalgebra::SVector::<_, #len>
184 ::from_array_storage(nalgebra::ArrayStorage([#array_tokens]))
185 };
186 proc_macro::TokenStream::from(output)
187}
188
189pub fn dvector_impl(stream: TokenStream) -> TokenStream {
190 let vector = parse_macro_input!(stream as Vector);
191 let len = vector.len();
192 let array_tokens = vector.to_array_tokens();
193 let output = quote! {
194 nalgebra::DVector::<_>
195 ::from_vec_storage(nalgebra::VecStorage::new(
196 nalgebra::Dyn(#len),
197 nalgebra::Const::<1>,
198 vec!#array_tokens))
199 };
200 proc_macro::TokenStream::from(output)
201}