bouncycastle_mldsa/polynomial.rs
1//! Represents a polynomial over the ML-DSA ring.
2
3use core::fmt;
4use core::fmt::{Debug, Display, Formatter};
5use core::ops::{Index, IndexMut};
6use bouncycastle_core::traits::Secret;
7use crate::mldsa::{N, MLDSA44_POLY_W1_PACKED_LEN, MLDSA65_POLY_W1_PACKED_LEN, q};
8use crate::aux_functions::{conditional_add_q, high_bits, low_bits, make_hint, montgomery_reduce, ZETAS};
9
10/// A polynomial over the ML-DSA ring.
11/// Dev note: this doesn't strictly need to be pub ... ie there's no good reason for a caller to use this class directly,
12/// but in order to test the Debug and Display traits, you need STD, so those can't be tested from inline tests in this file
13/// and the real unit tests are in a different crate, so here we are.
14#[derive(Clone)]
15pub struct Polynomial{ pub(crate) coeffs: [i32; N] }
16
17/// Convenience function to avoid ".0" all over the place.
18impl Index<usize> for Polynomial {
19 type Output = i32;
20
21 fn index(&self, index: usize) -> &Self::Output {
22 &self.coeffs[index]
23 }
24}
25/// Convenience function to avoid ".0" all over the place.
26impl IndexMut<usize> for Polynomial {
27 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
28 &mut self.coeffs[index]
29 }
30}
31
32impl Polynomial {
33 /// Create a new polynomial with all coefficients set to zero.
34 pub const fn new() -> Self {
35 Self{ coeffs: [0i32; N] }
36 }
37
38 pub(crate) fn conditional_add_q(&mut self) {
39 for x in self.coeffs.iter_mut() {
40 *x = conditional_add_q(*x);
41 }
42 }
43
44 pub(crate) fn reduce(&mut self) {
45 for i in 0..N {
46 self[i] = montgomery_reduce(self[i] as i64);
47 }
48 }
49
50 /// Algorithm 44 AddNTT(𝑎, 𝑏)̂
51 /// Computes the sum a + 𝑏 of two elements 𝑎, 𝑏 ∈ 𝑇𝑞.
52 /// Note: result could be up to 2q.
53 pub(crate) fn add_ntt(&mut self, w: &Self) {
54 for i in 0..N {
55 self[i] += w[i];
56 }
57 }
58
59 pub(crate) fn sub(&mut self, w: &Self) {
60 for i in 0..N {
61 self[i] -= w[i];
62 }
63 }
64
65 pub(crate) fn high_bits<const GAMMA2: i32>(&self) -> Self {
66 let mut w = Self::new();
67 for i in 0..N {
68 w[i] = high_bits::<GAMMA2>(self[i]);
69 }
70
71 w
72 }
73
74 pub(crate) fn low_bits<const GAMMA2: i32>(&self) -> Self {
75 let mut w = Self::new();
76 for i in 0..N {
77 w[i] = low_bits::<GAMMA2>(self[i]);
78 }
79
80 w
81 }
82
83 pub(crate) fn check_norm<const BOUND: i32>(&self) -> bool {
84 // Fine that this is not constant-time (returns true early) because it is used in a rejection loop.
85 // IE the early quit here leads to rejection and continuing to the top of the rejection loop, or failing the signature validation.
86 // So the i32 that we just checked in a non-constant-time manner is about to get thrown away.
87
88 // Note: this formulation of the check_norm function usually requires this bounds check
89 // if bound > (q - 1) / 8 {
90 // return true;
91 // }
92 // but since BOUND is a constant here, we'll just do a debug_assert to make sure the value is what we expect.
93 debug_assert!(BOUND <= (q - 1) / 8);
94
95
96 let mut t: i32;
97 for x in self.coeffs.iter() {
98 t = *x >> 31;
99 t = *x - (t & (2 * *x));
100
101 if t >= BOUND {
102 return true;
103 }
104 }
105 false
106 }
107
108 pub(crate) fn shift_left<const d: i32>(&mut self) {
109 for x in self.coeffs.iter_mut() {
110 *x <<= d;
111 }
112 }
113
114 /// Creates the hint vector, and also returns its hamming weight (ie the number of 1's).
115 pub(crate) fn make_hint<const GAMMA2: i32>(&self, r: &Self) -> (Self, i32) {
116 let mut out = Polynomial::new();
117 let mut count = 0i32;
118 for i in 0..N {
119 let x = make_hint::<GAMMA2>(self[i], r[i]);
120 out[i] = x;
121 count += x;
122 }
123
124 (out, count)
125 }
126
127 pub(crate) fn w1_encode<const POLY_W1_PACKED_LEN: usize>(&self) -> [u8; POLY_W1_PACKED_LEN] {
128 let mut r = [0u8; POLY_W1_PACKED_LEN];
129
130 match POLY_W1_PACKED_LEN {
131 MLDSA44_POLY_W1_PACKED_LEN => {
132 for i in 0..N/4 {
133 r[3 * i] =
134 ((self[4 * i]) as u8) | ((self[4 * i + 1] << 6) as u8);
135 r[3 * i + 1] =
136 ((self[4 * i + 1] >> 2) as u8) | ((self[4 * i + 2] << 4) as u8);
137 r[3 * i + 2] =
138 ((self[4 * i + 2] >> 4) as u8) | ((self[4 * i + 3] << 2) as u8);
139 }
140 },
141 // ML-DSA65 and 87 share a POLY_W1_PACKED_LEN value
142 MLDSA65_POLY_W1_PACKED_LEN => {
143 for i in 0..N/2 {
144 r[i] = ((self[2 * i]) | (self[2 * i + 1] << 4)) as u8;
145 }
146 },
147 _ => { unreachable!() }
148 }
149
150 r
151 }
152
153 /// Algorithm 41 NTT(𝑤)
154 /// Computes the NTT.
155 /// Input: Polynomial 𝑤(𝑋)
156 /// 𝑗=0 𝑤𝑗𝑋𝑗 ∈ 𝑅𝑞.
157 /// Output: 𝑤_hat = (𝑤_hat\[0], ..., 𝑤_hat\[255]) ∈ 𝑇𝑞.
158 ///
159 /// Note: by convention, variables holding the output of the NTT function should be named "_ntt"
160 /// to indicate that they are in the NTT domain (sometimes called the frequency domain), not the natural domain.
161 /// I considered using the rust type system to enforce this, but it seemed like overkill, cause that's what
162 /// NIST test vectors are for.
163 ///
164 /// Design choice: don't do the NTT in-place, but copy data to a new array.
165 /// This uses slightly more memory and requires a copy, but makes the code easier to read
166 /// and less likely to contain a bug. But this optimization could be considered in the future.
167 pub(crate) fn ntt(&mut self) {
168 let mut m: usize = 0;
169 let mut len: usize = 128;
170
171 while len >= 1 {
172 let mut start: usize = 0;
173 while start < N {
174 m += 1;
175 let z: i32 = ZETAS[m];
176
177 for j in start..start + len {
178 let t = montgomery_reduce(z as i64 * self[j + len] as i64);
179 self[j + len] = self[j] - t; // '% q' not strictly needed cause it gets reduced at some point later. Removing it gave +5% in benchmarking
180 self[j] = self[j] + t; // '% q' not strictly needed
181 }
182 start = start + 2 * len;
183 }
184 len >>= 1;
185 }
186 }
187
188 /// Algorithm 42 NTT−1(𝑤)̂
189 /// Computes the inverse of the NTT.
190 /// Input: ̂̂ ̂ 𝑤 = (𝑤\[0], … , 𝑤\[255]) ∈ 𝑇𝑞.
191 /// Output: Polynomial 𝑤(𝑋) = ∑255
192 /// 𝑗=0 𝑤𝑗𝑋𝑗 ∈ 𝑅𝑞
193 pub(crate) fn inv_ntt(&mut self) {
194 let mut m: usize = N;
195 let mut len: usize = 1;
196
197 while len < N {
198 let mut start: usize = 0;
199 while start < N {
200 m -= 1;
201 let z = (-1) * ZETAS[m];
202
203 // j = start;
204 // while j < start + len {
205 for j in start..start + len {
206 // 𝑡 ← 𝑤𝑗
207 let t: i32 = self[j];
208
209 // 𝑤𝑗 ← (𝑡 + 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
210 self[j] = t + self[j + len];
211
212 // 𝑤𝑗+𝑙𝑒𝑛 ← (𝑡 − 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
213 self[j + len] = t - self[j + len];
214
215 // 𝑤𝑗+𝑙𝑒𝑛 ← (𝑧 ⋅ 𝑤𝑗+𝑙𝑒𝑛) mod 𝑞
216 self[j + len] = montgomery_reduce(z as i64 * self[j + len] as i64);
217 }
218 start = start + 2 * len; // could be optimized to save the multiply-by-two since j finishes as `start + len`. That said 2* is just << 1, which is basically free.
219 }
220 len <<= 1;
221 }
222
223 // f = 256^-1 mod q
224 // const f: i64 = 8347681;
225 // bc-java uses this value rather than the one in FIPS 204
226 const f: i64 = 41978;
227 for j in 0..N {
228 // equiv. to the global constant N
229 self[j] = montgomery_reduce(f * self[j] as i64);
230 }
231 }
232}
233
234impl Secret for Polynomial {}
235
236impl Drop for Polynomial {
237 fn drop(&mut self) {
238 self.coeffs.fill(0i32);
239 }
240}
241
242impl Debug for Polynomial {
243 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244 write!(f, "Polynomial (data masked)")
245 }
246}
247
248impl Display for Polynomial {
249 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
250 write!(f, "Polynomial (data masked)")
251 }
252}