bouncycastle_mlkem/polynomial.rs
1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6
7use crate::aux_functions::{
8 ZETAS, ZETAS_INV, barrett_reduce, cond_sub_q, montgomery_reduce, mul_mont, ntt_base_mult,
9};
10use crate::mlkem::{N, q};
11use bouncycastle_core::traits::Secret;
12
13/// A polynomial over the ML-KEM ring.
14/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
15#[derive(Clone)]
16pub struct Polynomial {
17 /// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
18 pub coeffs: [i16; N],
19}
20
21/// Convenience function to avoid ".0" all over the place.
22impl Index<usize> for Polynomial {
23 type Output = i16;
24
25 fn index(&self, index: usize) -> &Self::Output {
26 &self.coeffs[index]
27 }
28}
29/// Convenience function to avoid ".0" all over the place.
30impl IndexMut<usize> for Polynomial {
31 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
32 &mut self.coeffs[index]
33 }
34}
35
36impl Polynomial {
37 /// Create a new polynomial with all coefficients set to zero.
38 pub const fn new() -> Self {
39 Self { coeffs: [0i16; N] }
40 }
41
42 /// Create a Polynomial from the message m
43 pub(crate) fn from_msg(m: [u8; 32]) -> Self {
44 let mut w = Polynomial::new();
45
46 for (i, b) in m.iter().enumerate().take(N / 8) {
47 for j in 0..8 {
48 let mask = -(((*b >> j) & 1) as i16);
49 w[8 * i + j] = mask /*as i32*/ & ((q + 1) / 2);
50 }
51 }
52
53 w
54 }
55
56 /// Convert a Polynomial back into a message m
57 pub(crate) fn to_msg(mut self) -> [u8; 32] {
58 const LOWER: i32 = q as i32 >> 2; // 832
59 const UPPER: i32 = q as i32 - LOWER; // 2497
60
61 let mut msg = [0u8; 32];
62
63 // you would expect to use a full reduce() here, but since this is data coming from
64 // out matrix math and not from an attacker, we can get away with the lighter cond_sub_q()
65 self.cond_sub_q();
66
67 // for (i, item) in msg.iter_mut().enumerate().take(N/8) {
68 for i in 0..N / 8 {
69 for j in 0..8 {
70 let c_j = self[8 * i + j] as i32;
71 let t = (((LOWER - c_j) & (c_j - UPPER)) >> 31) & 0x0000000000000001;
72 msg[i] |= (t << j) as u8;
73 }
74 }
75
76 msg
77 }
78
79 // not currently used, but I'll leave it here because it's useful for debugging if you want to output values
80 // that are normalized to [0,q] to compare against intermediate results from other libraries.
81 // pub(crate) fn conditional_add_q(&mut self) {
82 // for x in self.0.iter_mut() {
83 // *x = conditional_add_q(*x);
84 // }
85 // }
86
87 pub(crate) fn add(&mut self, w: &Self) {
88 for i in 0..N {
89 self[i] += w[i];
90 }
91 }
92
93 pub(crate) fn sub(&mut self, w: &Self) {
94 for i in 0..N {
95 self[i] -= w[i];
96 }
97 }
98
99 pub(crate) fn poly_reduce(&mut self) {
100 for i in 0..N {
101 self[i] = barrett_reduce(self[i]);
102 }
103 }
104
105 /// In-place conversion of all coefficients of a polynomial
106 /// from normal domain to Montgomery domain
107 ///
108 /// Borrowed from:
109 /// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L307
110 pub(crate) fn convert_to_mont(&mut self) {
111 const F: i16 = ((1u64 << 32) % q as u64) as i16;
112 for i in 0..N {
113 self[i] = montgomery_reduce((self[i] as i32) * (F as i32));
114 }
115 }
116
117 /// This is an optimized version of
118 /// ByteEncode_𝑑𝑣( Compress_𝑑𝑣(𝑣) )
119 /// which packs a single polynomial according to the packing coefficient dv
120 pub(crate) fn compress_poly<const dv: i16>(&self, out: &mut [u8]) {
121 // make sure we have received a dv
122 debug_assert!(dv == 4 || dv == 5);
123
124 // make sure we were given the right size output buffer
125 // each of the N i16's will take dv bits
126 debug_assert_eq!(out.len(), N * (dv as usize) / 8);
127
128 let mut t = [0u8; 8];
129 let mut idx = 0;
130
131 // bc-java has a cond_sub_q() here, but unit tests show that we don't need it.
132 // let mut s = self.clone();
133 // s.cond_sub_q();
134
135 match dv {
136 4 => {
137 // MLKEM512 and MLKEM768
138 for i in 0..N / 8 {
139 // fill the temp array t
140 for (j, item) in t.iter_mut().enumerate() {
141 *item = ((((self[8 * i + j] as i32) << 4) + (q as i32 / 2)) / (q as i32)
142 & 15) as u8;
143 }
144
145 out[idx] = t[0] | (t[1] << 4);
146 out[idx + 1] = t[2] | (t[3] << 4);
147 out[idx + 2] = t[4] | (t[5] << 4);
148 out[idx + 3] = t[6] | (t[7] << 4);
149 idx += 4;
150 }
151 }
152 5 => {
153 // MLKEM1024
154 for i in 0..N / 8 {
155 // fill the temp array t
156 for (j, item) in t.iter_mut().enumerate() {
157 *item = (((((self[8 * i + j] as i32) << 5) + (q as i32 / 2)) / (q as i32))
158 & 31) as u8;
159 }
160
161 out[idx] = t[0] | (t[1] << 5);
162 out[idx + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
163 out[idx + 2] = (t[3] >> 1) | (t[4] << 4);
164 out[idx + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
165 out[idx + 4] = (t[6] >> 2) | (t[7] << 3);
166 idx += 5;
167 }
168 }
169 _ => unreachable!(),
170 };
171 }
172
173 /// This is an optimized version of
174 /// Decompress_𝑑𝑣( ByteDecode_𝑑𝑣(𝑐2) )
175 /// which unpacks a single polynomial according to the packing coefficient dv
176 pub(crate) fn decompress_poly<const dv: i16>(compressed_v: &[u8]) -> Polynomial {
177 // make sure we have received a dv
178 debug_assert!(dv == 4 || dv == 5);
179
180 // make sure we were given the right size output buffer
181 // each of the N i16's will take dv bits
182 debug_assert_eq!(compressed_v.len(), N * (dv as usize) / 8);
183
184 let mut v = Polynomial::new();
185
186 let mut idx = 0usize;
187
188 // if self.m_engine.poly_compressed_bytes() == 128 {
189 match dv {
190 4 => {
191 // MLKEM512 and MLKEM768
192 for i in 0..N / 2 {
193 v[2 * i] =
194 (((((compressed_v[idx] & 15) as i16) as i32 * (q as i32)) + 8) >> 4) as i16;
195 v[2 * i + 1] =
196 (((((compressed_v[idx] >> 4) as i16) as i32 * (q as i32)) + 8) >> 4) as i16;
197 idx += 1;
198 }
199 }
200 5 => {
201 // MLKEM1024
202 let mut t = [0u8; 8];
203 for i in 0..N / 8 {
204 t[0] = compressed_v[idx];
205 t[1] = (compressed_v[idx] >> 5) | (compressed_v[idx + 1] << 3);
206 t[2] = compressed_v[idx + 1] >> 2;
207 t[3] = (compressed_v[idx + 1] >> 7) | (compressed_v[idx + 2] << 1);
208 t[4] = (compressed_v[idx + 2] >> 4) | (compressed_v[idx + 3] << 4);
209 t[5] = compressed_v[idx + 3] >> 1;
210 t[6] = (compressed_v[idx + 3] >> 6) | (compressed_v[idx + 4] << 2);
211 t[7] = compressed_v[idx + 4] >> 3;
212 idx += 5;
213 for (j, item) in t.iter_mut().enumerate() {
214 v[8 * i + j] = (((*item & 31) as i32 * (q as i32) + 16) >> 5) as i16;
215 }
216 }
217 }
218 _ => unreachable!(),
219 }
220
221 v
222 }
223
224 pub(crate) fn cond_sub_q(&mut self) {
225 for i in 0..N {
226 self[i] = cond_sub_q(self[i]);
227 }
228 }
229
230 /// Algorithm 9 NTT(𝑓)
231 /// Computes the NTT representation 𝑓_hat of the given polynomial 𝑓 ∈ 𝑅𝑞.
232 /// Input: array 𝑓 ∈ ℤ256 ▷ the coefficients of the input polynomial
233 /// Output: array 𝑓_hat ∈ ℤ256 ▷ the coefficients of the NTT of the input polynomial
234 /// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
235 pub fn ntt(&mut self) {
236 let mut len = 128;
237 let mut k = 1;
238
239 while len >= 2 {
240 let mut start = 0;
241 while start < 256 {
242 let zeta = ZETAS[k];
243 k += 1;
244 let mut j = start;
245 while j < start + len {
246 let t = mul_mont(zeta, self[j + len]);
247 self[j + len] = self[j] - t;
248 self[j] += t;
249 j += 1;
250 }
251 start = j + len;
252 }
253 len >>= 1;
254 }
255 }
256
257 /// Algorithm 10 NTT (𝑓_hat)
258 /// Computes the polynomial 𝑓 ∈ 𝑅𝑞 that corresponds to the given NTT representation 𝑓 ∈ 𝑇𝑞.
259 /// Input: array 𝑓 ∈ ℤ256 ▷ the coefficients of input NTT representation
260 /// Output: array 𝑓 ∈ ℤ256 ▷ the coefficients of the inverse NTT of the input
261 /// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
262 pub fn inv_ntt(&mut self) {
263 // FIPS 203 Alg 10 wants you to copy f_hat into f, and then act on f
264 // but we're going to do this in-place for memory-saving reasons.
265
266 let mut len = 2;
267 let mut k = 0;
268
269 while len <= 128 {
270 let mut start = 0;
271 while start < 256 {
272 let zeta = ZETAS_INV[k];
273 k += 1;
274 let mut j = start;
275 while j < start + len {
276 let t = self[j];
277 let u = self[j + len];
278
279 self[j] = barrett_reduce(t + u);
280 self[j + len] = mul_mont(zeta, t - u);
281 j += 1;
282 }
283 start = j + len;
284 }
285 len <<= 1;
286 }
287
288 // 14: 𝑓 ← 𝑓 ⋅ 3303 mod 𝑞
289 // ▷ multiply every entry by 3303 ≡ 128−1 mod 𝑞
290 for i in 0..N {
291 self[i] = mul_mont(self[i], ZETAS_INV[127]);
292 }
293 }
294}
295
296/// Multiplication of two polynomials in NTT domain
297///
298/// Borrowed from:
299/// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L290
300/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
301pub fn base_mult_montgomery(a: &Polynomial, b: &Polynomial) -> Polynomial {
302 let mut r = Polynomial::new();
303
304 for i in 0..(N / 4) {
305 ntt_base_mult(
306 &mut r.coeffs,
307 4 * i,
308 a[4 * i],
309 a[4 * i + 1],
310 b[4 * i],
311 b[4 * i + 1],
312 ZETAS[64 + i],
313 );
314 ntt_base_mult(
315 &mut r.coeffs,
316 4 * i + 2,
317 a[4 * i + 2],
318 a[4 * i + 3],
319 b[4 * i + 2],
320 b[4 * i + 3],
321 -ZETAS[64 + i],
322 );
323 }
324
325 r
326}
327
328impl Secret for Polynomial {}
329
330impl Drop for Polynomial {
331 fn drop(&mut self) {
332 self.coeffs.fill(0i16);
333 }
334}
335
336impl Debug for Polynomial {
337 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
338 write!(f, "Polynomial (data masked)")
339 }
340}
341
342impl Display for Polynomial {
343 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
344 write!(f, "Polynomial (data masked)")
345 }
346}
347
348// Not currently used, but I'll leave it here because it's useful for debugging if you want to output values
349// that are normalized to [0,q] to compare against intermediate results from other libraries.
350// /// if a is in \[-q..0], then it shifts it up by q to be in \[0..q]
351// pub(crate) fn conditional_add_q(a: i16) -> i16 {
352// a + ((a >> 15) & q)
353// }
354//
355// #[test]
356// /// These are the results it's giving; I'm not sure if these are "correct" or not.
357// fn test_conditional_add_q() {
358// assert_eq!(conditional_add_q(-q -1), -1);
359// assert_eq!(conditional_add_q(-q), 0);
360// assert_eq!(conditional_add_q(-q -2), -2);
361// assert_eq!(conditional_add_q(-q +1), 1);
362// assert_eq!(conditional_add_q(-1), q -1);
363// assert_eq!(conditional_add_q(0), 0);
364// assert_eq!(conditional_add_q(1), 1);
365// assert_eq!(conditional_add_q(q -1), q -1);
366// assert_eq!(conditional_add_q(q), q);
367// assert_eq!(conditional_add_q(q +1), q +1);
368// }