Skip to main content

bouncycastle_mlkem/
matrix.rs

1//! These are somewhat unnecessary wrappers around simple arrays, but they are helpful to me in clearly
2//! keeping the types and sizes obvious.
3
4use core::ops::{Index, IndexMut};
5
6use crate::mlkem::{q, N};
7use crate::polynomial;
8use crate::polynomial::{Polynomial};
9
10#[derive(Clone)]
11/// A matrix over the ML-KEM ring.
12pub struct Matrix<const k: usize, const l: usize>{ /*pub(crate)*/ mat: [[Polynomial; l]; k] }
13
14/// Convenience function to avoid ".0" all over the place.
15impl<const k: usize, const l: usize> Index<usize> for Matrix<k,l> {
16    type Output = [Polynomial; l];
17
18    fn index(&self, index: usize) -> &Self::Output {
19        &self.mat[index]
20    }
21}
22/// Convenience function to avoid ".0" all over the place.
23impl<const k: usize, const l: usize> IndexMut<usize> for Matrix<k,l> {
24    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
25        &mut self.mat[index]
26    }
27}
28
29impl<const k: usize, const l: usize> Matrix<k, l> {
30    pub(crate) fn new() -> Self {
31        Self{ mat: [[(); l]; k].map(|_| [(); l].map(|_| Polynomial::new())) }
32    }
33
34    /// FIPS 204 Algorithm 48 MatrixVectorNTT(𝐌, 𝐯)
35    /// Computes the product 𝐌 ∘̂ 𝐯_hat of a matrix 𝐌_hat and a vector 𝐯_hat over 𝑇𝑞.
36    /// Input: 𝑘, ℓ ∈ ℕ, 𝐌 ∈ 𝑇𝑞 𝑘×ℓ
37    /// Performs dot product multiplication of this matrix by a vector
38    /// Input: vector of length l
39    /// Output: vector of length k
40    ///
41    /// transpose: False will multiply A, where as True will multiply A^T
42    pub(crate) fn matrix_vector_ntt<const transpose: bool>(&self, v: &Vector<l>) -> Vector<k> {
43        let mut w = Vector::<k>::new();
44        for i in 0 .. k {
45            // split out the 0 case to skip a no-op add_ntt()
46            w[i] = if transpose{
47                polynomial::base_mult_montgomery(&self.mat[0][i], &v[0])
48            } else {
49                polynomial::base_mult_montgomery(&self.mat[i][0], &v[0])
50            };
51
52            let mut w1: Polynomial;
53            for j in 1 .. l {
54                // dot product a vector into a matrix: multiply the input vector
55                // into each row of the matrix, then sum the results to produce a vector of
56                // length k.
57                w1 = if transpose {
58                    polynomial::base_mult_montgomery(&self.mat[j][i], &v[j])
59                } else {
60                    polynomial::base_mult_montgomery(&self.mat[i][j], &v[j])
61                };
62
63                w[i].add(&w1);
64            }
65        }
66
67        // In the non-transposed case (keygen), we act in montgomery domain; otherwise (encaps / decaps) we reduce normally.
68        if transpose {
69            w.reduce();
70        } else {
71            w.convert_to_mont();
72        }
73
74        w
75    }
76}
77
78// Matrix and Vector do not need to impl Secret because the actual data is in the polynomials, which have their own zeroizing drop.
79// Technically all matrices and some vectors are only part of the public key and might not need to be zeroized,
80// but I'll leave it zeroizing for now and leave this as a potential future optimization.
81
82
83#[derive(Clone)]
84pub(crate) struct Vector<const k: usize>{ pub(crate) vec: [Polynomial; k] }
85
86/// Convenience function to avoid ".0" all over the place.
87impl<const k: usize> Index<usize> for Vector<k> {
88    type Output = Polynomial;
89
90    fn index(&self, index: usize) -> &Self::Output {
91        &self.vec[index]
92    }
93}
94/// Convenience function to avoid ".0" all over the place.
95impl<const k: usize> IndexMut<usize> for Vector<k> {
96    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
97        &mut self.vec[index]
98    }
99}
100
101impl<const k: usize> Vector<k>
102{
103    pub(crate) fn new() -> Self {
104        Self {vec: [(); k].map(|_| Polynomial::new()) }
105    }
106
107    /// Algorithm 46 AddVectorNTT(𝐯, 𝐰)̂
108    /// Computes the sum 𝐯_hat + 𝐰_hat of two vectors 𝐯_hat, 𝐰_hat over 𝑇𝑞.
109    /// Input: ℓ ∈ ℕ, v_hat ∈ T^ℓ, w_hat ∈ 𝑇^ℓ
110    /// Output: u_hat ∈ T^ℓ_𝑞.
111    /// Add another vector to this vector
112    pub(crate) fn add_vector_ntt(&mut self, s: &Self) {
113        for i in 0 ..k {
114            // perform montgomery addition of each polynomial in the vector
115            self[i].add(&s[i]);
116        }
117    }
118
119    pub(crate) fn dot_product(&self, v: &Self) -> Polynomial {
120        // split out the 0 case to skip a no-op add_ntt()
121        let mut w = polynomial::base_mult_montgomery(&self[0], &v[0]);
122
123        for i in 1..k {
124            let w1 = polynomial::base_mult_montgomery(&self[i], &v[i]);
125            w.add(&w1);
126        }
127        // in theory, we need this here, but all unit tests pass without it since
128        // it actually doesn't matter if you go outside the [0, q] range as long as you
129        // reduce down before encoding out.
130        // w.poly_reduce();
131
132        w
133    }
134
135    pub(crate) fn reduce(&mut self) {
136        for i in 0..k {
137            self[i].poly_reduce();
138        }
139    }
140
141    pub(crate) fn ntt(&mut self){
142        for i in 0..k {
143            self[i].ntt();
144        }
145    }
146
147    pub(crate) fn inv_ntt(&mut self) {
148        for i in 0..k {
149            self[i].inv_ntt();
150        }
151    }
152
153    pub(crate) fn convert_to_mont(&mut self) {
154        for i in 0 ..k {
155            self[i].convert_to_mont();
156        }
157    }
158
159    /// This is an optimized version of
160    ///   ByteEncode_𝑑𝑢( Compress_𝑑𝑢(𝐮) )
161    /// which packs a polynomial vector according to the packing coefficient dv
162    pub(crate) fn compress_pol_vec<const du: i16>(&self, out: &mut [u8]) {
163        // make sure we have received a dv
164        assert!(du == 10 || du == 11);
165
166        // make sure we were given the right size output buffer
167        // each of the N i16's will take dv bits
168        debug_assert_eq!(out.len(), k *(N * (du as usize) / 8));
169
170        // bc-java has a conditional_sub_q() here, but I pass all unit tests without it, so I'm taking it out for performance.
171        // let mut s = self.clone();
172        // s.conditional_sub_q();
173
174        let mut idx = 0;
175        match du {
176            10 => { // MLKEM512 and MLKEM 768
177                let mut t = [0i16; 4];
178                for i in 0..k {
179                    for j in 0..N/4 {
180                        // fill the temp array t
181                        for (l, item) in t.iter_mut().enumerate() {
182                            *item = (((((self[i][4 * j + l] as u32) << 10) as i32
183                                + (q as i32 / 2))
184                                / q as i32)
185                                & 0x3FF) as i16;
186                        }
187
188                        out[idx] = t[0] as u8;
189                        out[idx + 1] = ((t[0] >> 8) | (t[1] << 2)) as u8;
190                        out[idx + 2] = ((t[1] >> 6) | (t[2] << 4)) as u8;
191                        out[idx + 3] = ((t[2] >> 4) | (t[3] << 6)) as u8;
192                        out[idx + 4] = (t[3] >> 2) as u8;
193                        idx += 5;
194                    }
195                }
196            },
197            11 => {
198                let mut t = [0i16; 8];
199                for i in 0..k {
200                    for j in 0..N/8 {
201                        for (l, item) in t.iter_mut().enumerate() {
202                            *item = (((((self[i][8 * j + l] as u32) << 11) as i32
203                                + (q as i32 / 2))
204                                / q as i32)
205                                & 0x7FF) as i16;
206                        }
207
208                        out[idx] = t[0] as u8;
209                        out[idx + 1] = ((t[0] >> 8) | (t[1] << 3)) as u8;
210                        out[idx + 2] = ((t[1] >> 5) | (t[2] << 6)) as u8;
211                        out[idx + 3] = (t[2] >> 2) as u8;
212                        out[idx + 4] = ((t[2] >> 10) | (t[3] << 1)) as u8;
213                        out[idx + 5] = ((t[3] >> 7) | (t[4] << 4)) as u8;
214                        out[idx + 6] = ((t[4] >> 4) | (t[5] << 7)) as u8;
215                        out[idx + 7] = (t[5] >> 1) as u8;
216                        out[idx + 8] = ((t[5] >> 9) | (t[6] << 2)) as u8;
217                        out[idx + 9] = ((t[6] >> 6) | (t[7] << 5)) as u8;
218                        out[idx + 10] = (t[7] >> 3) as u8;
219                        idx += 11;
220                    }
221                }
222            },
223            _ => unreachable!(),
224        }
225    }
226
227    pub(crate) fn decompress_pol_vec<const du: i16>(compressed_u: &[u8]) -> Vector<k> {
228        let mut u = Vector::<k>::new();
229
230        // make sure we have received a dv
231        assert!(du == 10 || du == 11);
232
233        // make sure we were given the right size output buffer
234        // each of the N i16's will take dv bits
235        debug_assert_eq!(compressed_u.len(), k *(N * (du as usize) / 8));
236
237        let mut idx = 0;
238
239        match du {
240            10 => { // MLKEM512 and MLKEM768
241                let mut t = [0i16; 4];
242                for i in 0..k {
243                    for j in 0..(N/4) {
244                        t[0] = ((compressed_u[idx] as u16)
245                            | (compressed_u[idx + 1] as u16) << 8)
246                            as i16;
247                        t[1] = (((compressed_u[idx + 1] as u16) >> 2)
248                            | (compressed_u[idx + 2] as u16) << 6)
249                            as i16;
250                        t[2] = (((compressed_u[idx + 2] as u16) >> 4)
251                            | (compressed_u[idx + 3] as u16) << 4)
252                            as i16;
253                        t[3] = (((compressed_u[idx + 3] as u16) >> 6)
254                            | (compressed_u[idx + 4] as u16) << 2)
255                            as i16;
256                        idx += 5;
257                        for (l, item) in t.iter().enumerate() {
258                            u[i][4 * j + l] =
259                                ((((*item & 0x3FF) as i32) * (q as i32) + 512) >> 10) as i16;
260                        }
261                    }
262                }
263        },
264            11 => { // MLKEM1024
265                let mut t = [0i16; 8];
266                for i in 0..k {
267                    for j in 0..N/8 {
268                        t[0] = (compressed_u[idx] as i32
269                            | ((compressed_u[idx + 1] as u16) as i32) << 8)
270                            as i16;
271                        t[1] = ((compressed_u[idx + 1] >> 3) as i32
272                            | ((compressed_u[idx + 2] as u16) as i32) << 5)
273                            as i16;
274                        t[2] = ((compressed_u[idx + 2] >> 6) as i32
275                            | ((compressed_u[idx + 3] as u16) as i32) << 2
276                            | (((compressed_u[idx + 4] as i32) << 10) as u16) as i32)
277                            as i16;
278                        t[3] = ((compressed_u[idx + 4] >> 1) as i32
279                            | ((compressed_u[idx + 5] as u16) as i32) << 7)
280                            as i16;
281                        t[4] = ((compressed_u[idx + 5] >> 4) as i32
282                            | ((compressed_u[idx + 6] as u16) as i32) << 4)
283                            as i16;
284                        t[5] = ((compressed_u[idx + 6] >> 7) as i32
285                            | ((compressed_u[idx + 7] as u16) as i32) << 1
286                            | (((compressed_u[idx + 8] as i32) << 9) as u16) as i32)
287                            as i16;
288                        t[6] = ((compressed_u[idx + 8] >> 2) as i32
289                            | ((compressed_u[idx + 9] as u16) as i32) << 6)
290                            as i16;
291                        t[7] = ((compressed_u[idx + 9] >> 5) as i32
292                            | ((compressed_u[idx + 10] as u16) as i32) << 3)
293                            as i16;
294                        idx += 11;
295                        for (l, item) in t.iter().enumerate() {
296                            u[i][8 * j + l] =
297                                ((((*item & 0x7FF) as i32) * (q as i32) + 1024) >> 11) as i16;
298                        }
299                    }
300                }
301            },
302            _ => unreachable!(),
303        }
304
305        u
306    }
307}