bouncycastle_mlkem_lowmemory/polynomial.rs
1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6use bouncycastle_core::traits::Secret;
7use crate::aux_functions::{barrett_reduce, cond_sub_q, montgomery_reduce, mul_mont, ntt_base_mult, ZETAS, ZETAS_INV};
8use crate::mlkem::{N, q};
9
10/// A polynomial over the ML-KEM ring.
11/// Dev note: this doesn't strictly need to be pub ... ie there's no good reason for a caller to use this class directly,
12/// but in order to test the Debug and Display traits, you need STD, so those can't be tested from inline tests in this file
13/// and the real unit tests are in a different crate, so here we are.
14#[derive(Clone)]
15pub struct Polynomial{ pub(crate) coeffs: [i16; N] }
16
17/// Convenience function to avoid ".0" all over the place.
18impl Index<usize> for Polynomial {
19 type Output = i16;
20
21 fn index(&self, index: usize) -> &Self::Output {
22 &self.coeffs[index]
23 }
24}
25/// Convenience function to avoid ".0" all over the place.
26impl IndexMut<usize> for Polynomial {
27 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
28 &mut self.coeffs[index]
29 }
30}
31
32impl Polynomial {
33 /// Create a new polynomial with all coefficients set to zero.
34 pub const fn new() -> Self {
35 Self { coeffs: [0i16; N] }
36 }
37
38 /// Create a Polynomial from the message m
39 pub(crate) fn from_msg(m: [u8; 32]) -> Self {
40 let mut w = Polynomial::new();
41
42 for (i, b) in m.iter().enumerate().take(N/8) {
43 for j in 0..8 {
44 let mask = -(((*b >> j) & 1) as i16);
45 w[8 * i + j] = mask /*as i32*/ & ((q + 1) / 2);
46 }
47 }
48
49 w
50 }
51
52 /// Convert a Polynomial back into a message m
53 pub(crate) fn to_msg(mut self) -> [u8; 32] {
54
55 const LOWER: i32 = q as i32 >> 2; // 832
56 const UPPER: i32 = q as i32 - LOWER; // 2497
57
58 let mut msg = [0u8; 32];
59
60 // you would expect to use a full reduce() here, but since this is data coming from
61 // out matrix math and not from an attacker, we can get away with the lighter cond_sub_q()
62 self.cond_sub_q();
63
64 // for (i, item) in msg.iter_mut().enumerate().take(N/8) {
65 for i in 0 .. N/8 {
66 for j in 0..8 {
67 let c_j = self[8 * i + j] as i32;
68 let t = (((LOWER - c_j) & (c_j - UPPER)) >> 31) & 0x0000000000000001;
69 msg[i] |= (t << j) as u8;
70 }
71 }
72
73 msg
74 }
75
76 // not currently used, but I'll leave it here because it's useful for debugging if you want to output values
77 // that are normalized to [0,q] to compare against intermediate results from other libraries.
78 // pub(crate) fn conditional_add_q(&mut self) {
79 // for x in self.0.iter_mut() {
80 // *x = conditional_add_q(*x);
81 // }
82 // }
83
84 pub(crate) fn add(&mut self, w: &Self) {
85 for i in 0..N {
86 self[i] += w[i];
87 }
88 }
89
90 pub(crate) fn sub(&mut self, w: &Self) {
91 for i in 0..N {
92 self[i] -= w[i];
93 }
94 }
95
96 /// Multiplication of two polynomials in NTT domain
97 ///
98 /// Borrowed from:
99 /// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L290
100 pub(crate) fn base_mult_montgomery(&mut self, b: &Polynomial) {
101 for i in 0..(N/4) {
102 let a1: i16 = self[4 * i];
103 let a2: i16 = self[4 * i + 1];
104 ntt_base_mult(
105 &mut self.coeffs,
106 4 * i,
107 a1,
108 a2,
109 b[4 * i],
110 b[4 * i + 1],
111 ZETAS[64 + i],
112 );
113
114 let a1: i16 = self[4 * i + 2];
115 let a2: i16 = self[4 * i + 3];
116 ntt_base_mult(
117 &mut self.coeffs,
118 4 * i + 2,
119 a1,
120 a2,
121 b[4 * i + 2],
122 b[4 * i + 3],
123 -ZETAS[64 + i],
124 );
125 }
126 }
127
128 pub(crate) fn poly_reduce(&mut self) {
129 for i in 0..N {
130 self[i] = barrett_reduce(self[i]);
131 }
132 }
133
134 /// In-place conversion of all coefficients of a polynomial
135 /// from normal domain to Montgomery domain
136 ///
137 /// Borrowed from:
138 /// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L307
139 pub(crate) fn convert_to_mont(&mut self) {
140 const F: i16 = ((1u64 << 32) % q as u64) as i16;
141 for i in 0..N {
142 self[i] = montgomery_reduce((self[i] as i32) * (F as i32));
143 }
144 }
145
146 /// This is an optimized version of
147 /// ByteEncode_𝑑𝑣( Compress_𝑑𝑣(𝑣) )
148 /// which packs a single polynomial according to the packing coefficient dv
149 pub(crate) fn compress_poly<const dv: i16>(&self, out: &mut [u8]) {
150 // make sure we have received a dv
151 debug_assert!(dv == 4 || dv == 5);
152
153 // make sure we were given the right size output buffer
154 // each of the N i16's will take dv bits
155 debug_assert_eq!(out.len(), N * (dv as usize) / 8);
156
157 let mut t = [0u8; 8];
158 let mut idx = 0;
159
160 // bc-java has a cond_sub_q() here, but unit tests show that we don't need it.
161 // let mut s = self.clone();
162 // s.cond_sub_q();
163
164 match dv {
165 4 => { // MLKEM512 and MLKEM768
166 for i in 0..N/8 {
167 // fill the temp array t
168 for (j, item) in t.iter_mut().enumerate() {
169 *item = ((((self[8 * i + j] as i32) << 4) + (q as i32 /2))
170 / (q as i32)
171 & 15) as u8;
172 }
173
174 out[idx] = t[0] | (t[1] << 4);
175 out[idx + 1] = t[2] | (t[3] << 4);
176 out[idx + 2] = t[4] | (t[5] << 4);
177 out[idx + 3] = t[6] | (t[7] << 4);
178 idx += 4;
179 }
180 },
181 5 => { // MLKEM1024
182 for i in 0..N/8 {
183 // fill the temp array t
184 for (j, item) in t.iter_mut().enumerate() {
185 *item = (((((self[8 * i + j] as i32) << 5) + (q as i32 /2))
186 / (q as i32))
187 & 31) as u8;
188 }
189
190 out[idx] = t[0] | (t[1] << 5);
191 out[idx + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
192 out[idx + 2] = (t[3] >> 1) | (t[4] << 4);
193 out[idx + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
194 out[idx + 4] = (t[6] >> 2) | (t[7] << 3);
195 idx += 5;
196 }
197 },
198 _ => unreachable!(),
199 };
200 }
201
202 /// This is an optimized version of
203 /// Decompress_𝑑𝑣( ByteDecode_𝑑𝑣(𝑐2) )
204 /// which unpacks a single polynomial according to the packing coefficient dv
205 pub(crate) fn decompress_poly<const dv: i16>(compressed_v: &[u8]) -> Polynomial {
206 // make sure we have received a dv
207 debug_assert!(dv == 4 || dv == 5);
208
209 // make sure we were given the right size output buffer
210 // each of the N i16's will take dv bits
211 debug_assert_eq!(compressed_v.len(), N * (dv as usize) / 8);
212
213 let mut v = Polynomial::new();
214
215 let mut idx = 0usize;
216
217 // if self.m_engine.poly_compressed_bytes() == 128 {
218 match dv {
219 4 => { // MLKEM512 and MLKEM768
220 for i in 0..N/2 {
221 v[2 * i] =
222 (((((compressed_v[idx] & 15) as i16) as i32 * (q as i32)) + 8) >> 4)
223 as i16;
224 v[2 * i + 1] =
225 (((((compressed_v[idx] >> 4) as i16) as i32 * (q as i32)) + 8) >> 4)
226 as i16;
227 idx += 1;
228 }
229 },
230 5 => { // MLKEM1024
231 let mut t = [0u8; 8];
232 for i in 0..N/8 {
233 t[0] = compressed_v[idx];
234 t[1] =
235 (compressed_v[idx] >> 5) | (compressed_v[idx + 1] << 3);
236 t[2] = compressed_v[idx + 1] >> 2;
237 t[3] = (compressed_v[idx + 1] >> 7)
238 | (compressed_v[idx + 2] << 1);
239 t[4] = (compressed_v[idx + 2] >> 4)
240 | (compressed_v[idx + 3] << 4);
241 t[5] = compressed_v[idx + 3] >> 1;
242 t[6] = (compressed_v[idx + 3] >> 6)
243 | (compressed_v[idx + 4] << 2);
244 t[7] = compressed_v[idx + 4] >> 3;
245 idx += 5;
246 for (j, item) in t.iter_mut().enumerate() {
247 v[8 * i + j] = (((*item & 31) as i32 * (q as i32) + 16) >> 5) as i16;
248 }
249 }
250 },
251 _ => unreachable!(),
252 }
253
254 v
255 }
256
257 pub(crate) fn cond_sub_q(&mut self) {
258 for i in 0..N {
259 self[i] = cond_sub_q(self[i]);
260 }
261 }
262
263 /// Algorithm 9 NTT(𝑓)
264 /// Computes the NTT representation 𝑓_hat of the given polynomial 𝑓 ∈ 𝑅𝑞.
265 /// Input: array 𝑓 ∈ ℤ256 ▷ the coefficients of the input polynomial
266 /// Output: array 𝑓_hat ∈ ℤ256 ▷ the coefficients of the NTT of the input polynomial
267 pub(crate) fn ntt(&mut self) {
268 let mut len = 128;
269 let mut k = 1;
270
271 while len >= 2 {
272 let mut start = 0;
273 while start < 256 {
274 let zeta = ZETAS[k];
275 k += 1;
276 let mut j = start;
277 while j < start + len {
278 let t = mul_mont(zeta, self[j + len]);
279 self[j + len] = self[j] - t;
280 self[j] += t;
281 j += 1;
282 }
283 start = j + len;
284 }
285 len >>= 1;
286 }
287 }
288
289 /// Algorithm 10 NTT (𝑓_hat)
290 /// Computes the polynomial 𝑓 ∈ 𝑅𝑞 that corresponds to the given NTT representation 𝑓 ∈ 𝑇𝑞.
291 /// Input: array 𝑓 ∈ ℤ256 ▷ the coefficients of input NTT representation
292 /// Output: array 𝑓 ∈ ℤ256 ▷ the coefficients of the inverse NTT of the input
293 pub(crate) fn inv_ntt(&mut self) {
294 // FIPS 203 ALg 10 wants you to copy f_hat into f, and then act of f
295 // but we're going to do this in-place for memory-saving reasons.
296
297 let mut len = 2;
298 let mut k = 0;
299
300 while len <= 128 {
301 let mut start = 0;
302 while start < 256 {
303 let zeta = ZETAS_INV[k];
304 k += 1;
305 let mut j = start;
306 while j < start + len {
307 let t = self[j];
308 let u = self[j + len];
309
310 self[j] = barrett_reduce(t + u);
311 self[j + len] = mul_mont(zeta, t - u);
312 j += 1;
313 }
314 start = j + len;
315 }
316 len <<= 1;
317 }
318
319 // 14: 𝑓 ← 𝑓 ⋅ 3303 mod 𝑞
320 // ▷ multiply every entry by 3303 ≡ 128−1 mod 𝑞
321 for i in 0..N {
322 self[i] = mul_mont(self[i], ZETAS_INV[127]);
323 }
324 }
325}
326
327impl Secret for Polynomial {}
328
329impl Drop for Polynomial {
330 fn drop(&mut self) {
331 self.coeffs.fill(0i16);
332 }
333}
334
335impl Debug for Polynomial {
336 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
337 write!(f, "Polynomial (data masked)")
338 }
339}
340
341impl Display for Polynomial {
342 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
343 write!(f, "Polynomial (data masked)")
344 }
345}
346
347// Not currently used, but I'll leave it here because it's useful for debugging if you want to output values
348// that are normalized to [0,q] to compare against intermediate results from other libraries.
349// /// if a is in \[-q..0], then it shifts it up by q to be in \[0..q]
350// pub(crate) fn conditional_add_q(a: i16) -> i16 {
351// a + ((a >> 15) & q)
352// }
353//
354// #[test]
355// /// These are the results it's giving; I'm not sure if these are "correct" or not.
356// fn test_conditional_add_q() {
357// assert_eq!(conditional_add_q(-q -1), -1);
358// assert_eq!(conditional_add_q(-q), 0);
359// assert_eq!(conditional_add_q(-q -2), -2);
360// assert_eq!(conditional_add_q(-q +1), 1);
361// assert_eq!(conditional_add_q(-1), q -1);
362// assert_eq!(conditional_add_q(0), 0);
363// assert_eq!(conditional_add_q(1), 1);
364// assert_eq!(conditional_add_q(q -1), q -1);
365// assert_eq!(conditional_add_q(q), q);
366// assert_eq!(conditional_add_q(q +1), q +1);
367// }