Skip to main content

bouncycastle_mlkem_lowmemory/
mlkem.rs

1//! There are no advanced features in this low memory crate that are not already documented in the standard \[bouncycastle_mlkem] crate.
2
3use crate::aux_functions::sample_poly_CBD;
4use crate::low_memory_helpers::{
5    compress_u_row, compute_A_hat_dot_y_hat, compute_t_hat_dot_y_hat_row, unpack_ciphertext_u_row,
6    unpack_ciphertext_v, unpack_t_hat_row,
7};
8use crate::mlkem_keys::{
9    MLKEM512PrivateKey, MLKEM512PublicKey, MLKEM768PrivateKey, MLKEM768PublicKey,
10    MLKEM1024PrivateKey, MLKEM1024PublicKey,
11};
12use crate::mlkem_keys::{MLKEMPrivateKeyInternalTrait, MLKEMPrivateKeyTrait};
13use crate::mlkem_keys::{MLKEMPublicKeyInternalTrait, MLKEMPublicKeyTrait};
14use crate::polynomial::Polynomial;
15use bouncycastle_core::errors::KEMError;
16use bouncycastle_core::key_material::{KeyMaterial, KeyMaterialTrait, KeyType};
17use bouncycastle_core::traits::{Algorithm, Hash, KEM, RNG, SecurityStrength, XOF};
18use bouncycastle_rng::HashDRBG_SHA512;
19use bouncycastle_sha3::{SHA3_256, SHA3_512, SHAKE256};
20use bouncycastle_utils::ct::{conditional_copy_bytes, ct_eq_bytes};
21use core::marker::PhantomData;
22
23/*** Constants ***/
24
25///
26pub const ML_KEM_512_NAME: &str = "ML-KEM-512";
27///
28pub const ML_KEM_768_NAME: &str = "ML-KEM-768";
29///
30pub const ML_KEM_1024_NAME: &str = "ML-KEM-1024";
31
32// From FIPS 203 Table 2 and Table 3
33
34// Constants that are the same for all parameter sets
35/// Length of the \[u8] holding an ML-KEM seed value.
36pub const MLKEM_SEED_LEN: usize = 64;
37/// Length of the \[u8] holding an ML-KEM encaps random value, also sometimes called the message `m`
38pub const MLKEM_RND_LEN: usize = 32;
39/// Size of in bytes of an ML-KEM shared secret key.
40pub const MLKEM_SS_LEN: usize = 32;
41pub(crate) const N: usize = 256;
42pub(crate) const q: i16 = 3329;
43pub(crate) const q_inv: i32 = 62209;
44pub(crate) const ETA2: i16 = 2;
45pub(crate) const POLY_BYTES: usize = 384;
46
47/* ML-KEM-512 params */
48
49/// Length of the \[u8] holding a ML-KEM-512 public key.
50pub const MLKEM512_PK_LEN: usize = 800;
51/// Length of the \[u8] holding a ML-KEM-512 seed-based private key.
52pub const MLKEM512_SK_LEN: usize = MLKEM_SEED_LEN;
53/// Length of the \[u8] holding a full ML-KEM-512 private key in the NIST encoding.
54pub const MLKEM512_FULL_SK_LEN: usize = 1632;
55/// Length of the \[u8] holding a ML-KEM-512 ciphertext.
56pub const MLKEM512_CT_LEN: usize = 768;
57pub(crate) const MLKEM512_k: usize = 2;
58pub(crate) const MLKEM512_ETA1: i16 = 3;
59pub(crate) const MLKEM512_DU: i16 = 10;
60pub(crate) const MLKEM512_DV: i16 = 4;
61/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
62pub(crate) const MLKEM512_LAMBDA: i16 = 128;
63
64// internal derived values
65pub(crate) const MLKEM512_T_PACKED_LEN: usize = 12 * MLKEM512_k * 32;
66
67/* ML-KEM-768 params */
68
69/// Length of the \[u8] holding a ML-KEM-768 public key.
70pub const MLKEM768_PK_LEN: usize = 1184;
71/// Length of the \[u8] holding a ML-KEM-768 seed-based private key.
72pub const MLKEM768_SK_LEN: usize = MLKEM_SEED_LEN;
73/// Length of the \[u8] holding a full ML-KEM-768 private key in the NIST encoding.
74pub const MLKEM768_FULL_SK_LEN: usize = 2400;
75/// Length of the \[u8] holding a ML-KEM-768 ciphertext.
76pub const MLKEM768_CT_LEN: usize = 1088;
77pub(crate) const MLKEM768_k: usize = 3;
78pub(crate) const MLKEM768_ETA1: i16 = 2;
79pub(crate) const MLKEM768_DU: i16 = 10;
80pub(crate) const MLKEM768_DV: i16 = 4;
81/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
82pub(crate) const MLKEM768_LAMBDA: i16 = 192;
83
84// internal derived values
85pub(crate) const MLKEM768_T_PACKED_LEN: usize = 12 * MLKEM768_k * 32;
86
87/* ML-KEM-1024 params */
88
89/// Length of the \[u8] holding a ML-KEM-1024 public key.
90pub const MLKEM1024_PK_LEN: usize = 1568;
91/// Length of the \[u8] holding a ML-KEM-512 seed-based private key.
92pub const MLKEM1024_SK_LEN: usize = MLKEM_SEED_LEN;
93/// Length of the \[u8] holding a full ML-KEM-512 private key in the NIST encoding.
94pub const MLKEM1024_FULL_SK_LEN: usize = 3168;
95/// Length of the \[u8] holding a ML-KEM-1024 ciphertext.
96pub const MLKEM1024_CT_LEN: usize = 1568;
97pub(crate) const MLKEM1024_k: usize = 4;
98pub(crate) const MLKEM1024_ETA1: i16 = 2;
99pub(crate) const MLKEM1024_DU: i16 = 11;
100pub(crate) const MLKEM1024_DV: i16 = 5;
101/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
102pub(crate) const MLKEM1024_LAMBDA: i16 = 256;
103
104// internal derived values
105pub(crate) const MLKEM1024_T_PACKED_LEN: usize = 12 * MLKEM1024_k * 32;
106
107// Typedefs just to make the algorithms look more like the FIPS 204 sample code.
108pub(crate) type G = SHA3_512;
109pub(crate) type H = SHA3_256;
110pub(crate) type J = SHAKE256;
111
112/*** Pub Types ***/
113
114/// The ML-KEM-512 algorithm.
115pub type MLKEM512 = MLKEM<
116    MLKEM512_PK_LEN,
117    MLKEM512_SK_LEN,
118    MLKEM512_FULL_SK_LEN,
119    MLKEM512_CT_LEN,
120    MLKEM_SS_LEN,
121    MLKEM512PublicKey,
122    MLKEM512PrivateKey,
123    MLKEM512_k,
124    MLKEM512_ETA1,
125    MLKEM512_DU,
126    MLKEM512_DV,
127    MLKEM512_LAMBDA,
128    MLKEM512_T_PACKED_LEN,
129>;
130
131impl Algorithm for MLKEM512 {
132    const ALG_NAME: &'static str = ML_KEM_512_NAME;
133    const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_128bit;
134}
135
136/// The ML-KEM-768 algorithm.
137pub type MLKEM768 = MLKEM<
138    MLKEM768_PK_LEN,
139    MLKEM768_SK_LEN,
140    MLKEM768_FULL_SK_LEN,
141    MLKEM768_CT_LEN,
142    MLKEM_SS_LEN,
143    MLKEM768PublicKey,
144    MLKEM768PrivateKey,
145    MLKEM768_k,
146    MLKEM768_ETA1,
147    MLKEM768_DU,
148    MLKEM768_DV,
149    MLKEM768_LAMBDA,
150    MLKEM768_T_PACKED_LEN,
151>;
152
153impl Algorithm for MLKEM768 {
154    const ALG_NAME: &'static str = ML_KEM_768_NAME;
155    const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_192bit;
156}
157
158/// The ML-KEM-1024 algorithm.
159pub type MLKEM1024 = MLKEM<
160    MLKEM1024_PK_LEN,
161    MLKEM1024_SK_LEN,
162    MLKEM1024_FULL_SK_LEN,
163    MLKEM1024_CT_LEN,
164    MLKEM_SS_LEN,
165    MLKEM1024PublicKey,
166    MLKEM1024PrivateKey,
167    MLKEM1024_k,
168    MLKEM1024_ETA1,
169    MLKEM1024_DU,
170    MLKEM1024_DV,
171    MLKEM1024_LAMBDA,
172    MLKEM1024_T_PACKED_LEN,
173>;
174
175impl Algorithm for MLKEM1024 {
176    const ALG_NAME: &'static str = ML_KEM_1024_NAME;
177    const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_256bit;
178}
179
180/// The core internal implementation of the ML-KEM algorithm.
181/// This needs to be public for the compiler to be able to find it, but you shouldn't ever
182/// need to use this directly. Please use the named public types.
183pub struct MLKEM<
184    const PK_LEN: usize,
185    const SK_LEN: usize,
186    const FULL_SK_LEN: usize,
187    const CT_LEN: usize,
188    const SS_LEN: usize,
189    PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
190        + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
191    SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
192        + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
193    const k: usize,
194    const eta1: i16,
195    const du: i16,
196    const dv: i16,
197    const LAMBDA: i16,
198    const T_PACKED_LEN: usize,
199> {
200    _phantom: PhantomData<(PK, SK)>,
201}
202
203impl<
204    const PK_LEN: usize,
205    const SK_LEN: usize,
206    const FULL_SK_LEN: usize,
207    const CT_LEN: usize,
208    const SS_LEN: usize,
209    PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
210        + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
211    SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
212        + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
213    const k: usize,
214    const eta1: i16,
215    const du: i16,
216    const dv: i16,
217    const LAMBDA: i16,
218    const T_PACKED_LEN: usize,
219>
220    MLKEM<
221        PK_LEN,
222        SK_LEN,
223        FULL_SK_LEN,
224        CT_LEN,
225        SS_LEN,
226        PK,
227        SK,
228        k,
229        eta1,
230        du,
231        dv,
232        LAMBDA,
233        T_PACKED_LEN,
234    >
235{
236    /// Should still be ok in FIPS mode
237    pub fn keygen_from_os_rng() -> Result<(PK, SK), KEMError> {
238        let mut seed = KeyMaterial::<64>::new();
239        HashDRBG_SHA512::new_from_os().fill_keymaterial_out(&mut seed)?;
240        // Self::keygen_internal(&seed)
241        Self::keygen_internal(&seed)
242    }
243    /// Performs the first step of key generation to transform the single provided seed into a set of internal intermediate seeds.
244    ///
245    /// Unlike other interfaces across the library that take an &impl KeyMaterial, this one
246    /// specifically takes a 64-byte [KeyMaterial512] and checks that it has [KeyType::Seed] and
247    /// the appropriate [SecurityStrength] for the requested ML-KEM parameter set.
248    /// If you happen to have your seed in a larger KeyMaterial, you'll have to copy it using
249    /// [KeyMaterial::from_key].
250    pub(crate) fn keygen_internal(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError> {
251        let sk = SK::from_keymaterial(seed)?;
252        let pk = sk.pk();
253        let pk = PK::new(pk.t_hat_packed, pk.rho); // stupid conversion, but it gets around these overly-generified rust types
254        Ok((pk, sk))
255    }
256
257    /// Algorithm 14 K-PKE.Encrypt(ekPKE, ๐‘š, ๐‘Ÿ)
258    /// Uses the encryption key to encrypt a plaintext message using the randomness ๐‘Ÿ.
259    /// Input: encryption key ekPKE โˆˆ ๐”น384๐‘˜+32 .
260    /// Input: message ๐‘š โˆˆ ๐”น32 .
261    /// Input: randomness ๐‘Ÿ โˆˆ ๐”น32 .
262    /// Output: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
263    fn pke_encrypt(
264        t_hat_packed: &[u8; T_PACKED_LEN],
265        rho: &[u8; 32],
266        m: [u8; 32],
267        r: &[u8; 32],
268    ) -> [u8; CT_LEN] {
269        let mut ct = [0u8; CT_LEN];
270
271        // 1: ๐‘ โ† 0
272        //  since the number of loops here is static; we can hard-code the N values rather than using a counter
273
274        // 2: ๐ญ โ† ByteDecode12(ekPKE[0 โˆถ 384๐‘˜])
275        // 3: ๐œŒ โ† ekPKE[384๐‘˜ โˆถ 384๐‘˜ + 32]
276        // not necessary here because ek is already decoded
277
278        // 19: ๐ฎ โ† NTTโˆ’1(๐€_hat^โŠบ โˆ˜ ๐ฒ_hat) + ๐ž1
279        // 22: ๐‘1 โ† ByteEncode_๐‘‘๐‘ข(Compress_๐‘‘๐‘ข(๐ฎ))
280
281        // Note: you need y_hat twice: once here at line 19, and again at line 21.
282        //  We'll just generate it twice to save the memory of holding on to it.
283        for i in 0..k {
284            let mut u_i = compute_A_hat_dot_y_hat::<k, eta1>(rho, &r, i);
285
286            let e1_i = sample_poly_CBD::<ETA2>(&r, (k + i) as u8);
287            u_i.add(&e1_i);
288            u_i.poly_reduce();
289
290            compress_u_row::<du, CT_LEN>(u_i, i, &mut ct);
291        }
292
293        // 17: ๐‘’2 โ† SamplePolyCBD_๐œ‚2(PRF๐œ‚2 (๐‘Ÿ, ๐‘))
294        // 20: ๐œ‡ โ† Decompress1(ByteDecode1(๐‘š))
295        // 21: ๐‘ฃ โ† NTTโˆ’1(๐ญ_hat_T โˆ˜ ๐ฒ_hat) + ๐‘’2 + ๐œ‡
296        // 23: ๐‘2 โ† ByteEncode_๐‘‘๐‘ฃ(Compress_๐‘‘๐‘ฃ(๐‘ฃ))
297        {
298            // compute v, which is a single polynomial, but requires iterating over the vectors t_hat and y_hat
299            let mut v = compute_t_hat_dot_y_hat_row::<k, eta1>(
300                &r,
301                &unpack_t_hat_row(t_hat_packed, 0),
302                /*row*/ 0,
303            );
304
305            for i in 1..k {
306                let v_i = compute_t_hat_dot_y_hat_row::<k, eta1>(
307                    &r,
308                    &unpack_t_hat_row(t_hat_packed, i),
309                    /*row*/ i,
310                );
311                v.add(&v_i);
312            }
313
314            // perform polynomial addition
315            let e2 = sample_poly_CBD::<ETA2>(&r, 2 * k as u8);
316            v.add(&e2);
317
318            let mu = Polynomial::from_msg(m);
319            v.add(&mu);
320
321            v.poly_reduce();
322
323            v.compress_poly::<dv>(&mut ct[CT_LEN - (N * (dv as usize) / 8)..]);
324        }
325
326        ct
327    }
328
329    /// Algorithm 17 ML-KEM.Encaps_internal(ek, ๐‘š)
330    /// Uses the encapsulation key and randomness to generate a key and an associated ciphertext.
331    /// Input: encapsulation key ek โˆˆ ๐”น384๐‘˜+32 .
332    /// Input: randomness ๐‘š โˆˆ ๐”น32 .
333    /// Output: shared secret key ๐พ โˆˆ ๐”น32 .
334    /// Output: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
335    ///
336    /// Unlike the more public function exposed by [KEM::encaps], this returns the shared secret as raw bytes
337    /// instead of wrapped in an appropriately-set [KeyMaterialTrait], so you're on your own for handling it properly.
338    ///
339    /// Note: this is an internal function that allows the caller to specify the encapsulation
340    /// randomness (which is the message `m` to be encrypted by the underlying PKE scheme).
341    /// This function should not be used directly unless you really have a
342    /// good reason. [KEM::encaps] should be used in 99.9% of cases.
343    /// The reason this is exposed publicly is: A) for unit testing that requires access
344    /// to the deterministically reproducible function, and B) for operational environments
345    /// that wish to provide randomness from their own source instead of the built-in RNG in bc-rust.
346    /// If you think you will be clever and invent some scheme that uses a deterministic KEM,
347    /// then you will almost certainly end up with security problems. Please don't do this.
348    pub fn encaps_internal(ek: &PK, m: [u8; 32]) -> ([u8; 32], [u8; CT_LEN]) {
349        debug_assert_eq!(CT_LEN, 32 * ((du as usize) * k + (dv as usize)));
350
351        // 1: (๐พ, ๐‘Ÿ) โ† G(๐‘šโ€–H(ek))
352        //  โ–ท derive shared secret key ๐พ and randomness ๐‘Ÿ
353        let K: [u8; MLKEM_SS_LEN];
354        let r: [u8; 32];
355        (K, r) = {
356            let mut g = G::new();
357            g.do_update(&m);
358            g.do_update(&ek.compute_hash());
359            let mut buf = [0u8; 64];
360            let bytes_written = g.do_final_out(&mut buf);
361            debug_assert_eq!(bytes_written, 64);
362
363            (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
364        };
365
366        // 2: ๐‘ โ† K-PKE.Encrypt(ek, ๐‘š, ๐‘Ÿ)
367        //  โ–ท encrypt ๐‘š using K-PKE with randomness ๐‘Ÿ
368        // deviation from FIPS:
369        let ct = Self::pke_encrypt(ek.t_hat_packed(), ek.rho(), m, &r);
370
371        (K, ct)
372    }
373
374    /// Algorithm 15 K-PKE.Decrypt(dkPKE, ๐‘)
375    /// Uses the decryption key to decrypt a ciphertext
376    /// Input: decryption key dkPKE โˆˆ ๐”น384๐‘˜.
377    /// Input: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
378    /// Output: message ๐‘š โˆˆ ๐”น32 .
379    fn pke_decrypt(dk: &SK, ct: [u8; CT_LEN]) -> [u8; 32] {
380        // 1: ๐‘1 โ† ๐‘[0 โˆถ 32๐‘‘๐‘ข๐‘˜]
381        // 3: ๐ฎโ€ฒ โ† Decompress_๐‘‘๐‘ข(ByteDecode_๐‘‘๐‘ข(๐‘1))
382
383        // 5: ๐ฌ_hat โ† ByteDecode12(dkPKE)
384        //   Unnecessary here because we're gonna re-compute them row-by-row
385
386        // first half of
387        // 6: ๐‘ค โ† ๐‘ฃโ€ฒ โˆ’ NTTโˆ’1(๐ฌ_hat^T โˆ˜ NTT(๐ฎโ€ฒ))
388        let v1 = {
389            // i = 0 case
390            let mut v1 = {
391                let mut s_hat_i = dk.compute_s_hat_row(0);
392                {
393                    let mut u_prime_i = unpack_ciphertext_u_row::<du, CT_LEN>(0, &ct);
394                    u_prime_i.ntt();
395                    s_hat_i.base_mult_montgomery(&u_prime_i);
396                }
397                s_hat_i.inv_ntt();
398
399                s_hat_i
400            };
401
402            for i in 1..k {
403                let mut s_hat_i = dk.compute_s_hat_row(i);
404                {
405                    let mut u_prime_i = unpack_ciphertext_u_row::<du, CT_LEN>(i, &ct);
406                    u_prime_i.ntt();
407                    s_hat_i.base_mult_montgomery(&u_prime_i);
408                }
409                s_hat_i.inv_ntt();
410                v1.add(&s_hat_i);
411            }
412
413            v1
414        };
415
416        // 2: ๐‘2 โ† ๐‘[32๐‘‘๐‘ข๐‘˜ โˆถ 32(๐‘‘๐‘ข๐‘˜ + ๐‘‘๐‘ฃ)]
417        // 4: ๐‘ฃโ€ฒ โ† Decompress_๐‘‘๐‘ฃ(ByteDecode_๐‘‘๐‘ฃ(๐‘2))
418        let w = {
419            // second half of
420            // 6: ๐‘ค โ† ๐‘ฃโ€ฒ โˆ’ NTTโˆ’1(๐ฌ_hat^T โˆ˜ NTT(๐ฎโ€ฒ))
421            let mut v_prime = unpack_ciphertext_v::<k, CT_LEN, du, dv>(&ct);
422
423            v_prime.sub(&v1);
424            v_prime.poly_reduce();
425
426            v_prime // rename to w
427        };
428
429        // 7: ๐‘š โ† ByteEncode1(Compress1(๐‘ค))
430        //   โ–ท decode plaintext ๐‘š from polynomial ๐‘ค
431        w.to_msg()
432    }
433
434    /// Algorithm 18 ML-KEM.Decaps_internal(dk, ๐‘)
435    /// Uses the decapsulation key to produce a shared secret key from a ciphertext.
436    /// Input: decapsulation key dk โˆˆ ๐”น768๐‘˜+96 .
437    /// Input: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
438    /// Output: shared secret key ๐พ โˆˆ ๐”น32 .
439    fn decaps_internal(dk: &SK, c: [u8; CT_LEN]) -> [u8; MLKEM_SS_LEN] {
440        // I have tried to keep this as clean as possible for correspondence with the FIPS,
441        // but I have moved things around so that I can use unnamed scopes to limit how many
442        // stack variables are alive at the same time.
443
444        // 1: dkPKE โ† dk[0 โˆถ 384๐‘˜] โ–ท extract (from KEM decaps key) the PKE decryption key
445        // 2: ekPKE โ† dk[384๐‘˜ โˆถ 768๐‘˜ + 32] โ–ท extract PKE encryption key
446        // 3: โ„Ž โ† dk[768๐‘˜ + 32 โˆถ 768๐‘˜ + 64] โ–ท extract hash of PKE encryption key
447        // 4: ๐‘ง โ† dk[768๐‘˜ + 64 โˆถ 768๐‘˜ + 96] โ–ท extract implicit rejection value
448        // Nothing to do since dk is already decoded.
449
450        // 5: ๐‘šโ€ฒ โ† K-PKE.Decrypt(dkPKE, ๐‘)
451        let m_prime = Self::pke_decrypt(&dk, c);
452
453        // Compute the trial shared secret key
454        // 6: (๐พโ€ฒ, ๐‘Ÿโ€ฒ) โ† G(๐‘šโ€ฒโ€–โ„Ž)ฬ„
455        let K_prime: [u8; MLKEM_SS_LEN];
456        let r_prime: [u8; 32];
457        (K_prime, r_prime) = {
458            let mut g = G::new();
459            g.do_update(&m_prime);
460            g.do_update(&dk.pk().compute_hash());
461            let mut buf = [0u8; 64];
462            let bytes_written = g.do_final_out(&mut buf);
463            debug_assert_eq!(bytes_written, 64);
464
465            (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
466        };
467
468        // 7: ๐พ_bar โ† J(๐‘งโ€–๐‘)
469        //   Compute the rejection sampling key.
470        //   Note to future optimizers: this needs to be computed outside of the if at line 9 below
471        //   because if its computation is conditional on the Fujisaki-Okamoto check failing, then
472        //   you'll have a timing difference between success and failure.
473
474        let K_bar: [u8; MLKEM_SS_LEN];
475        K_bar = {
476            let mut j = J::new();
477            j.absorb(dk.z());
478            j.absorb(&c);
479            let mut buf = [0u8; MLKEM_SS_LEN];
480            let bytes_written = j.squeeze_out(&mut buf);
481            debug_assert_eq!(bytes_written, MLKEM_SS_LEN);
482
483            buf
484        };
485
486        // 8: ๐‘โ€ฒ โ† K-PKE.Encrypt(ekPKE, ๐‘šโ€ฒ, ๐‘Ÿโ€ฒ)
487        //   โ–ท re-encrypt using the derived randomness ๐‘Ÿโ€ฒ
488        let c_prime = Self::pke_encrypt(&dk.t_hat_packed(), dk.rho(), m_prime, &r_prime);
489
490        // 9: if ๐‘ โ‰  ๐‘โ€ฒ then
491        // 10: ๐พโ€ฒ โ† ๐พ_bar
492        //  โ–ท if ciphertexts do not match, โ€œimplicitly reject"
493        let mut K_out = [0u8; MLKEM_SS_LEN];
494        conditional_copy_bytes(&K_prime, &K_bar, &mut K_out, ct_eq_bytes(&c, &c_prime));
495
496        K_out
497    }
498
499    /// Alternative initialization of the streaming signer where you have your private key
500    /// as a seed and you want to delay its expansion as late as possible for memory-usage reasons.
501    pub fn decaps_from_seed(
502        seed: &KeyMaterial<64>,
503        ct: &[u8],
504    ) -> Result<KeyMaterial<SS_LEN>, KEMError> {
505        let sk = SK::from_keymaterial(seed)?;
506
507        Self::decaps(&sk, ct)
508    }
509}
510
511impl<
512    const PK_LEN: usize,
513    const SK_LEN: usize,
514    const FULL_SK_LEN: usize,
515    const CT_LEN: usize,
516    const SS_LEN: usize,
517    PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
518        + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
519    SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
520        + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
521    const k: usize,
522    const eta1: i16,
523    const du: i16,
524    const dv: i16,
525    const LAMBDA: i16,
526    const T_PACKED_LEN: usize,
527>
528    MLKEMTrait<
529        PK_LEN,
530        SK_LEN,
531        FULL_SK_LEN,
532        CT_LEN,
533        SS_LEN,
534        PK,
535        SK,
536        k,
537        eta1,
538        du,
539        dv,
540        LAMBDA,
541        T_PACKED_LEN,
542    >
543    for MLKEM<
544        PK_LEN,
545        SK_LEN,
546        FULL_SK_LEN,
547        CT_LEN,
548        SS_LEN,
549        PK,
550        SK,
551        k,
552        eta1,
553        du,
554        dv,
555        LAMBDA,
556        T_PACKED_LEN,
557    >
558{
559    /// Imports a secret key from a seed.
560    fn keygen_from_seed(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError> {
561        Self::keygen_internal(seed)
562    }
563    /// Imports a secret key from both a seed and an encoded_sk.
564    ///
565    /// This is a convenience function to expand the key from seed and compare it against
566    /// the provided `encoded_sk` using a constant-time equality check.
567    /// If everything checks out, the secret key is returned fully populated with pk and seed.
568    /// If the provided key and derived key don't match, an error is returned.
569    fn keygen_from_seed_and_encoded(
570        seed: &KeyMaterial<64>,
571        encoded_sk: &[u8; SK_LEN],
572    ) -> Result<(PK, SK), KEMError> {
573        let (pk, sk) = Self::keygen_internal(seed)?;
574
575        let sk_from_bytes = SK::sk_decode(encoded_sk);
576
577        // MLKEMPrivateKey impls PartialEq with a constant-time equality check.
578        if sk != sk_from_bytes {
579            return Err(KEMError::KeyGenError("Encoded key does not match generated key"));
580        }
581
582        Ok((pk, sk))
583    }
584    /// Given a public key and a secret key, check that the public key matches the secret key.
585    /// This is a sanity check that the public key was generated correctly from the secret key.
586    ///
587    /// At the current time, this is only possible if `sk` either contains a public key (in which case
588    /// the two pk's are encoded and compared for byte equality), or if `sk` contains a seed
589    /// (in which case a keygen_from_seed is run and then the pk's compared).
590    ///
591    /// Returns either `()` or [KEMError::ConsistencyCheckFailed].
592    fn keypair_consistency_check(pk: &PK, sk: &SK) -> Result<(), KEMError> {
593        let derived_pk = sk.pk();
594        if derived_pk.compute_hash() == pk.compute_hash() {
595            Ok(())
596        } else {
597            Err(KEMError::ConsistencyCheckFailed(""))
598        }
599    }
600}
601
602/// Trait for all three of the ML-DSA algorithm variants.
603pub trait MLKEMTrait<
604    const PK_LEN: usize,
605    const SK_LEN: usize,
606    const FULL_SK_LEN: usize,
607    const CT_LEN: usize,
608    const SS_LEN: usize,
609    PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
610        + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
611    SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
612        + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
613    const k: usize,
614    const eta: i16,
615    const du: i16,
616    const dv: i16,
617    const LAMBDA: i16,
618    const T_PACKED_LEN: usize,
619>: Sized
620{
621    /// Imports a secret key from a seed.
622    fn keygen_from_seed(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError>;
623    /// Imports a secret key from both a seed and an encoded_sk.
624    ///
625    /// This is a convenience function to expand the key from seed and compare it against
626    /// the provided `encoded_sk` using a constant-time equality check.
627    /// If everything checks out, the secret key is returned fully populated with pk and seed.
628    /// If the provided key and derived key don't match, an error is returned.
629    fn keygen_from_seed_and_encoded(
630        seed: &KeyMaterial<64>,
631        encoded_sk: &[u8; SK_LEN],
632    ) -> Result<(PK, SK), KEMError>;
633    /// Given a public key and a secret key, check that the public key matches the secret key.
634    /// This is a sanity check that the public key was generated correctly from the secret key.
635    ///
636    /// At the current time, this is only possible if `sk` either contains a public key (in which case
637    /// the two pk's are encoded and compared for byte equality), or if `sk` contains a seed
638    /// (in which case a keygen_from_seed is run and then the pk's compared).
639    ///
640    /// Returns either `()` or [KEMError::ConsistencyCheckFailed].
641    fn keypair_consistency_check(pk: &PK, sk: &SK) -> Result<(), KEMError>;
642}
643
644impl<
645    const PK_LEN: usize,
646    const SK_LEN: usize,
647    const FULL_SK_LEN: usize,
648    const CT_LEN: usize,
649    const SS_LEN: usize,
650    PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
651        + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
652    SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
653        + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
654    const k: usize,
655    const eta: i16,
656    const du: i16,
657    const dv: i16,
658    const LAMBDA: i16,
659    const T_PACKED_LEN: usize,
660> KEM<PK, SK, PK_LEN, SK_LEN, CT_LEN, SS_LEN>
661    for MLKEM<
662        PK_LEN,
663        SK_LEN,
664        FULL_SK_LEN,
665        CT_LEN,
666        SS_LEN,
667        PK,
668        SK,
669        k,
670        eta,
671        du,
672        dv,
673        LAMBDA,
674        T_PACKED_LEN,
675    >
676{
677    /// Generates a fresh key pair.
678    fn keygen() -> Result<(PK, SK), KEMError> {
679        Self::keygen_from_os_rng()
680    }
681
682    fn encaps(pk: &PK) -> Result<(KeyMaterial<SS_LEN>, [u8; CT_LEN]), KEMError> {
683        let mut m = [0u8; 32];
684        HashDRBG_SHA512::new_from_os().next_bytes_out(&mut m)?;
685
686        let (ss_bytes, ct) = Self::encaps_internal(pk, m);
687
688        let mut ss_keymaterial =
689            KeyMaterial::<SS_LEN>::from_bytes_as_type(&ss_bytes, KeyType::BytesFullEntropy)?;
690        ss_keymaterial.allow_hazardous_operations();
691        ss_keymaterial.set_security_strength(SecurityStrength::from_bits(LAMBDA as usize))?;
692        ss_keymaterial.drop_hazardous_operations();
693
694        Ok((ss_keymaterial, ct))
695    }
696    /// Performs a decapsulation of the given ciphertext.
697    /// Returns the shared secret key.
698    /// The derived shared secret key is returned as a KeyMaterial with the SecurityStrength set to
699    /// the security level of the ML-KEM parameter set.
700    /// As ML-KEM is an implicitly-rejecting KEM, this returns an error only if the ciphertext is invalid (ie the wrong length).
701    fn decaps(sk: &SK, ct: &[u8]) -> Result<KeyMaterial<SS_LEN>, KEMError> {
702        if ct.len() != CT_LEN {
703            return Err(KEMError::LengthError("Invalid ciphertext length"));
704        }
705
706        let ss_bytes = Self::decaps_internal(sk, ct.try_into().unwrap());
707
708        let mut ss_keymaterial =
709            KeyMaterial::<SS_LEN>::from_bytes_as_type(&ss_bytes, KeyType::BytesFullEntropy)?;
710        ss_keymaterial.allow_hazardous_operations();
711        ss_keymaterial.set_security_strength(SecurityStrength::from_bits(LAMBDA as usize))?;
712        ss_keymaterial.drop_hazardous_operations();
713
714        Ok(ss_keymaterial)
715    }
716}