Skip to main content

bouncycastle_mlkem/
polynomial.rs

1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6
7use crate::aux_functions::{
8    ZETAS, ZETAS_INV, barrett_reduce, cond_sub_q, montgomery_reduce, mul_mont, ntt_base_mult,
9};
10use crate::mlkem::{N, q};
11use bouncycastle_core::traits::Secret;
12
13/// A polynomial over the ML-KEM ring.
14/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
15#[derive(Clone)]
16pub struct Polynomial {
17    /// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
18    pub coeffs: [i16; N],
19}
20
21/// Convenience function to avoid ".0" all over the place.
22impl Index<usize> for Polynomial {
23    type Output = i16;
24
25    fn index(&self, index: usize) -> &Self::Output {
26        &self.coeffs[index]
27    }
28}
29/// Convenience function to avoid ".0" all over the place.
30impl IndexMut<usize> for Polynomial {
31    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
32        &mut self.coeffs[index]
33    }
34}
35
36impl Polynomial {
37    /// Create a new polynomial with all coefficients set to zero.
38    pub const fn new() -> Self {
39        Self { coeffs: [0i16; N] }
40    }
41
42    /// Create a Polynomial from the message m
43    pub(crate) fn from_msg(m: [u8; 32]) -> Self {
44        let mut w = Polynomial::new();
45
46        for (i, b) in m.iter().enumerate().take(N / 8) {
47            for j in 0..8 {
48                let mask = -(((*b >> j) & 1) as i16);
49                w[8 * i + j] = mask /*as i32*/ & ((q + 1) / 2);
50            }
51        }
52
53        w
54    }
55
56    /// Convert a Polynomial back into a message m
57    pub(crate) fn to_msg(mut self) -> [u8; 32] {
58        const LOWER: i32 = q as i32 >> 2; // 832
59        const UPPER: i32 = q as i32 - LOWER; // 2497
60
61        let mut msg = [0u8; 32];
62
63        // you would expect to use a full reduce() here, but since this is data coming from
64        // out matrix math and not from an attacker, we can get away with the lighter cond_sub_q()
65        self.cond_sub_q();
66
67        // for (i, item) in msg.iter_mut().enumerate().take(N/8) {
68        for i in 0..N / 8 {
69            for j in 0..8 {
70                let c_j = self[8 * i + j] as i32;
71                let t = (((LOWER - c_j) & (c_j - UPPER)) >> 31) & 0x0000000000000001;
72                msg[i] |= (t << j) as u8;
73            }
74        }
75
76        msg
77    }
78
79    // not currently used, but I'll leave it here because it's useful for debugging if you want to output values
80    // that are normalized to [0,q] to compare against intermediate results from other libraries.
81    // pub(crate) fn conditional_add_q(&mut self) {
82    //     for x in self.0.iter_mut() {
83    //         *x = conditional_add_q(*x);
84    //     }
85    // }
86
87    pub(crate) fn add(&mut self, w: &Self) {
88        for i in 0..N {
89            self[i] += w[i];
90        }
91    }
92
93    pub(crate) fn sub(&mut self, w: &Self) {
94        for i in 0..N {
95            self[i] -= w[i];
96        }
97    }
98
99    pub(crate) fn poly_reduce(&mut self) {
100        for i in 0..N {
101            self[i] = barrett_reduce(self[i]);
102        }
103    }
104
105    /// In-place conversion of all coefficients of a polynomial
106    /// from normal domain to Montgomery domain
107    ///
108    /// Borrowed from:
109    /// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L307
110    pub(crate) fn convert_to_mont(&mut self) {
111        const F: i16 = ((1u64 << 32) % q as u64) as i16;
112        for i in 0..N {
113            self[i] = montgomery_reduce((self[i] as i32) * (F as i32));
114        }
115    }
116
117    /// This is an optimized version of
118    ///   ByteEncode_𝑑𝑣( Compress_𝑑𝑣(𝑣) )
119    /// which packs a single polynomial according to the packing coefficient dv
120    pub(crate) fn compress_poly<const dv: i16>(&self, out: &mut [u8]) {
121        // make sure we have received a dv
122        debug_assert!(dv == 4 || dv == 5);
123
124        // make sure we were given the right size output buffer
125        // each of the N i16's will take dv bits
126        debug_assert_eq!(out.len(), N * (dv as usize) / 8);
127
128        let mut t = [0u8; 8];
129        let mut idx = 0;
130
131        // bc-java has a cond_sub_q() here, but unit tests show that we don't need it.
132        // let mut s = self.clone();
133        // s.cond_sub_q();
134
135        match dv {
136            4 => {
137                // MLKEM512 and MLKEM768
138                for i in 0..N / 8 {
139                    // fill the temp array t
140                    for (j, item) in t.iter_mut().enumerate() {
141                        *item = ((((self[8 * i + j] as i32) << 4) + (q as i32 / 2)) / (q as i32)
142                            & 15) as u8;
143                    }
144
145                    out[idx] = t[0] | (t[1] << 4);
146                    out[idx + 1] = t[2] | (t[3] << 4);
147                    out[idx + 2] = t[4] | (t[5] << 4);
148                    out[idx + 3] = t[6] | (t[7] << 4);
149                    idx += 4;
150                }
151            }
152            5 => {
153                // MLKEM1024
154                for i in 0..N / 8 {
155                    // fill the temp array t
156                    for (j, item) in t.iter_mut().enumerate() {
157                        *item = (((((self[8 * i + j] as i32) << 5) + (q as i32 / 2)) / (q as i32))
158                            & 31) as u8;
159                    }
160
161                    out[idx] = t[0] | (t[1] << 5);
162                    out[idx + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
163                    out[idx + 2] = (t[3] >> 1) | (t[4] << 4);
164                    out[idx + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
165                    out[idx + 4] = (t[6] >> 2) | (t[7] << 3);
166                    idx += 5;
167                }
168            }
169            _ => unreachable!(),
170        };
171    }
172
173    /// This is an optimized version of
174    ///   Decompress_𝑑𝑣( ByteDecode_𝑑𝑣(𝑐2) )
175    /// which unpacks a single polynomial according to the packing coefficient dv
176    pub(crate) fn decompress_poly<const dv: i16>(compressed_v: &[u8]) -> Polynomial {
177        // make sure we have received a dv
178        debug_assert!(dv == 4 || dv == 5);
179
180        // make sure we were given the right size output buffer
181        // each of the N i16's will take dv bits
182        debug_assert_eq!(compressed_v.len(), N * (dv as usize) / 8);
183
184        let mut v = Polynomial::new();
185
186        let mut idx = 0usize;
187
188        // if self.m_engine.poly_compressed_bytes() == 128 {
189        match dv {
190            4 => {
191                // MLKEM512 and MLKEM768
192                for i in 0..N / 2 {
193                    v[2 * i] =
194                        (((((compressed_v[idx] & 15) as i16) as i32 * (q as i32)) + 8) >> 4) as i16;
195                    v[2 * i + 1] =
196                        (((((compressed_v[idx] >> 4) as i16) as i32 * (q as i32)) + 8) >> 4) as i16;
197                    idx += 1;
198                }
199            }
200            5 => {
201                // MLKEM1024
202                let mut t = [0u8; 8];
203                for i in 0..N / 8 {
204                    t[0] = compressed_v[idx];
205                    t[1] = (compressed_v[idx] >> 5) | (compressed_v[idx + 1] << 3);
206                    t[2] = compressed_v[idx + 1] >> 2;
207                    t[3] = (compressed_v[idx + 1] >> 7) | (compressed_v[idx + 2] << 1);
208                    t[4] = (compressed_v[idx + 2] >> 4) | (compressed_v[idx + 3] << 4);
209                    t[5] = compressed_v[idx + 3] >> 1;
210                    t[6] = (compressed_v[idx + 3] >> 6) | (compressed_v[idx + 4] << 2);
211                    t[7] = compressed_v[idx + 4] >> 3;
212                    idx += 5;
213                    for (j, item) in t.iter_mut().enumerate() {
214                        v[8 * i + j] = (((*item & 31) as i32 * (q as i32) + 16) >> 5) as i16;
215                    }
216                }
217            }
218            _ => unreachable!(),
219        }
220
221        v
222    }
223
224    pub(crate) fn cond_sub_q(&mut self) {
225        for i in 0..N {
226            self[i] = cond_sub_q(self[i]);
227        }
228    }
229
230    /// Algorithm 9 NTT(𝑓)
231    /// Computes the NTT representation 𝑓_hat of the given polynomial 𝑓 ∈ 𝑅𝑞.
232    /// Input: array 𝑓 ∈ ℤ256  ▷ the coefficients of the input polynomial
233    /// Output: array 𝑓_hat ∈ ℤ256  ▷ the coefficients of the NTT of the input polynomial
234    /// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
235    pub fn ntt(&mut self) {
236        let mut len = 128;
237        let mut k = 1;
238
239        while len >= 2 {
240            let mut start = 0;
241            while start < 256 {
242                let zeta = ZETAS[k];
243                k += 1;
244                let mut j = start;
245                while j < start + len {
246                    let t = mul_mont(zeta, self[j + len]);
247                    self[j + len] = self[j] - t;
248                    self[j] += t;
249                    j += 1;
250                }
251                start = j + len;
252            }
253            len >>= 1;
254        }
255    }
256
257    /// Algorithm 10 NTT (𝑓_hat)
258    /// Computes the polynomial 𝑓 ∈ 𝑅𝑞 that corresponds to the given NTT representation 𝑓 ∈ 𝑇𝑞.
259    /// Input: array 𝑓 ∈ ℤ256  ▷ the coefficients of input NTT representation
260    /// Output: array 𝑓 ∈ ℤ256  ▷ the coefficients of the inverse NTT of the input
261    /// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
262    pub fn inv_ntt(&mut self) {
263        // FIPS 203 Alg 10 wants you to copy f_hat into f, and then act on f
264        // but we're going to do this in-place for memory-saving reasons.
265
266        let mut len = 2;
267        let mut k = 0;
268
269        while len <= 128 {
270            let mut start = 0;
271            while start < 256 {
272                let zeta = ZETAS_INV[k];
273                k += 1;
274                let mut j = start;
275                while j < start + len {
276                    let t = self[j];
277                    let u = self[j + len];
278
279                    self[j] = barrett_reduce(t + u);
280                    self[j + len] = mul_mont(zeta, t - u);
281                    j += 1;
282                }
283                start = j + len;
284            }
285            len <<= 1;
286        }
287
288        // 14: 𝑓 ← 𝑓 ⋅ 3303 mod 𝑞
289        //   ▷ multiply every entry by 3303 ≡ 128−1 mod 𝑞
290        for i in 0..N {
291            self[i] = mul_mont(self[i], ZETAS_INV[127]);
292        }
293    }
294}
295
296/// Multiplication of two polynomials in NTT domain
297///
298/// Borrowed from:
299/// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L290
300/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
301pub fn base_mult_montgomery(a: &Polynomial, b: &Polynomial) -> Polynomial {
302    let mut r = Polynomial::new();
303
304    for i in 0..(N / 4) {
305        ntt_base_mult(
306            &mut r.coeffs,
307            4 * i,
308            a[4 * i],
309            a[4 * i + 1],
310            b[4 * i],
311            b[4 * i + 1],
312            ZETAS[64 + i],
313        );
314        ntt_base_mult(
315            &mut r.coeffs,
316            4 * i + 2,
317            a[4 * i + 2],
318            a[4 * i + 3],
319            b[4 * i + 2],
320            b[4 * i + 3],
321            -ZETAS[64 + i],
322        );
323    }
324
325    r
326}
327
328impl Secret for Polynomial {}
329
330impl Drop for Polynomial {
331    fn drop(&mut self) {
332        self.coeffs.fill(0i16);
333    }
334}
335
336impl Debug for Polynomial {
337    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
338        write!(f, "Polynomial (data masked)")
339    }
340}
341
342impl Display for Polynomial {
343    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
344        write!(f, "Polynomial (data masked)")
345    }
346}
347
348// Not currently used, but I'll leave it here because it's useful for debugging if you want to output values
349// that are normalized to [0,q] to compare against intermediate results from other libraries.
350// /// if a is in \[-q..0], then it shifts it up by q to be in \[0..q]
351// pub(crate) fn conditional_add_q(a: i16) -> i16 {
352//     a + ((a >> 15) & q)
353// }
354//
355// #[test]
356// /// These are the results it's giving; I'm not sure if these are "correct" or not.
357// fn test_conditional_add_q() {
358//     assert_eq!(conditional_add_q(-q -1), -1);
359//     assert_eq!(conditional_add_q(-q), 0);
360//     assert_eq!(conditional_add_q(-q -2), -2);
361//     assert_eq!(conditional_add_q(-q +1), 1);
362//     assert_eq!(conditional_add_q(-1), q -1);
363//     assert_eq!(conditional_add_q(0), 0);
364//     assert_eq!(conditional_add_q(1), 1);
365//     assert_eq!(conditional_add_q(q -1), q -1);
366//     assert_eq!(conditional_add_q(q), q);
367//     assert_eq!(conditional_add_q(q +1), q +1);
368// }