Skip to main content

bouncycastle_mldsa/
polynomial.rs

1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6use bouncycastle_core::traits::Secret;
7use crate::mldsa::{N, MLDSA44_POLY_W1_PACKED_LEN, MLDSA65_POLY_W1_PACKED_LEN, q};
8use crate::aux_functions::{conditional_add_q, high_bits, low_bits, make_hint, montgomery_reduce, ZETAS};
9
10/// A polynomial over the ML-DSA ring.
11/// Dev note: this doesn't strictly need to be pub ... ie there's no good reason for a caller to use this class directly,
12/// but in order to test the Debug and Display traits, you need STD, so those can't be tested from inline tests in this file
13/// and the real unit tests are in a different crate, so here we are.
14#[derive(Clone)]
15pub struct Polynomial{ pub(crate) coeffs: [i32; N] }
16
17/// Convenience function to avoid ".0" all over the place.
18impl Index<usize> for Polynomial {
19    type Output = i32;
20
21    fn index(&self, index: usize) -> &Self::Output {
22        &self.coeffs[index]
23    }
24}
25/// Convenience function to avoid ".0" all over the place.
26impl IndexMut<usize> for Polynomial {
27    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
28        &mut self.coeffs[index]
29    }
30}
31
32impl Polynomial {
33    /// Create a new polynomial with all coefficients set to zero.
34    pub const fn new() -> Self {
35        Self{ coeffs: [0i32; N] }
36    }
37
38    pub(crate) fn conditional_add_q(&mut self) {
39        for x in self.coeffs.iter_mut() {
40            *x = conditional_add_q(*x);
41        }
42    }
43    
44    pub(crate) fn reduce(&mut self) {
45        for i in 0..N {
46            self[i] = montgomery_reduce(self[i] as i64);
47        }
48    }
49
50    /// Algorithm 44 AddNTT(𝑎, 𝑏)̂
51    /// Computes the sum a + 𝑏 of two elements 𝑎, 𝑏 ∈ 𝑇𝑞.
52    /// Note: result could be up to 2q.
53    pub(crate) fn add_ntt(&mut self, w: &Self) {
54        for i in 0..N {
55            self[i] += w[i];
56        }
57    }
58
59    pub(crate) fn sub(&mut self, w: &Self) {
60        for i in 0..N {
61            self[i] -= w[i];
62        }
63    }
64
65    pub(crate) fn high_bits<const GAMMA2: i32>(&self) -> Self {
66        let mut w = Self::new();
67        for i in 0..N {
68            w[i] = high_bits::<GAMMA2>(self[i]);
69        }
70
71        w
72    }
73
74    pub(crate) fn low_bits<const GAMMA2: i32>(&self) -> Self {
75        let mut w = Self::new();
76        for i in 0..N {
77            w[i] = low_bits::<GAMMA2>(self[i]);
78        }
79
80        w
81    }
82
83    pub(crate) fn check_norm<const BOUND: i32>(&self) -> bool {
84        // Fine that this is not constant-time (returns true early) because it is used in a rejection loop.
85        // IE the early quit here leads to rejection and continuing to the top of the rejection loop, or failing the signature validation.
86        // So the i32 that we just checked in a non-constant-time manner is about to get thrown away.
87        
88        // Note: this formulation of the check_norm function usually requires this bounds check
89        //  if bound > (q - 1) / 8 {
90        //     return true;
91        //  }
92        // but since BOUND is a constant here, we'll just do a debug_assert to make sure the value is what we expect.
93        debug_assert!(BOUND <= (q - 1) / 8);
94        
95        
96        let mut t: i32;
97        for x in self.coeffs.iter() {
98            t = *x >> 31;
99            t = *x - (t & (2 * *x));
100
101            if t >= BOUND {
102                return true;
103            }
104        }
105        false
106    }
107
108    pub(crate) fn shift_left<const d: i32>(&mut self) {
109        for x in self.coeffs.iter_mut() {
110            *x <<= d;
111        }
112    }
113
114    /// Creates the hint vector, and also returns its hamming weight (ie the number of 1's).
115    pub(crate) fn make_hint<const GAMMA2: i32>(&self, r: &Self) -> (Self, i32) {
116        let mut out = Polynomial::new();
117        let mut count = 0i32;
118        for i in 0..N {
119            let x = make_hint::<GAMMA2>(self[i], r[i]);
120            out[i] = x;
121            count += x;
122        }
123
124        (out, count)
125    }
126
127    pub(crate) fn w1_encode<const POLY_W1_PACKED_LEN: usize>(&self) -> [u8; POLY_W1_PACKED_LEN] {
128        let mut r = [0u8; POLY_W1_PACKED_LEN];
129
130        match POLY_W1_PACKED_LEN {
131            MLDSA44_POLY_W1_PACKED_LEN => {
132                for i in 0..N/4 {
133                    r[3 * i] =
134                        ((self[4 * i]) as u8) | ((self[4 * i + 1] << 6) as u8);
135                    r[3 * i + 1] =
136                        ((self[4 * i + 1] >> 2) as u8) | ((self[4 * i + 2] << 4) as u8);
137                    r[3 * i + 2] =
138                        ((self[4 * i + 2] >> 4) as u8) | ((self[4 * i + 3] << 2) as u8);
139                }
140            },
141            // ML-DSA65 and 87 share a POLY_W1_PACKED_LEN value
142            MLDSA65_POLY_W1_PACKED_LEN => {
143                for i in 0..N/2 {
144                    r[i] = ((self[2 * i]) | (self[2 * i + 1] << 4)) as u8;
145                }
146            },
147            _ => { unreachable!() }
148        }
149
150        r
151    }
152
153    /// Algorithm 41 NTT(𝑤)
154    /// Computes the NTT.
155    /// Input: Polynomial 𝑤(𝑋)
156    /// 𝑗=0 𝑤𝑗𝑋𝑗 ∈ 𝑅𝑞.
157    /// Output: 𝑤_hat = (𝑤_hat\[0], ..., 𝑤_hat\[255]) ∈ 𝑇𝑞.
158    ///
159    /// Note: by convention, variables holding the output of the NTT function should be named "_ntt"
160    /// to indicate that they are in the NTT domain (sometimes called the frequency domain), not the natural domain.
161    /// I considered using the rust type system to enforce this, but it seemed like overkill, cause that's what
162    /// NIST test vectors are for.
163    ///
164    /// Design choice: don't do the NTT in-place, but copy data to a new array.
165    /// This uses slightly more memory and requires a copy, but makes the code easier to read
166    /// and less likely to contain a bug. But this optimization could be considered in the future.
167    pub(crate) fn ntt(&mut self) {
168        let mut m: usize = 0;
169        let mut len: usize = 128;
170
171        while len >= 1 {
172            let mut start: usize = 0;
173            while start < N {
174                m += 1;
175                let z: i32 = ZETAS[m];
176
177                for j in start..start + len {
178                    let t = montgomery_reduce(z as i64 * self[j + len] as i64);
179                    self[j + len] = self[j] - t; // '% q' not strictly needed cause it gets reduced at some point later. Removing it gave +5% in benchmarking
180                    self[j] = self[j] + t; // '% q' not strictly needed
181                }
182                start = start + 2 * len;
183            }
184            len >>= 1;
185        }
186    }
187
188    /// Algorithm 42 NTT−1(𝑤)̂
189    /// Computes the inverse of the NTT.
190    /// Input: ̂̂ ̂ 𝑤 = (𝑤\[0], … , 𝑤\[255]) ∈ 𝑇𝑞.
191    /// Output: Polynomial 𝑤(𝑋) = ∑255
192    /// 𝑗=0 𝑤𝑗𝑋𝑗 ∈ 𝑅𝑞
193    pub(crate) fn inv_ntt(&mut self) {
194        let mut m: usize = N;
195        let mut len: usize = 1;
196
197        while len < N {
198            let mut start: usize = 0;
199            while start < N {
200                m -= 1;
201                let z = (-1) * ZETAS[m];
202
203                // j = start;
204                // while j < start + len {
205                for j in start..start + len {
206                    // 𝑡 ← 𝑤𝑗
207                    let t: i32 = self[j];
208
209                    // 𝑤𝑗 ← (𝑡 + 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
210                    self[j] = t + self[j + len];
211
212                    // 𝑤𝑗+𝑙𝑒𝑛 ← (𝑡 − 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
213                    self[j + len] = t - self[j + len];
214
215                    // 𝑤𝑗+𝑙𝑒𝑛 ← (𝑧 ⋅ 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
216                    self[j + len] = montgomery_reduce(z as i64 * self[j + len] as i64);
217                }
218                start = start + 2 * len; // could be optimized to save the multiply-by-two since j finishes as `start + len`. That said 2* is just << 1, which is basically free.
219            }
220            len <<= 1;
221        }
222
223        // f = 256^-1 mod q
224        // const f: i64 = 8347681;
225        // bc-java uses this value rather than the one in FIPS 204
226        const f: i64 = 41978;
227        for j in 0..N {
228            // equiv. to the global constant N
229            self[j] = montgomery_reduce(f * self[j] as i64);
230        }
231    }
232}
233
234impl Secret for Polynomial {}
235
236impl Drop for Polynomial {
237    fn drop(&mut self) {
238        self.coeffs.fill(0i32);
239    }
240}
241
242impl Debug for Polynomial {
243    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244        write!(f, "Polynomial (data masked)")
245    }
246}
247
248impl Display for Polynomial {
249    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
250        write!(f, "Polynomial (data masked)")
251    }
252}