nalgebra_macros/
matrix_vector_impl.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens, TokenStreamExt};
3use std::ops::Index;
4use syn::parse::{Error, Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::Expr;
8use syn::{parse_macro_input, Token};
9
10use proc_macro2::{Delimiter, Spacing, TokenStream as TokenStream2, TokenTree};
11use proc_macro2::{Group, Punct};
12
13/// A matrix of expressions
14pub struct Matrix {
15    // Represent the matrix data in row-major format
16    data: Vec<Expr>,
17    nrows: usize,
18    ncols: usize,
19}
20
21impl Index<(usize, usize)> for Matrix {
22    type Output = Expr;
23
24    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
25        let linear_idx = self.ncols * row + col;
26        &self.data[linear_idx]
27    }
28}
29
30impl Matrix {
31    pub fn nrows(&self) -> usize {
32        self.nrows
33    }
34
35    pub fn ncols(&self) -> usize {
36        self.ncols
37    }
38
39    /// Produces a stream of tokens representing this matrix as a column-major nested array.
40    pub fn to_col_major_nested_array_tokens(&self) -> TokenStream2 {
41        let mut result = TokenStream2::new();
42        for j in 0..self.ncols() {
43            let mut col = TokenStream2::new();
44            let col_iter = (0..self.nrows()).map(|i| &self[(i, j)]);
45            col.append_separated(col_iter, Punct::new(',', Spacing::Alone));
46            result.append(Group::new(Delimiter::Bracket, col));
47            result.append(Punct::new(',', Spacing::Alone));
48        }
49        TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result)))
50    }
51
52    /// Produces a stream of tokens representing this matrix as a column-major flat array
53    /// (suitable for representing e.g. a `DMatrix`).
54    pub fn to_col_major_flat_array_tokens(&self) -> TokenStream2 {
55        let mut data = TokenStream2::new();
56        for j in 0..self.ncols() {
57            for i in 0..self.nrows() {
58                self[(i, j)].to_tokens(&mut data);
59                data.append(Punct::new(',', Spacing::Alone));
60            }
61        }
62        TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data)))
63    }
64}
65
66type MatrixRowSyntax = Punctuated<Expr, Token![,]>;
67
68impl Parse for Matrix {
69    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
70        let mut data = Vec::new();
71        let mut ncols = None;
72        let mut nrows = 0;
73
74        while !input.is_empty() {
75            let row = MatrixRowSyntax::parse_separated_nonempty(input)?;
76            let row_span = row.span();
77
78            if let Some(ncols) = ncols {
79                if row.len() != ncols {
80                    let error_msg = format!(
81                        "Unexpected number of entries in row {}. Expected {}, found {} entries.",
82                        nrows,
83                        ncols,
84                        row.len()
85                    );
86                    return Err(Error::new(row_span, error_msg));
87                }
88            } else {
89                ncols = Some(row.len());
90            }
91            data.extend(row.into_iter());
92            nrows += 1;
93
94            // We've just read a row, so if there are more tokens, there must be a semi-colon,
95            // otherwise the input is malformed
96            if !input.is_empty() {
97                input.parse::<Token![;]>()?;
98            }
99        }
100
101        Ok(Self {
102            data,
103            nrows,
104            ncols: ncols.unwrap_or(0),
105        })
106    }
107}
108
109pub struct Vector {
110    elements: Vec<Expr>,
111}
112
113impl Vector {
114    pub fn to_array_tokens(&self) -> TokenStream2 {
115        let mut data = TokenStream2::new();
116        data.append_separated(&self.elements, Punct::new(',', Spacing::Alone));
117        TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data)))
118    }
119
120    pub fn len(&self) -> usize {
121        self.elements.len()
122    }
123}
124
125impl Parse for Vector {
126    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
127        // The syntax of a vector is just the syntax of a single matrix row
128        if input.is_empty() {
129            Ok(Self {
130                elements: Vec::new(),
131            })
132        } else {
133            let elements = MatrixRowSyntax::parse_terminated(input)?
134                .into_iter()
135                .collect();
136            Ok(Self { elements })
137        }
138    }
139}
140
141pub fn matrix_impl(stream: TokenStream) -> TokenStream {
142    let matrix = parse_macro_input!(stream as Matrix);
143
144    let row_dim = matrix.nrows();
145    let col_dim = matrix.ncols();
146
147    let array_tokens = matrix.to_col_major_nested_array_tokens();
148
149    //  TODO: Use quote_spanned instead??
150    let output = quote! {
151        nalgebra::SMatrix::<_, #row_dim, #col_dim>
152            ::from_array_storage(nalgebra::ArrayStorage(#array_tokens))
153    };
154
155    proc_macro::TokenStream::from(output)
156}
157
158pub fn dmatrix_impl(stream: TokenStream) -> TokenStream {
159    let matrix = parse_macro_input!(stream as Matrix);
160
161    let row_dim = matrix.nrows();
162    let col_dim = matrix.ncols();
163
164    let array_tokens = matrix.to_col_major_flat_array_tokens();
165
166    //  TODO: Use quote_spanned instead??
167    let output = quote! {
168        nalgebra::DMatrix::<_>
169            ::from_vec_storage(nalgebra::VecStorage::new(
170                nalgebra::Dyn(#row_dim),
171                nalgebra::Dyn(#col_dim),
172                vec!#array_tokens))
173    };
174
175    proc_macro::TokenStream::from(output)
176}
177
178pub fn vector_impl(stream: TokenStream) -> TokenStream {
179    let vector = parse_macro_input!(stream as Vector);
180    let len = vector.len();
181    let array_tokens = vector.to_array_tokens();
182    let output = quote! {
183        nalgebra::SVector::<_, #len>
184            ::from_array_storage(nalgebra::ArrayStorage([#array_tokens]))
185    };
186    proc_macro::TokenStream::from(output)
187}
188
189pub fn dvector_impl(stream: TokenStream) -> TokenStream {
190    let vector = parse_macro_input!(stream as Vector);
191    let len = vector.len();
192    let array_tokens = vector.to_array_tokens();
193    let output = quote! {
194        nalgebra::DVector::<_>
195            ::from_vec_storage(nalgebra::VecStorage::new(
196                nalgebra::Dyn(#len),
197                nalgebra::Const::<1>,
198                vec!#array_tokens))
199    };
200    proc_macro::TokenStream::from(output)
201}