1use crate::Matrix;
4use crate::matrix::Vector;
5use crate::mlkem::{N, q, q_inv};
6use crate::polynomial::Polynomial;
7use bouncycastle_core::traits::XOF;
8use bouncycastle_sha3::{SHAKE128, SHAKE256};
9
10pub(crate) fn expandA<const k: usize>(rho: &[u8; 32]) -> Matrix<k, k> {
11 let mut A_hat = Matrix::<k, k>::new();
12 for i in 0..k {
13 for j in 0..k {
15 A_hat[i][j] = sample_ntt(rho, &[j as u8, i as u8]);
18 }
19 }
20
21 A_hat
22}
23
24pub fn byte_encode<const d: usize, const PACK_LEN: usize>(F: &Polynomial) -> [u8; PACK_LEN] {
30 debug_assert_eq!(PACK_LEN, 32 * d);
31
32 let mut B = [0u8; PACK_LEN];
33
34 for i in 0..N {
35 let mut alpha = F[i];
36
37 alpha = barrett_reduce(alpha);
40
41 for j in 0..d {
42 let tmp = (alpha & 1) as u8;
44
45 B[(i * d + j) / 8] |= tmp << ((i * d + j) % 8);
49
50 alpha >>= 1;
59 }
60 }
61
62 B
63}
64
65pub fn byte_decode<const d: usize, const PACK_LEN: usize>(B: &[u8; PACK_LEN]) -> Polynomial {
71 debug_assert_eq!(PACK_LEN, 32 * d);
72
73 let mut F = Polynomial::new();
74
75 for i in 0..N {
76 for j in 0..d {
78 F[i] |= (((B[(i * d + j) / 8] >> (i * d + j) % 8) & 1) as i16) << j; }
81 }
82
83 F
84}
85
86pub fn sample_ntt(rho: &[u8; 32], nonce: &[u8; 2]) -> Polynomial {
92 let mut a_hat = Polynomial::new();
93
94 let mut xof = SHAKE128::new();
97 xof.absorb(rho);
98 xof.absorb(nonce);
99
100 let mut j = 0usize;
102
103 let mut C = [0u8; 288];
109 xof.squeeze_out(&mut C);
110 let mut idx: usize = 0;
111
112 while j < N {
114 if idx == C.len() {
117 xof.squeeze_out(&mut C);
118 idx = 0;
119 }
120
121 let d1: i16 = (C[idx + 0] as i16) | ((C[idx + 1] as i32) << 8) as i16 & 0xFFF;
124 debug_assert!(d1 < 2 << 12);
125
126 let d2: i16 = ((C[idx + 1] as i16) >> 4) | ((C[idx + 2] as i32) << 4) as i16 & 0xFFF;
129 debug_assert!(d2 < 2 << 12);
130
131 if d1 < q {
137 a_hat[j] = d1;
138 j += 1;
139 }
140
141 if d2 < q && j < N {
146 a_hat[j] = d2;
147 j += 1;
148 }
149
150 idx += 3;
151 }
152
153 a_hat
154}
155
156pub fn sample_poly_cbd<const eta: i16>(bytes: &[u8]) -> Polynomial {
162 debug_assert_eq!(bytes.len(), 64 * eta as usize);
163
164 let mut f = Polynomial::new();
165
166 match eta {
167 2 => {
168 for i in 0..N / 8 {
169 let t = u32::from_le_bytes(bytes[4 * i..4 * i + 4].try_into().unwrap());
170 let mut d = t & 0x55555555;
171 d += (t >> 1) & 0x55555555;
172 for j in 0..8usize {
173 let a = ((d >> (4 * j)) & 0x3) as i16;
174 let b = ((d >> (4 * j + eta as usize)) & 0x3) as i16;
175 f[8 * i + j] = a - b;
176
177 debug_assert!(-eta <= f[8 * i + j] && f[8 * i + j] <= eta);
180 }
181 }
182 }
183 3 => {
184 for i in 0..N / 4 {
185 let t = little_endian_to_u24(bytes, 3 * i);
186 let mut d = t & 0x00249249;
187 d += (t >> 1) & 0x00249249;
188 d += (t >> 2) & 0x00249249;
189 for j in 0..4usize {
190 let a = ((d >> (6 * j)) & 0x7) as i16;
191 let b = ((d >> (6 * j + eta as usize)) & 0x7) as i16;
192 f[4 * i + j] = a - b;
193
194 debug_assert!(-eta <= f[4 * i + j] && f[4 * i + j] <= eta);
197 }
198 }
199 }
200 _ => unreachable!("Wrong Eta"),
201 }
202
203 f
204}
205
206pub(crate) fn sample_poly_CBD<const eta: i16>(b: &[u8; 32], n: u8) -> Polynomial {
209 match eta {
212 2 => {
213 let buf = {
214 let mut xof = SHAKE256::new();
215 xof.absorb(b);
216 xof.absorb(&n.to_le_bytes());
217
218 let mut buf = [0u8; 2 * 64];
219 xof.squeeze_out(&mut buf);
220 buf
221 };
222
223 sample_poly_cbd::<eta>(&buf)
224 }
225 3 => {
226 let buf = {
227 let mut xof = SHAKE256::new();
228 xof.absorb(b);
229 xof.absorb(&n.to_le_bytes());
230 let mut buf = [0u8; 3 * 64];
231 xof.squeeze_out(&mut buf);
232 buf
233 };
234
235 sample_poly_cbd::<eta>(&buf)
236 }
237 _ => unreachable!(),
238 }
239}
240
241pub(crate) fn sample_vector_CBD<const k: usize, const eta: i16>(
243 b: &[u8; 32],
244 mut n: u8,
245) -> Vector<k> {
246 let mut v = Vector::<k>::new();
247
248 for i in 0..k {
249 v[i] = sample_poly_CBD::<eta>(b, n);
250
251 n += 1;
253 }
254
255 v
256}
257
258fn little_endian_to_u24(bs: &[u8], off: usize) -> u32 {
259 let mut n = bs[off] as u32;
260 n |= (bs[off + 1] as u32) << 8;
261 n | (bs[off + 2] as u32) << 16
262}
263
264pub(crate) const ZETAS: [i16; 128] = [
265 2285, 2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468, 573,
266 2004, 264, 383, 2500, 1458, 1727, 3199, 2648, 1017, 732, 608, 1787, 411, 3124, 1758, 1223, 652,
267 2777, 1015, 2036, 1491, 3047, 1785, 516, 3321, 3009, 2663, 1711, 2167, 126, 1469, 2476, 3239,
268 3058, 830, 107, 1908, 3082, 2378, 2931, 961, 1821, 2604, 448, 2264, 677, 2054, 2226, 430, 555,
269 843, 2078, 871, 1550, 105, 422, 587, 177, 3094, 3038, 2869, 1574, 1653, 3083, 778, 1159, 3182,
270 2552, 1483, 2727, 1119, 1739, 644, 2457, 349, 418, 329, 3173, 3254, 817, 1097, 603, 610, 1322,
271 2044, 1864, 384, 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, 2144, 1799, 2051, 794, 1819,
272 2475, 2459, 478, 3221, 3021, 996, 991, 958, 1869, 1522, 1628,
273];
274
275pub(crate) const ZETAS_INV: [i16; 128] = [
276 1701, 1807, 1460, 2371, 2338, 2333, 308, 108, 2851, 870, 854, 1510, 2535, 1278, 1530, 1185,
277 1659, 1187, 3109, 874, 1335, 2111, 136, 1215, 2945, 1465, 1285, 2007, 2719, 2726, 2232, 2512,
278 75, 156, 3000, 2911, 2980, 872, 2685, 1590, 2210, 602, 1846, 777, 147, 2170, 2551, 246, 1676,
279 1755, 460, 291, 235, 3152, 2742, 2907, 3224, 1779, 2458, 1251, 2486, 2774, 2899, 1103, 1275,
280 2652, 1065, 2881, 725, 1508, 2368, 398, 951, 247, 1421, 3222, 2499, 271, 90, 853, 1860, 3203,
281 1162, 1618, 666, 320, 8, 2813, 1544, 282, 1838, 1293, 2314, 552, 2677, 2106, 1571, 205, 2918,
282 1542, 2721, 2597, 2312, 681, 130, 1602, 1871, 829, 2946, 3065, 1325, 2756, 1861, 1474, 1202,
283 2367, 3147, 1752, 2707, 171, 3127, 3042, 1907, 1836, 1517, 359, 758, 1441,
284];
285
286pub(crate) fn mul_mont(a: i16, b: i16) -> i16 {
287 montgomery_reduce((a as i32) * (b as i32))
288}
289
290pub(crate) fn montgomery_reduce(a: i32) -> i16 {
291 let u = a.wrapping_mul(q_inv) as i16;
292 let mut t = (u as i32) * q as i32;
293 t = a - t;
294 t >>= 16;
295 t as i16
296}
297
298pub(crate) fn barrett_reduce(a: i16) -> i16 {
299 let v = (((1u32 << 26) + ((q / 2) as u32)) / (q as u32)) as i16;
300 let t = (((v as i32) * (a as i32)) >> 26) as i16;
301 a - (((t as i32) * q as i32) as i16)
302}
303
304pub(super) fn cond_sub_q(a: i16) -> i16 {
305 let tmp = a - q;
306 tmp + ((tmp >> 15) & q)
307}
308
309pub(crate) fn ntt_base_mult(
315 r: &mut [i16],
316 off: usize,
317 a0: i16,
318 a1: i16,
319 b0: i16,
320 b1: i16,
321 zeta: i16,
322) {
323 let mut out_val0 = mul_mont(a1, b1);
324 out_val0 = mul_mont(out_val0, zeta);
325 out_val0 += mul_mont(a0, b0);
326 r[off] = out_val0;
327
328 let mut out_val1 = mul_mont(a0, b1);
329 out_val1 += mul_mont(a1, b0);
330 r[off + 1] = out_val1;
331}
332
333pub(crate) fn pack_ciphertext<const k: usize, const CT_LEN: usize, const du: i16, const dv: i16>(
334 u: &Vector<k>,
335 v: &Polynomial,
336) -> [u8; CT_LEN] {
337 let mut out = [0u8; CT_LEN];
338
339 let lim: usize = k * (N * (du as usize) / 8);
341
342 u.compress_pol_vec::<du>(&mut out[..lim]);
343 v.compress_poly::<dv>(&mut out[lim..]);
344 out
345}
346
347pub(crate) fn unpack_ciphertext_u<
348 const k: usize,
349 const CT_LEN: usize,
350 const du: i16,
351 const dv: i16,
352>(
353 c: &[u8; CT_LEN],
354) -> Vector<k> {
355 let lim: usize = k * (N * (du as usize) / 8);
357
358 let u = Vector::<k>::decompress_pol_vec::<du>(&c[..lim]);
359
360 u
361}
362
363pub(crate) fn unpack_ciphertext_v<
364 const k: usize,
365 const CT_LEN: usize,
366 const du: i16,
367 const dv: i16,
368>(
369 c: &[u8; CT_LEN],
370) -> Polynomial {
371 let lim: usize = k * (N * (du as usize) / 8);
373
374 let v = Polynomial::decompress_poly::<dv>(&c[lim..]);
375
376 v
377}