Skip to main content

bouncycastle_mldsa_lowmemory/
polynomial.rs

1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6use bouncycastle_core::traits::Secret;
7use crate::mldsa::{N, q, q_inv, MLDSA44_POLY_W1_PACKED_LEN, MLDSA65_POLY_W1_PACKED_LEN};
8use crate::aux_functions::{high_bits, low_bits, make_hint, use_hint};
9
10/// A polynomial over the ML-DSA ring.
11/// Dev note: this doesn't strictly need to be pub ... ie there's no good reason for a caller to use this class directly,
12/// but in order to test the Debug and Display traits, you need STD, so those can't be tested from inline tests in this file
13/// and the real unit tests are in a different crate, so here we are.
14#[derive(Clone)]
15pub struct Polynomial{ pub(crate) coeffs: [i32; N] }
16
17/// Convenience function to avoid ".0" all over the place.
18impl Index<usize> for Polynomial {
19    type Output = i32;
20
21    fn index(&self, index: usize) -> &Self::Output {
22        &self.coeffs[index]
23    }
24}
25/// Convenience function to avoid ".0" all over the place.
26impl IndexMut<usize> for Polynomial {
27    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
28        &mut self.coeffs[index]
29    }
30}
31
32impl Polynomial {
33    /// Create a new polynomial with all coefficients set to zero.
34    pub const fn new() -> Self {
35        Self{ coeffs: [0i32; N] }
36    }
37
38    pub(crate) fn conditional_add_q(&mut self) {
39        for x in self.coeffs.iter_mut() {
40            *x = conditional_add_q(*x);
41        }
42    }
43
44    /// Algorithm 44 AddNTT(𝑎, 𝑏)̂
45    /// Computes the sum a + 𝑏 of two elements 𝑎, 𝑏 ∈ 𝑇𝑞.
46    /// Note: result could be up to 2q.
47    pub(crate) fn add_ntt(&mut self, w: &Self) {
48        for i in 0..N {
49            self[i] += w[i];
50        }
51    }
52
53    pub(crate) fn sub(&mut self, w: &Self) {
54        for i in 0..N {
55            self[i] -= w[i];
56        }
57    }
58
59    /// Algorithm 45 MultiplyNTT(𝑎, 𝑏)̂
60    /// Computes the product 𝑎 ∘̂ 𝑏 of two elements 𝑎, 𝑏 ∈ 𝑇𝑞.
61    /// Input: 𝑎, 𝑏 ∈ 𝑇𝑞.
62    /// Output: 𝑐 ∈ 𝑇𝑞.
63    /// Multiply the coefficients in this polynomial by those in another polynomial and perform montgomery reduction.
64    /// Also called pointwise montgomery multiplication
65    pub(crate) fn multiply_ntt(&mut self, b: &Polynomial) {
66        for i in 0..N {
67            self[i] = montgomery_reduce((self[i] as i64) * (b[i] as i64));
68        }
69    }
70
71    pub(crate) fn high_bits<const GAMMA2: i32>(&mut self) {
72        for i in 0..N {
73            self[i] = high_bits::<GAMMA2>(self[i]);
74        }
75    }
76
77    pub(crate) fn low_bits<const GAMMA2: i32>(&mut self) {
78        for i in 0..N {
79            self[i] = low_bits::<GAMMA2>(self[i]);
80        }
81    }
82
83    pub(crate) fn check_norm(&self, bound: i32) -> bool {
84        // Fine that this is not constant-time (returns true early) because it is used in a rejection loop.
85        // IE the early quit here leads to rejection and continuing to the top of the rejection loop, or failing the signature validation.
86        // So the i32 that we just checked in a non-constant-time manner is about to get thrown away.
87        if bound > (q - 1) / 8 {
88            return true;
89        }
90
91        let mut t: i32;
92        for x in self.coeffs.iter() {
93            t = *x >> 31;
94            t = *x - (t & (2 * *x));
95
96            if t >= bound {
97                return true;
98            }
99        }
100        false
101    }
102
103    pub(crate) fn shift_left<const d: i32>(&mut self) {
104        for x in self.coeffs.iter_mut() {
105            *x <<= d;
106        }
107    }
108
109    /// Creates the hint vector, and also returns its hamming weight (ie the number of 1's).
110    pub(crate) fn make_hint_row<const GAMMA2: i32>(&self, r: &Self) -> (Self, i32) {
111        let mut out = Polynomial::new();
112        let mut count = 0i32;
113        for i in 0..N {
114            let x = make_hint::<GAMMA2>(self[i], r[i]);
115            out[i] = x;
116            count += x;
117        }
118
119        (out, count)
120    }
121
122    pub(crate) fn w1_encode<const POLY_W1_PACKED_LEN: usize>(&self) -> [u8; POLY_W1_PACKED_LEN] {
123        // It might seem counter-intuitive for a low-memory implementation to create a tmp buffer
124        // rather than work in the provided buffer, but my benchmarking shows that for whatever
125        // reason, rust is like an order of magnitude faster working in a scope-local array than
126        // in a referenced piece of memory.
127        // My guess is that when you tell the compiler that the intermediate values are scope-local,
128        // then it's free to optimize all of the computation into CPU registers and skip, in this case,
129        // several hundred physical memory writes.
130        // So while it looks odd to use a scope variable in a low-memory implementation, it's way faster
131        // and I'm not convinced that it uses any more physical memory.        
132        let mut r = [0u8; POLY_W1_PACKED_LEN];
133
134        match POLY_W1_PACKED_LEN {
135            MLDSA44_POLY_W1_PACKED_LEN => {
136                for i in 0..N/4 {
137                    r[3 * i] =
138                        ((self[4 * i]) as u8) | ((self[4 * i + 1] << 6) as u8);
139                    r[3 * i + 1] =
140                        ((self[4 * i + 1] >> 2) as u8) | ((self[4 * i + 2] << 4) as u8);
141                    r[3 * i + 2] =
142                        ((self[4 * i + 2] >> 4) as u8) | ((self[4 * i + 3] << 2) as u8);
143                }
144            },
145            // ML-DSA65 and 87 share a POLY_W1_PACKED_LEN value
146            MLDSA65_POLY_W1_PACKED_LEN => {
147                for i in 0..N/2 {
148                    r[i] = ((self[2 * i]) | (self[2 * i + 1] << 4)) as u8;
149                }
150            },
151            _ => { unreachable!() }
152        }
153        
154        r
155    }
156
157    /// Algorithm 41 NTT(𝑤)
158    /// Computes the NTT.
159    /// Input: Polynomial 𝑤(𝑋)
160    /// 𝑗=0 𝑤𝑗𝑋𝑗 ∈ 𝑅𝑞.
161    /// Output: 𝑤_hat = (𝑤_hat\[0], ..., 𝑤_hat\[255]) ∈ 𝑇𝑞.
162    ///
163    /// Note: by convention, variables holding the output of the NTT function should be named "_ntt"
164    /// to indicate that they are in the NTT domain (sometimes called the frequency domain), not the natural domain.
165    /// I considered using the rust type system to enforce this, but it seemed like overkill, cause that's what
166    /// NIST test vectors are for.
167    ///
168    /// Design choice: don't do the NTT in-place, but copy data to a new array.
169    /// This uses slightly more memory and requires a copy, but makes the code easier to read
170    /// and less likely to contain a bug. But this optimization could be considered in the future.
171    pub(crate) fn ntt(&mut self) {
172        let mut m: usize = 0;
173        let mut len: usize = 128;
174
175        while len >= 1 {
176            let mut start: usize = 0;
177            while start < N {
178                m += 1;
179                let z: i32 = ZETAS[m];
180
181                for j in start..start + len {
182                    let t = montgomery_reduce(z as i64 * self[j + len] as i64);
183                    self[j + len] = self[j] - t; // '% q' not strictly needed cause it gets reduced at some point later. Removing it gave +5% in benchmarking
184                    self[j] = self[j] + t; // '% q' not strictly needed
185                }
186                start = start + 2 * len;
187            }
188            len >>= 1;
189        }
190    }
191
192    /// Algorithm 42 NTT−1(𝑤)̂
193    /// Computes the inverse of the NTT.
194    /// Input: ̂̂ ̂ 𝑤 = (𝑤\[0], … , 𝑤\[255]) ∈ 𝑇𝑞.
195    /// Output: Polynomial 𝑤(𝑋) = ∑255
196    /// 𝑗=0 𝑤𝑗𝑋𝑗 ∈ 𝑅𝑞
197    pub(crate) fn inv_ntt(&mut self) {
198        let mut m: usize = N;
199        let mut len: usize = 1;
200
201        while len < N {
202            let mut start: usize = 0;
203            while start < N {
204                m -= 1;
205                let z = (-1) * ZETAS[m];
206
207                // j = start;
208                // while j < start + len {
209                for j in start..start + len {
210                    // 𝑡 ← 𝑤𝑗
211                    let t: i32 = self[j];
212
213                    // 𝑤𝑗 ← (𝑡 + 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
214                    self[j] = t + self[j + len];
215
216                    // 𝑤𝑗+𝑙𝑒𝑛 ← (𝑡 − 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
217                    self[j + len] = t - self[j + len];
218
219                    // 𝑤𝑗+𝑙𝑒𝑛 ← (𝑧 ⋅ 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
220                    self[j + len] = montgomery_reduce(z as i64 * self[j + len] as i64);
221                }
222                start = start + 2 * len; // could be optimized to save the multiply-by-two since j finishes as `start + len`. That said 2* is just << 1, which is basically free.
223            }
224            len <<= 1;
225        }
226
227        // f = 256^-1 mod q
228        // const f: i64 = 8347681;
229        // bc-java uses this value rather than the one in FIPS 204
230        const f: i64 = 41978;
231        for j in 0..N {
232            // equiv. to the global constant N
233            self[j] = montgomery_reduce(f * self[j] as i64);
234        }
235    }
236
237
238    pub(crate) fn use_hint<const GAMMA2: i32>(
239        &mut self,
240        h: &Polynomial,
241    ) {
242        for i in 0..N {
243            self[i] = use_hint::<GAMMA2>(self[i], h[i]);
244        }
245    }
246}
247
248impl Secret for Polynomial {}
249
250impl Drop for Polynomial {
251    fn drop(&mut self) {
252        self.coeffs.fill(0i32);
253    }
254}
255
256impl Debug for Polynomial {
257    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
258        write!(f, "Polynomial (data masked)")
259    }
260}
261
262impl Display for Polynomial {
263    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
264        write!(f, "Polynomial (data masked)")
265    }
266}
267
268/// FIPS 204 Algorithm 49
269/// As described in FIPS 204 Appendix A, montgomery reduction allows for efficient computation
270/// of expressions of the form c = a * b (mod q).
271/// The output is not necessarily less than q in absolute value, but it is less than 2q in absolute value
272pub(crate) fn montgomery_reduce(a: i64) -> i32 {
273    debug_assert!(a > - ((q as i64) <<31) && a < ((q as i64) <<31));
274
275    // 2: 𝑡 ← ((𝑎 mod 2^32) ⋅ QINV) mod 2^32
276    let t: i32 = (a as i32).wrapping_mul(q_inv);
277
278    // 3: 𝑟 ← (𝑎 − 𝑡 ⋅ 𝑞)/2^32
279    ((a - ((t as i64) * (q as i64))) >> 32) as i32
280}
281
282
283pub(crate) fn conditional_add_q(a: i32) -> i32 {
284    a + ((a >> 31) & q)
285}
286
287#[test]
288/// These are the results it's giving; I'm not sure if these are "correct" or not.
289fn test_conditional_add_q() {
290    assert_eq!(conditional_add_q(-q -1), -1);
291    assert_eq!(conditional_add_q(-q), 0);
292    assert_eq!(conditional_add_q(-q -2), -2);
293    assert_eq!(conditional_add_q(-q +1), 1);
294    assert_eq!(conditional_add_q(-1), q-1);
295    assert_eq!(conditional_add_q(0), 0);
296    assert_eq!(conditional_add_q(1), 1);
297    assert_eq!(conditional_add_q(q -1), q-1);
298    assert_eq!(conditional_add_q(q), q);
299    assert_eq!(conditional_add_q(q +1), q+1);
300}
301
302/// Constants for NTT
303const ZETAS: [i32; 256] = [
304    0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251,
305    -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488,
306    -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497,
307    280005, 2706023, 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, -3861115,
308    -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267, -1643818, 3505694,
309    -3821735, 3507263, -2140649, -1600420, 3699596, 811944, 531354, 954230, 3881043, 3900724,
310    -2556880, 2071892, -2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950,
311    2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922,
312    3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, -671102, -1228525,
313    -22981, -1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944, 508951, 3097992,
314    44288, -1100098, 904516, 3958618, -3724342, -8578, 1653064, -3249728, 2389356, -210977, 759969,
315    -1316856, 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, 1285669,
316    -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, 2091667, 3407706, 2316500,
317    3817976, -3342478, 2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181, -3520352,
318    -3759364, -1197226, -3193378, 900702, 1859098, 909542, 819034, 495491, -1613174, -43260,
319    -522500, -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297,
320    286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 2842341, 2691481, -2590150,
321    1265009, 4055324, 1247620, 2486353, 1595974, -3767016, 1250494, 2635921, -3548272, -2994039,
322    1869119, 1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115,
323    -1962642, -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, -542412,
324    -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395, 2454455,
325    -164721, 1957272, 3369112, 185531, -1207385, -3183426, 162844, 1616392, 3014001, 810149,
326    1652634, -3694233, -1799107, -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735,
327    472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, -2939036,
328    -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, -554416, 3919660, -48306,
329    -1362209, 3937738, 1400424, -846154, 1976782,
330];