Skip to main content

bouncycastle_mlkem/
aux_functions.rs

1//! Implements auxiliary functions for ML-DSA as defined in Section 7 of FIPS 204.
2
3use crate::Matrix;
4use crate::matrix::Vector;
5use crate::mlkem::{N, q, q_inv};
6use crate::polynomial::Polynomial;
7use bouncycastle_core::traits::XOF;
8use bouncycastle_sha3::{SHAKE128, SHAKE256};
9
10pub(crate) fn expandA<const k: usize>(rho: &[u8; 32]) -> Matrix<k, k> {
11    let mut A_hat = Matrix::<k, k>::new();
12    for i in 0..k {
13        // 5: for (𝑗 ← 0; 𝑗 < π‘˜; 𝑗++)
14        for j in 0..k {
15            // 6: 𝐀[𝑖, 𝑗] ← SampleNTT(πœŒβ€–π‘—β€–π‘–)
16            //  β–· 𝑗 and 𝑖 are bytes 33 and 34 of the input
17            A_hat[i][j] = sample_ntt(rho, &[j as u8, i as u8]);
18        }
19    }
20
21    A_hat
22}
23
24/// Algorithm 5 ByteEncode_d(𝐹)
25/// Encodes an array of 𝑑-bit integers into a byte array for 1 ≀ 𝑑 ≀ 12.
26/// Input: integer array 𝐹 ∈ β„€_M^256, where π‘š = 2^𝑑 if 𝑑 < 12, and π‘š = π‘ž if 𝑑 = 12.
27/// Output: byte array 𝐡 ∈ 𝔹32𝑑.
28/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
29pub fn byte_encode<const d: usize, const PACK_LEN: usize>(F: &Polynomial) -> [u8; PACK_LEN] {
30    debug_assert_eq!(PACK_LEN, 32 * d);
31
32    let mut B = [0u8; PACK_LEN];
33
34    for i in 0..N {
35        let mut alpha = F[i];
36
37        // For efficiency, the library is happy to work with values outside the range [0..q],
38        // but we need to reduce it for the canonical encoding.
39        alpha = barrett_reduce(alpha);
40
41        for j in 0..d {
42            // alpha % 2, but without using % for constant-time reasons
43            let tmp = (alpha & 1) as u8;
44
45            // 4: 𝑏[𝑖⋅𝑑 + 𝑗] ← π‘Ž mod 2
46            //  constant-time note: yes, % is not constant-time,
47            //   but all of the values in (i*d + j) % 8 are loop indices and not part of the secret key.
48            B[(i * d + j) / 8] |= tmp << ((i * d + j) % 8);
49
50            // 5: π‘Ž ← (π‘Ž βˆ’ 𝑏[𝑖⋅𝑑 + 𝑗])/2
51            //   β–· note π‘Ž βˆ’ 𝑏[𝑖⋅𝑑 + 𝑗] is always even
52            //
53            // Deviation from the FIPS:
54            //   the direct translation to rust would be:
55            //     alpha = (alpha - tmp as i16) >> 1;
56            //   but since 𝑏[𝑖⋅𝑑 + 𝑗] is a single bit, and π‘Ž βˆ’ 𝑏[𝑖⋅𝑑 + 𝑗] is always even,
57            //   and we're about to shift off the last bit anyway, this is literally equivalent to "alpha >> 1".
58            alpha >>= 1;
59        }
60    }
61
62    B
63}
64
65/// Algorithm 6 ByteDecode_d(𝐡)
66/// Decodes a byte array into an array of 𝑑-bit integers for 1 ≀ 𝑑 ≀ 12.
67/// Input: byte array 𝐡 ∈ 𝔹32𝑑 .
68/// Output: integer array 𝐹 ∈ β„€256 , where π‘š = 2𝑑 if 𝑑 < 12 and π‘š = π‘ž if 𝑑 = 12.
69/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
70pub fn byte_decode<const d: usize, const PACK_LEN: usize>(B: &[u8; PACK_LEN]) -> Polynomial {
71    debug_assert_eq!(PACK_LEN, 32 * d);
72
73    let mut F = Polynomial::new();
74
75    for i in 0..N {
76        // 3: F[i] = SUM_j=0..d-1{ 𝑏[𝑖 β‹… 𝑑 + 𝑗] β‹… 2𝑗 } mod m
77        for j in 0..d {
78            // select the next bit, according to bitcount, then shift it up by j
79            F[i] |= (((B[(i * d + j) / 8] >> (i * d + j) % 8) & 1) as i16) << j; // there's supposed to be a `mod m` here, but that shouldn't matter; we'll check it below anyway.
80        }
81    }
82
83    F
84}
85
86/// Algorithm 7 SampleNTT(𝐡)
87/// Takes a 32-byte seed and two indices as input and outputs a pseudorandom element of π‘‡π‘ž.
88/// Input: byte array 𝐡 ∈ 𝔹34 . β–· a 32-byte seed along with two indices
89/// Output: array π‘Ž_hat ∈ β„€256 β–· the coefficients of the NTT of a polynomial
90/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
91pub fn sample_ntt(rho: &[u8; 32], nonce: &[u8; 2]) -> Polynomial {
92    let mut a_hat = Polynomial::new();
93
94    // 1: ctx ← XOF.Init()
95    // 2: ctx ← XOF.Absorb(ctx, 𝐡) β–· input the given byte array into XOF
96    let mut xof = SHAKE128::new();
97    xof.absorb(rho);
98    xof.absorb(nonce);
99
100    // 3: 𝑗 ← 0
101    let mut j = 0usize;
102
103    // SHAKE is fairly inefficient if you just squeeze 3 bytes at a time, so we'll do a block.
104    // size doesn't really matter, so long as it's a multiple of 3.
105    // 288 seemed to be the sweet spot from playing with benchmarks
106    // It's probably around the average rejection rate, and 288 is a multiple of both 3 (required for this alg)
107    // and 8 (efficient for SHAKE).
108    let mut C = [0u8; 288];
109    xof.squeeze_out(&mut C);
110    let mut idx: usize = 0;
111
112    // 4: while 𝑗 < 256 do
113    while j < N {
114        // 5: (ctx, 𝐢) ← XOF.Squeeze(ctx, 3)
115        //   β–· get a fresh 3-byte array 𝐢 from XOF
116        if idx == C.len() {
117            xof.squeeze_out(&mut C);
118            idx = 0;
119        }
120
121        // 6: 𝑑1 ← 𝐢[0] + 256 β‹… (𝐢[1] mod 16)
122        //  β–· 0 ≀ 𝑑1 < 2^12
123        let d1: i16 = (C[idx + 0] as i16) | ((C[idx + 1] as i32) << 8) as i16 & 0xFFF;
124        debug_assert!(d1 < 2 << 12);
125
126        // 7: 𝑑2 ← ⌊𝐢[1]/16βŒ‹ + 16 β‹… 𝐢[2]
127        //  β–· 0 ≀ 𝑑2 < 2^12
128        let d2: i16 = ((C[idx + 1] as i16) >> 4) | ((C[idx + 2] as i32) << 4) as i16 & 0xFFF;
129        debug_assert!(d2 < 2 << 12);
130
131        // 8: if 𝑑1 < π‘ž then
132        // 9:   π‘Ž_hat[𝑗] ← 𝑑1 Μ‚
133        //         β–· π‘Ž_hat ∈ β„€256
134        // 10:  𝑗 ← 𝑗 + 1
135        // 11: end if
136        if d1 < q {
137            a_hat[j] = d1;
138            j += 1;
139        }
140
141        // 12: if 𝑑2 < π‘ž and 𝑗 < 256 then
142        // 13:  π‘Ž[𝑗] ← 𝑑2
143        // 14:  𝑗 ← 𝑗 + 1
144        // 15: end if
145        if d2 < q && j < N {
146            a_hat[j] = d2;
147            j += 1;
148        }
149
150        idx += 3;
151    }
152
153    a_hat
154}
155
156/// Algorithm 8 SamplePolyCBD (𝐡)πœ‚
157/// Takes a seed as input and outputs a pseudorandom sample from the distribution Dπœ‚(π‘…π‘ž).
158/// Input: byte array 𝐡 ∈ 𝔹64πœ‚.
159/// Output: array 𝑓 ∈ β„€256  β–· the coefficients of the sampled polynomial
160/// Note: this is exposed publicly only for testing purposes and there is no good reason to use it in production code.
161pub fn sample_poly_cbd<const eta: i16>(bytes: &[u8]) -> Polynomial {
162    debug_assert_eq!(bytes.len(), 64 * eta as usize);
163
164    let mut f = Polynomial::new();
165
166    match eta {
167        2 => {
168            for i in 0..N / 8 {
169                let t = u32::from_le_bytes(bytes[4 * i..4 * i + 4].try_into().unwrap());
170                let mut d = t & 0x55555555;
171                d += (t >> 1) & 0x55555555;
172                for j in 0..8usize {
173                    let a = ((d >> (4 * j)) & 0x3) as i16;
174                    let b = ((d >> (4 * j + eta as usize)) & 0x3) as i16;
175                    f[8 * i + j] = a - b;
176
177                    // β–· 0 ≀ 𝑓[𝑖] ≀ πœ‚ or π‘ž βˆ’ πœ‚ ≀ 𝑓[𝑖] ≀ π‘ž βˆ’ 1
178                    //  this version is in [-eta, eta] instead of [0..eta] \U [q-eta..q-1]
179                    debug_assert!(-eta <= f[8 * i + j] && f[8 * i + j] <= eta);
180                }
181            }
182        }
183        3 => {
184            for i in 0..N / 4 {
185                let t = little_endian_to_u24(bytes, 3 * i);
186                let mut d = t & 0x00249249;
187                d += (t >> 1) & 0x00249249;
188                d += (t >> 2) & 0x00249249;
189                for j in 0..4usize {
190                    let a = ((d >> (6 * j)) & 0x7) as i16;
191                    let b = ((d >> (6 * j + eta as usize)) & 0x7) as i16;
192                    f[4 * i + j] = a - b;
193
194                    // β–· 0 ≀ 𝑓[𝑖] ≀ πœ‚ or π‘ž βˆ’ πœ‚ ≀ 𝑓[𝑖] ≀ π‘ž βˆ’ 1
195                    //  this version is in [-eta, eta] instead of [0..eta] \U [q-eta..q-1]
196                    debug_assert!(-eta <= f[4 * i + j] && f[4 * i + j] <= eta);
197                }
198            }
199        }
200        _ => unreachable!("Wrong Eta"),
201    }
202
203    f
204}
205
206/// SamplePolyCBDπœ‚1(PRFπœ‚1 (𝜎, 𝑁 ))
207/// Performs both the PRF and SamplePolyCBD steps
208pub(crate) fn sample_poly_CBD<const eta: i16>(b: &[u8; 32], n: u8) -> Polynomial {
209    // Alg 13: 9: 𝐬[𝑖] ← SamplePolyCBDπœ‚1(PRFπœ‚1 (𝜎, 𝑁 ))
210    //  β–· 𝐬[𝑖] ∈ β„€256 sampled from CBD
211    match eta {
212        2 => {
213            let buf = {
214                let mut xof = SHAKE256::new();
215                xof.absorb(b);
216                xof.absorb(&n.to_le_bytes());
217
218                let mut buf = [0u8; 2 * 64];
219                xof.squeeze_out(&mut buf);
220                buf
221            };
222
223            sample_poly_cbd::<eta>(&buf)
224        }
225        3 => {
226            let buf = {
227                let mut xof = SHAKE256::new();
228                xof.absorb(b);
229                xof.absorb(&n.to_le_bytes());
230                let mut buf = [0u8; 3 * 64];
231                xof.squeeze_out(&mut buf);
232                buf
233            };
234
235            sample_poly_cbd::<eta>(&buf)
236        }
237        _ => unreachable!(),
238    }
239}
240
241/// Internal helper for keygen since both s_hat and e_hat have identical sampling code
242pub(crate) fn sample_vector_CBD<const k: usize, const eta: i16>(
243    b: &[u8; 32],
244    mut n: u8,
245) -> Vector<k> {
246    let mut v = Vector::<k>::new();
247
248    for i in 0..k {
249        v[i] = sample_poly_CBD::<eta>(b, n);
250
251        // Alg 13: 10: 𝑁 ← 𝑁 + 1
252        n += 1;
253    }
254
255    v
256}
257
258fn little_endian_to_u24(bs: &[u8], off: usize) -> u32 {
259    let mut n = bs[off] as u32;
260    n |= (bs[off + 1] as u32) << 8;
261    n | (bs[off + 2] as u32) << 16
262}
263
264pub(crate) const ZETAS: [i16; 128] = [
265    2285, 2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468, 573,
266    2004, 264, 383, 2500, 1458, 1727, 3199, 2648, 1017, 732, 608, 1787, 411, 3124, 1758, 1223, 652,
267    2777, 1015, 2036, 1491, 3047, 1785, 516, 3321, 3009, 2663, 1711, 2167, 126, 1469, 2476, 3239,
268    3058, 830, 107, 1908, 3082, 2378, 2931, 961, 1821, 2604, 448, 2264, 677, 2054, 2226, 430, 555,
269    843, 2078, 871, 1550, 105, 422, 587, 177, 3094, 3038, 2869, 1574, 1653, 3083, 778, 1159, 3182,
270    2552, 1483, 2727, 1119, 1739, 644, 2457, 349, 418, 329, 3173, 3254, 817, 1097, 603, 610, 1322,
271    2044, 1864, 384, 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, 2144, 1799, 2051, 794, 1819,
272    2475, 2459, 478, 3221, 3021, 996, 991, 958, 1869, 1522, 1628,
273];
274
275pub(crate) const ZETAS_INV: [i16; 128] = [
276    1701, 1807, 1460, 2371, 2338, 2333, 308, 108, 2851, 870, 854, 1510, 2535, 1278, 1530, 1185,
277    1659, 1187, 3109, 874, 1335, 2111, 136, 1215, 2945, 1465, 1285, 2007, 2719, 2726, 2232, 2512,
278    75, 156, 3000, 2911, 2980, 872, 2685, 1590, 2210, 602, 1846, 777, 147, 2170, 2551, 246, 1676,
279    1755, 460, 291, 235, 3152, 2742, 2907, 3224, 1779, 2458, 1251, 2486, 2774, 2899, 1103, 1275,
280    2652, 1065, 2881, 725, 1508, 2368, 398, 951, 247, 1421, 3222, 2499, 271, 90, 853, 1860, 3203,
281    1162, 1618, 666, 320, 8, 2813, 1544, 282, 1838, 1293, 2314, 552, 2677, 2106, 1571, 205, 2918,
282    1542, 2721, 2597, 2312, 681, 130, 1602, 1871, 829, 2946, 3065, 1325, 2756, 1861, 1474, 1202,
283    2367, 3147, 1752, 2707, 171, 3127, 3042, 1907, 1836, 1517, 359, 758, 1441,
284];
285
286pub(crate) fn mul_mont(a: i16, b: i16) -> i16 {
287    montgomery_reduce((a as i32) * (b as i32))
288}
289
290pub(crate) fn montgomery_reduce(a: i32) -> i16 {
291    let u = a.wrapping_mul(q_inv) as i16;
292    let mut t = (u as i32) * q as i32;
293    t = a - t;
294    t >>= 16;
295    t as i16
296}
297
298pub(crate) fn barrett_reduce(a: i16) -> i16 {
299    let v = (((1u32 << 26) + ((q / 2) as u32)) / (q as u32)) as i16;
300    let t = (((v as i32) * (a as i32)) >> 26) as i16;
301    a - (((t as i32) * q as i32) as i16)
302}
303
304pub(super) fn cond_sub_q(a: i16) -> i16 {
305    let tmp = a - q;
306    tmp + ((tmp >> 15) & q)
307}
308
309/// Multiplication of polynomials in Zq\[X]/(X^2-zeta)
310/// used for multiplication of elements in Rq in NTT domain
311///
312/// Borrowed from:
313/// https://github.com/pq-crystals/kyber/blob/main/ref/ntt.c#L139
314pub(crate) fn ntt_base_mult(
315    r: &mut [i16],
316    off: usize,
317    a0: i16,
318    a1: i16,
319    b0: i16,
320    b1: i16,
321    zeta: i16,
322) {
323    let mut out_val0 = mul_mont(a1, b1);
324    out_val0 = mul_mont(out_val0, zeta);
325    out_val0 += mul_mont(a0, b0);
326    r[off] = out_val0;
327
328    let mut out_val1 = mul_mont(a0, b1);
329    out_val1 += mul_mont(a1, b0);
330    r[off + 1] = out_val1;
331}
332
333pub(crate) fn pack_ciphertext<const k: usize, const CT_LEN: usize, const du: i16, const dv: i16>(
334    u: &Vector<k>,
335    v: &Polynomial,
336) -> [u8; CT_LEN] {
337    let mut out = [0u8; CT_LEN];
338
339    // each of the N i16's will take du bits, so a polynomial takes N * du bits, then we have k of them
340    let lim: usize = k * (N * (du as usize) / 8);
341
342    u.compress_pol_vec::<du>(&mut out[..lim]);
343    v.compress_poly::<dv>(&mut out[lim..]);
344    out
345}
346
347pub(crate) fn unpack_ciphertext_u<
348    const k: usize,
349    const CT_LEN: usize,
350    const du: i16,
351    const dv: i16,
352>(
353    c: &[u8; CT_LEN],
354) -> Vector<k> {
355    // each of the N i16's will take du bits, so a polynomial takes N * du bits, then we have k of them
356    let lim: usize = k * (N * (du as usize) / 8);
357
358    let u = Vector::<k>::decompress_pol_vec::<du>(&c[..lim]);
359
360    u
361}
362
363pub(crate) fn unpack_ciphertext_v<
364    const k: usize,
365    const CT_LEN: usize,
366    const du: i16,
367    const dv: i16,
368>(
369    c: &[u8; CT_LEN],
370) -> Polynomial {
371    // each of the N i16's will take du bits, so a polynomial takes N * du bits, then we have k of them
372    let lim: usize = k * (N * (du as usize) / 8);
373
374    let v = Polynomial::decompress_poly::<dv>(&c[lim..]);
375
376    v
377}