Skip to main content

bouncycastle_mlkem_lowmemory/
polynomial.rs

1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6use bouncycastle_core::traits::Secret;
7use crate::aux_functions::{barrett_reduce, cond_sub_q, montgomery_reduce, mul_mont, ntt_base_mult, ZETAS, ZETAS_INV};
8use crate::mlkem::{N, q};
9
10/// A polynomial over the ML-KEM ring.
11/// Dev note: this doesn't strictly need to be pub ... ie there's no good reason for a caller to use this class directly,
12/// but in order to test the Debug and Display traits, you need STD, so those can't be tested from inline tests in this file
13/// and the real unit tests are in a different crate, so here we are.
14#[derive(Clone)]
15pub struct Polynomial{ pub(crate) coeffs: [i16; N] }
16
17/// Convenience function to avoid ".0" all over the place.
18impl Index<usize> for Polynomial {
19    type Output = i16;
20
21    fn index(&self, index: usize) -> &Self::Output {
22        &self.coeffs[index]
23    }
24}
25/// Convenience function to avoid ".0" all over the place.
26impl IndexMut<usize> for Polynomial {
27    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
28        &mut self.coeffs[index]
29    }
30}
31
32impl Polynomial {
33    /// Create a new polynomial with all coefficients set to zero.
34    pub const fn new() -> Self {
35        Self { coeffs: [0i16; N] }
36    }
37
38    /// Create a Polynomial from the message m
39    pub(crate) fn from_msg(m: [u8; 32]) -> Self {
40        let mut w = Polynomial::new();
41
42        for (i, b) in m.iter().enumerate().take(N/8) {
43            for j in 0..8 {
44                let mask = -(((*b >> j) & 1) as i16);
45                w[8 * i + j] = mask /*as i32*/ & ((q + 1) / 2);
46            }
47        }
48
49        w
50    }
51
52    /// Convert a Polynomial back into a message m
53    pub(crate) fn to_msg(mut self) -> [u8; 32] {
54
55        const LOWER: i32 = q as i32 >> 2;  // 832
56        const UPPER: i32 = q as i32 - LOWER; // 2497
57
58        let mut msg = [0u8; 32];
59
60        // you would expect to use a full reduce() here, but since this is data coming from
61        // out matrix math and not from an attacker, we can get away with the lighter cond_sub_q()
62        self.cond_sub_q();
63
64        // for (i, item) in msg.iter_mut().enumerate().take(N/8) {
65        for i in 0 .. N/8 {
66            for j in 0..8 {
67                let c_j = self[8 * i + j] as i32;
68                let t = (((LOWER - c_j) & (c_j - UPPER)) >> 31) & 0x0000000000000001;
69                msg[i] |= (t << j) as u8;
70            }
71        }
72
73        msg
74    }
75
76    // not currently used, but I'll leave it here because it's useful for debugging if you want to output values
77    // that are normalized to [0,q] to compare against intermediate results from other libraries.
78    // pub(crate) fn conditional_add_q(&mut self) {
79    //     for x in self.0.iter_mut() {
80    //         *x = conditional_add_q(*x);
81    //     }
82    // }
83
84    pub(crate) fn add(&mut self, w: &Self) {
85        for i in 0..N {
86            self[i] += w[i];
87        }
88    }
89
90    pub(crate) fn sub(&mut self, w: &Self) {
91        for i in 0..N {
92            self[i] -= w[i];
93        }
94    }
95
96    /// Multiplication of two polynomials in NTT domain
97    ///
98    /// Borrowed from:
99    /// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L290
100    pub(crate) fn base_mult_montgomery(&mut self, b: &Polynomial) {
101        for i in 0..(N/4) {
102            let a1: i16 = self[4 * i];
103            let a2: i16 = self[4 * i + 1];
104            ntt_base_mult(
105                &mut self.coeffs,
106                4 * i,
107                a1,
108                a2,
109                b[4 * i],
110                b[4 * i + 1],
111                ZETAS[64 + i],
112            );
113
114            let a1: i16 = self[4 * i + 2];
115            let a2: i16 = self[4 * i + 3];
116            ntt_base_mult(
117                &mut self.coeffs,
118                4 * i + 2,
119                a1,
120                a2,
121                b[4 * i + 2],
122                b[4 * i + 3],
123                -ZETAS[64 + i],
124            );
125        }
126    }
127
128    pub(crate) fn poly_reduce(&mut self) {
129        for i in 0..N {
130            self[i] = barrett_reduce(self[i]);
131        }
132    }
133
134    /// In-place conversion of all coefficients of a polynomial
135    /// from normal domain to Montgomery domain
136    ///
137    /// Borrowed from:
138    /// https://github.com/pq-crystals/kyber/blob/main/ref/poly.c#L307
139    pub(crate) fn convert_to_mont(&mut self) {
140        const F: i16 = ((1u64 << 32) % q as u64) as i16;
141        for i in 0..N {
142            self[i] = montgomery_reduce((self[i] as i32) * (F as i32));
143        }
144    }
145
146    /// This is an optimized version of
147    ///   ByteEncode_𝑑𝑣( Compress_𝑑𝑣(𝑣) )
148    /// which packs a single polynomial according to the packing coefficient dv
149    pub(crate) fn compress_poly<const dv: i16>(&self, out: &mut [u8]) {
150        // make sure we have received a dv
151        debug_assert!(dv == 4 || dv == 5);
152
153        // make sure we were given the right size output buffer
154        // each of the N i16's will take dv bits
155        debug_assert_eq!(out.len(), N * (dv as usize) / 8);
156
157        let mut t = [0u8; 8];
158        let mut idx = 0;
159
160        // bc-java has a cond_sub_q() here, but unit tests show that we don't need it.
161        // let mut s = self.clone();
162        // s.cond_sub_q();
163
164        match dv {
165            4 => { // MLKEM512 and MLKEM768
166                for i in 0..N/8 {
167                    // fill the temp array t
168                    for (j, item) in t.iter_mut().enumerate() {
169                        *item = ((((self[8 * i + j] as i32) << 4) + (q as i32 /2))
170                            / (q as i32)
171                            & 15) as u8;
172                    }
173
174                    out[idx] = t[0] | (t[1] << 4);
175                    out[idx + 1] = t[2] | (t[3] << 4);
176                    out[idx + 2] = t[4] | (t[5] << 4);
177                    out[idx + 3] = t[6] | (t[7] << 4);
178                    idx += 4;
179                }
180            },
181            5 => { // MLKEM1024
182                for i in 0..N/8 {
183                    // fill the temp array t
184                    for (j, item) in t.iter_mut().enumerate() {
185                        *item = (((((self[8 * i + j] as i32) << 5) + (q as i32 /2))
186                            / (q as i32))
187                            & 31) as u8;
188                    }
189
190                    out[idx] = t[0] | (t[1] << 5);
191                    out[idx + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
192                    out[idx + 2] = (t[3] >> 1) | (t[4] << 4);
193                    out[idx + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
194                    out[idx + 4] = (t[6] >> 2) | (t[7] << 3);
195                    idx += 5;
196                }
197            },
198            _ => unreachable!(),
199        };
200    }
201
202    /// This is an optimized version of
203    ///   Decompress_𝑑𝑣( ByteDecode_𝑑𝑣(𝑐2) )
204    /// which unpacks a single polynomial according to the packing coefficient dv
205    pub(crate) fn decompress_poly<const dv: i16>(compressed_v: &[u8]) -> Polynomial {
206        // make sure we have received a dv
207        debug_assert!(dv == 4 || dv == 5);
208
209        // make sure we were given the right size output buffer
210        // each of the N i16's will take dv bits
211        debug_assert_eq!(compressed_v.len(), N * (dv as usize) / 8);
212
213        let mut v = Polynomial::new();
214
215        let mut idx = 0usize;
216
217        // if self.m_engine.poly_compressed_bytes() == 128 {
218        match dv {
219            4 => { // MLKEM512 and MLKEM768
220                for i in 0..N/2 {
221                    v[2 * i] =
222                        (((((compressed_v[idx] & 15) as i16) as i32 * (q as i32)) + 8) >> 4)
223                            as i16;
224                    v[2 * i + 1] =
225                        (((((compressed_v[idx] >> 4) as i16) as i32 * (q as i32)) + 8) >> 4)
226                            as i16;
227                    idx += 1;
228                }
229            },
230            5 => { // MLKEM1024
231                let mut t = [0u8; 8];
232                for i in 0..N/8 {
233                    t[0] = compressed_v[idx];
234                    t[1] =
235                        (compressed_v[idx] >> 5) | (compressed_v[idx + 1] << 3);
236                    t[2] = compressed_v[idx + 1] >> 2;
237                    t[3] = (compressed_v[idx + 1] >> 7)
238                        | (compressed_v[idx + 2] << 1);
239                    t[4] = (compressed_v[idx + 2] >> 4)
240                        | (compressed_v[idx + 3] << 4);
241                    t[5] = compressed_v[idx + 3] >> 1;
242                    t[6] = (compressed_v[idx + 3] >> 6)
243                        | (compressed_v[idx + 4] << 2);
244                    t[7] = compressed_v[idx + 4] >> 3;
245                    idx += 5;
246                    for (j, item) in t.iter_mut().enumerate() {
247                        v[8 * i + j] = (((*item & 31) as i32 * (q as i32) + 16) >> 5) as i16;
248                    }
249                }
250            },
251            _ => unreachable!(),
252        }
253
254        v
255    }
256
257    pub(crate) fn cond_sub_q(&mut self) {
258        for i in 0..N {
259            self[i] = cond_sub_q(self[i]);
260        }
261    }
262
263    /// Algorithm 9 NTT(𝑓)
264    /// Computes the NTT representation 𝑓_hat of the given polynomial 𝑓 ∈ 𝑅𝑞.
265    /// Input: array 𝑓 ∈ ℤ256  ▷ the coefficients of the input polynomial
266    /// Output: array 𝑓_hat ∈ ℤ256  ▷ the coefficients of the NTT of the input polynomial
267    pub(crate) fn ntt(&mut self) {
268        let mut len = 128;
269        let mut k = 1;
270
271        while len >= 2 {
272            let mut start = 0;
273            while start < 256 {
274                let zeta = ZETAS[k];
275                k += 1;
276                let mut j = start;
277                while j < start + len {
278                    let t = mul_mont(zeta, self[j + len]);
279                    self[j + len] = self[j] - t;
280                    self[j] += t;
281                    j += 1;
282                }
283                start = j + len;
284            }
285            len >>= 1;
286        }
287    }
288
289    /// Algorithm 10 NTT (𝑓_hat)
290    /// Computes the polynomial 𝑓 ∈ 𝑅𝑞 that corresponds to the given NTT representation 𝑓 ∈ 𝑇𝑞.
291    /// Input: array 𝑓 ∈ ℤ256  ▷ the coefficients of input NTT representation
292    /// Output: array 𝑓 ∈ ℤ256  ▷ the coefficients of the inverse NTT of the input
293    pub(crate) fn inv_ntt(&mut self) {
294        // FIPS 203 ALg 10 wants you to copy f_hat into f, and then act of f
295        // but we're going to do this in-place for memory-saving reasons.
296
297        let mut len = 2;
298        let mut k = 0;
299
300        while len <= 128 {
301            let mut start = 0;
302            while start < 256 {
303                let zeta = ZETAS_INV[k];
304                k += 1;
305                let mut j = start;
306                while j < start + len {
307                    let t = self[j];
308                    let u = self[j + len];
309
310                    self[j] = barrett_reduce(t + u);
311                    self[j + len] = mul_mont(zeta, t - u);
312                    j += 1;
313                }
314                start = j + len;
315            }
316            len <<= 1;
317        }
318
319        // 14: 𝑓 ← 𝑓 ⋅ 3303 mod 𝑞
320        //   ▷ multiply every entry by 3303 ≡ 128−1 mod 𝑞
321        for i in 0..N {
322            self[i] = mul_mont(self[i], ZETAS_INV[127]);
323        }
324    }
325}
326
327impl Secret for Polynomial {}
328
329impl Drop for Polynomial {
330    fn drop(&mut self) {
331        self.coeffs.fill(0i16);
332    }
333}
334
335impl Debug for Polynomial {
336    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
337        write!(f, "Polynomial (data masked)")
338    }
339}
340
341impl Display for Polynomial {
342    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
343        write!(f, "Polynomial (data masked)")
344    }
345}
346
347// Not currently used, but I'll leave it here because it's useful for debugging if you want to output values
348// that are normalized to [0,q] to compare against intermediate results from other libraries.
349// /// if a is in \[-q..0], then it shifts it up by q to be in \[0..q]
350// pub(crate) fn conditional_add_q(a: i16) -> i16 {
351//     a + ((a >> 15) & q)
352// }
353//
354// #[test]
355// /// These are the results it's giving; I'm not sure if these are "correct" or not.
356// fn test_conditional_add_q() {
357//     assert_eq!(conditional_add_q(-q -1), -1);
358//     assert_eq!(conditional_add_q(-q), 0);
359//     assert_eq!(conditional_add_q(-q -2), -2);
360//     assert_eq!(conditional_add_q(-q +1), 1);
361//     assert_eq!(conditional_add_q(-1), q -1);
362//     assert_eq!(conditional_add_q(0), 0);
363//     assert_eq!(conditional_add_q(1), 1);
364//     assert_eq!(conditional_add_q(q -1), q -1);
365//     assert_eq!(conditional_add_q(q), q);
366//     assert_eq!(conditional_add_q(q +1), q +1);
367// }