bouncycastle_mlkem/
matrix.rs1use core::ops::{Index, IndexMut};
5
6use crate::mlkem::{q, N};
7use crate::polynomial;
8use crate::polynomial::{Polynomial};
9
10#[derive(Clone)]
11pub struct Matrix<const k: usize, const l: usize>{ mat: [[Polynomial; l]; k] }
13
14impl<const k: usize, const l: usize> Index<usize> for Matrix<k,l> {
16 type Output = [Polynomial; l];
17
18 fn index(&self, index: usize) -> &Self::Output {
19 &self.mat[index]
20 }
21}
22impl<const k: usize, const l: usize> IndexMut<usize> for Matrix<k,l> {
24 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
25 &mut self.mat[index]
26 }
27}
28
29impl<const k: usize, const l: usize> Matrix<k, l> {
30 pub(crate) fn new() -> Self {
31 Self{ mat: [[(); l]; k].map(|_| [(); l].map(|_| Polynomial::new())) }
32 }
33
34 pub(crate) fn matrix_vector_ntt<const transpose: bool>(&self, v: &Vector<l>) -> Vector<k> {
43 let mut w = Vector::<k>::new();
44 for i in 0 .. k {
45 w[i] = if transpose{
47 polynomial::base_mult_montgomery(&self.mat[0][i], &v[0])
48 } else {
49 polynomial::base_mult_montgomery(&self.mat[i][0], &v[0])
50 };
51
52 let mut w1: Polynomial;
53 for j in 1 .. l {
54 w1 = if transpose {
58 polynomial::base_mult_montgomery(&self.mat[j][i], &v[j])
59 } else {
60 polynomial::base_mult_montgomery(&self.mat[i][j], &v[j])
61 };
62
63 w[i].add(&w1);
64 }
65 }
66
67 if transpose {
69 w.reduce();
70 } else {
71 w.convert_to_mont();
72 }
73
74 w
75 }
76}
77
78#[derive(Clone)]
84pub(crate) struct Vector<const k: usize>{ pub(crate) vec: [Polynomial; k] }
85
86impl<const k: usize> Index<usize> for Vector<k> {
88 type Output = Polynomial;
89
90 fn index(&self, index: usize) -> &Self::Output {
91 &self.vec[index]
92 }
93}
94impl<const k: usize> IndexMut<usize> for Vector<k> {
96 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
97 &mut self.vec[index]
98 }
99}
100
101impl<const k: usize> Vector<k>
102{
103 pub(crate) fn new() -> Self {
104 Self {vec: [(); k].map(|_| Polynomial::new()) }
105 }
106
107 pub(crate) fn add_vector_ntt(&mut self, s: &Self) {
113 for i in 0 ..k {
114 self[i].add(&s[i]);
116 }
117 }
118
119 pub(crate) fn dot_product(&self, v: &Self) -> Polynomial {
120 let mut w = polynomial::base_mult_montgomery(&self[0], &v[0]);
122
123 for i in 1..k {
124 let w1 = polynomial::base_mult_montgomery(&self[i], &v[i]);
125 w.add(&w1);
126 }
127 w
133 }
134
135 pub(crate) fn reduce(&mut self) {
136 for i in 0..k {
137 self[i].poly_reduce();
138 }
139 }
140
141 pub(crate) fn ntt(&mut self){
142 for i in 0..k {
143 self[i].ntt();
144 }
145 }
146
147 pub(crate) fn inv_ntt(&mut self) {
148 for i in 0..k {
149 self[i].inv_ntt();
150 }
151 }
152
153 pub(crate) fn convert_to_mont(&mut self) {
154 for i in 0 ..k {
155 self[i].convert_to_mont();
156 }
157 }
158
159 pub(crate) fn compress_pol_vec<const du: i16>(&self, out: &mut [u8]) {
163 assert!(du == 10 || du == 11);
165
166 debug_assert_eq!(out.len(), k *(N * (du as usize) / 8));
169
170 let mut idx = 0;
175 match du {
176 10 => { let mut t = [0i16; 4];
178 for i in 0..k {
179 for j in 0..N/4 {
180 for (l, item) in t.iter_mut().enumerate() {
182 *item = (((((self[i][4 * j + l] as u32) << 10) as i32
183 + (q as i32 / 2))
184 / q as i32)
185 & 0x3FF) as i16;
186 }
187
188 out[idx] = t[0] as u8;
189 out[idx + 1] = ((t[0] >> 8) | (t[1] << 2)) as u8;
190 out[idx + 2] = ((t[1] >> 6) | (t[2] << 4)) as u8;
191 out[idx + 3] = ((t[2] >> 4) | (t[3] << 6)) as u8;
192 out[idx + 4] = (t[3] >> 2) as u8;
193 idx += 5;
194 }
195 }
196 },
197 11 => {
198 let mut t = [0i16; 8];
199 for i in 0..k {
200 for j in 0..N/8 {
201 for (l, item) in t.iter_mut().enumerate() {
202 *item = (((((self[i][8 * j + l] as u32) << 11) as i32
203 + (q as i32 / 2))
204 / q as i32)
205 & 0x7FF) as i16;
206 }
207
208 out[idx] = t[0] as u8;
209 out[idx + 1] = ((t[0] >> 8) | (t[1] << 3)) as u8;
210 out[idx + 2] = ((t[1] >> 5) | (t[2] << 6)) as u8;
211 out[idx + 3] = (t[2] >> 2) as u8;
212 out[idx + 4] = ((t[2] >> 10) | (t[3] << 1)) as u8;
213 out[idx + 5] = ((t[3] >> 7) | (t[4] << 4)) as u8;
214 out[idx + 6] = ((t[4] >> 4) | (t[5] << 7)) as u8;
215 out[idx + 7] = (t[5] >> 1) as u8;
216 out[idx + 8] = ((t[5] >> 9) | (t[6] << 2)) as u8;
217 out[idx + 9] = ((t[6] >> 6) | (t[7] << 5)) as u8;
218 out[idx + 10] = (t[7] >> 3) as u8;
219 idx += 11;
220 }
221 }
222 },
223 _ => unreachable!(),
224 }
225 }
226
227 pub(crate) fn decompress_pol_vec<const du: i16>(compressed_u: &[u8]) -> Vector<k> {
228 let mut u = Vector::<k>::new();
229
230 assert!(du == 10 || du == 11);
232
233 debug_assert_eq!(compressed_u.len(), k *(N * (du as usize) / 8));
236
237 let mut idx = 0;
238
239 match du {
240 10 => { let mut t = [0i16; 4];
242 for i in 0..k {
243 for j in 0..(N/4) {
244 t[0] = ((compressed_u[idx] as u16)
245 | (compressed_u[idx + 1] as u16) << 8)
246 as i16;
247 t[1] = (((compressed_u[idx + 1] as u16) >> 2)
248 | (compressed_u[idx + 2] as u16) << 6)
249 as i16;
250 t[2] = (((compressed_u[idx + 2] as u16) >> 4)
251 | (compressed_u[idx + 3] as u16) << 4)
252 as i16;
253 t[3] = (((compressed_u[idx + 3] as u16) >> 6)
254 | (compressed_u[idx + 4] as u16) << 2)
255 as i16;
256 idx += 5;
257 for (l, item) in t.iter().enumerate() {
258 u[i][4 * j + l] =
259 ((((*item & 0x3FF) as i32) * (q as i32) + 512) >> 10) as i16;
260 }
261 }
262 }
263 },
264 11 => { let mut t = [0i16; 8];
266 for i in 0..k {
267 for j in 0..N/8 {
268 t[0] = (compressed_u[idx] as i32
269 | ((compressed_u[idx + 1] as u16) as i32) << 8)
270 as i16;
271 t[1] = ((compressed_u[idx + 1] >> 3) as i32
272 | ((compressed_u[idx + 2] as u16) as i32) << 5)
273 as i16;
274 t[2] = ((compressed_u[idx + 2] >> 6) as i32
275 | ((compressed_u[idx + 3] as u16) as i32) << 2
276 | (((compressed_u[idx + 4] as i32) << 10) as u16) as i32)
277 as i16;
278 t[3] = ((compressed_u[idx + 4] >> 1) as i32
279 | ((compressed_u[idx + 5] as u16) as i32) << 7)
280 as i16;
281 t[4] = ((compressed_u[idx + 5] >> 4) as i32
282 | ((compressed_u[idx + 6] as u16) as i32) << 4)
283 as i16;
284 t[5] = ((compressed_u[idx + 6] >> 7) as i32
285 | ((compressed_u[idx + 7] as u16) as i32) << 1
286 | (((compressed_u[idx + 8] as i32) << 9) as u16) as i32)
287 as i16;
288 t[6] = ((compressed_u[idx + 8] >> 2) as i32
289 | ((compressed_u[idx + 9] as u16) as i32) << 6)
290 as i16;
291 t[7] = ((compressed_u[idx + 9] >> 5) as i32
292 | ((compressed_u[idx + 10] as u16) as i32) << 3)
293 as i16;
294 idx += 11;
295 for (l, item) in t.iter().enumerate() {
296 u[i][8 * j + l] =
297 ((((*item & 0x7FF) as i32) * (q as i32) + 1024) >> 11) as i16;
298 }
299 }
300 }
301 },
302 _ => unreachable!(),
303 }
304
305 u
306 }
307}