Skip to main content

bouncycastle_mlkem/
mlkem.rs

1//! This page documents advanced features of the Module Lattice Key-Encapsulation Algorithm (ML-KEM)
2//! available in this crate.
3//!
4//! # Pre-expanding the public key for repeated use
5//!
6//! Within the usual ML-KEM public key representation, the public matrix A is stored as a seed rho, which
7//! means that both the ML-KEM.encops() and ML-KEM.decaps() operations need to expand it into a full matrix
8//! before performing the matrix multiplication.
9//! We offer a version of the public and private key structs that pre-expand the public matrix for repeated use.
10//!
11//! When done as part of the keygen, expansion of the public matrix accounts for roughly 25% of the keygen time,
12//! however it accounts for roughly 35% / 60% / 80% of an encaps and 30% / 45% / 65% of a decaps
13//! for MLKEM512 / MLKEM768 / MLKEM1024.
14//!
15//! Most often, ML-KEM is used in an ephemeral mode where a key pair is generated, used for a single encaps
16//! and decaps and then discarded. In this mode, there is no performance difference to whether the
17//! public matrix A is expanded as part of keygen or as part of encaps / decaps, but it does make both
18//! the public and private key take up more space in memory, so the default ML-KEM public and private key
19//! objects defer expansion until it is needed.
20//!
21//! However, in non-ephemeral uses where many encaps or decaps operations are performed against the same
22//! key pair in quick succession, there can be substantial performance improvements to pre-computing
23//! this and holding on to a larger key object.
24//! This is accomplished via constructing a [MLKEMPublicKeyExpanded] or [MLKEMPrivateKeyExpanded] object
25//! of the appropriate parameter set from the original key, and then using this with [MLKEM::encaps_for_expanded_key]
26//! or [MLKEM::decaps_with_expanded_key].
27//! Both [MLKEMPublicKeyExpanded] and [MLKEMPrivateKeyExpanded] implement the same traits
28//! and therefore behave the same as their non-expanded counterparts in most regards.
29//!
30//! ```rust
31//! use bouncycastle_mlkem::{MLKEM768, MLKEMTrait};
32//! use bouncycastle_mlkem::{MLKEM768PublicKeyExpanded, MLKEM768PrivateKeyExpanded};
33//! use bouncycastle_core::traits::KEM;
34//! use bouncycastle_core::errors::KEMError;
35//!
36//! let (pk, sk) = MLKEM768::keygen().unwrap();
37//!
38//! // Pre-expand the public key uses more memory, but has performance
39//! // improvements if doing multiple encapsulations for the same key
40//! let pk_expanded = MLKEM768PublicKeyExpanded::from(&pk);
41//! let (ss, ct) = MLKEM768::encaps_for_expanded_key(&pk_expanded).unwrap();
42//!
43//! // Pre-expand the private key, which uses more memory, but has performance
44//! // improvements if doing multiple decapsulations with the same key
45//! let sk_expanded = MLKEM768PrivateKeyExpanded::from(&sk);
46//! let ss1 = match MLKEM768::decaps_with_expanded_key(&sk_expanded, &ct) {
47//!     Err(KEMError) => panic!("Error decapsulating"),
48//!     Ok(ss) => ss,
49//! };
50//!
51//! assert_eq!(ss, ss1);
52//! ```
53//!
54//! # decaps_from_seed
55//!
56//! This mode is intended for users who want the simplicity of storing only the seed form of the private key.
57//! This is merely a convnience function that calls [MLKEM::keygen_from_seed) before performing a decapsulation.
58//!
59//! Example usage:
60//!
61//! ```rust
62//! use bouncycastle_mlkem::{MLKEM768, MLKEMTrait};
63//! use bouncycastle_core::traits::KEM;
64//! use bouncycastle_core::errors::KEMError;
65//! use bouncycastle_core::key_material::{KeyMaterial512, KeyType};
66//! use bouncycastle_hex as hex;
67//!
68//! let seed = KeyMaterial512::from_bytes_as_type(
69//!     &hex::decode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f
70//!                   202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f").unwrap(),
71//!     KeyType::Seed,
72//! ).unwrap();
73//!
74//! // for this demo, we do need to run keygen only to get the public key
75//! let (pk, _sk) = MLKEM768::keygen_from_seed(&seed).unwrap();
76//!
77//! // Create the shared secret and ciphertext using the public key
78//! let (ss, ct) = MLKEM768::encaps(&pk).unwrap();
79//!
80//! // Recover the shared secret using the private key seed
81//! let ss1 = match MLKEM768::decaps_from_seed(&seed, &ct) {
82//!     Err(KEMError) => panic!("Error decapsulating"),
83//!     Ok(ss) => ss,
84//! };
85//!
86//! assert_eq!(ss, ss1);
87//! ```
88//!
89//! While this is currently only supported when operating from a seed-based private key, something analogous
90//! could be done that merges the sk_decode() and sign() routines when working with the standardized
91//! private key encoding (which is often called the "semi-expanded format" since the in-memory representation
92//! is still larger).
93//! Contact us if you need such a thing implemented.
94//! ## Deterministic encapsulation
95//!
96//! This section pertains to [MLKEM::encaps_internal] which allows you to pass in the encapsulation randomness
97//! and thus obtain a deterministic encapsulation.
98//!
99//! The only good reasons for doing this are: A) testing if you need reproducible results, or
100//! B) if you want to use your own source of randomness, such as a hardware RNG, instead of the library's
101//! default RNG.
102//! If you think you will invent same clever cryptographic scheme by making clever use of this parameter:
103//! don't; you will almost certainly end up with something completely insecure.
104//!
105//! ```rust
106//! use bouncycastle_mlkem::{MLKEM768, MLKEMTrait};
107//! use bouncycastle_core::traits::{KEM};
108//! use bouncycastle_core::errors::KEMError;
109//! use bouncycastle_core::key_material::KeyMaterialTrait;
110//!
111//! let (pk, sk) = MLKEM768::keygen().unwrap();
112//! // note: totally insecure and for demonstration purposes only.
113//! //       The message `m` needs to be sourced from a cryptographically-secure RNG.
114//! let m: [u8; 32] = [0; 32];
115//!
116//! // Create the shared secret and ciphertext using the public key and the random message `m`
117//! let (ss, ct) = MLKEM768::encaps_internal(&pk, None, m);
118//!
119//! // Recover the shared secret using the private key//!
120//! let ss1 = match MLKEM768::decaps(&sk, &ct) {
121//!     Err(KEMError) => panic!("Error decapsulating"),
122//!     Ok(ss) => ss,
123//! };
124//!
125//! assert_eq!(ss, ss1.ref_to_bytes());
126//! ```
127
128use crate::MLKEMPublicKeyExpanded;
129use crate::aux_functions::{
130    expandA, pack_ciphertext, sample_poly_CBD, sample_vector_CBD, unpack_ciphertext_u,
131    unpack_ciphertext_v,
132};
133use crate::matrix::{Matrix, Vector};
134use crate::mlkem_keys::{
135    MLKEM512PrivateKey, MLKEM512PublicKey, MLKEM768PrivateKey, MLKEM768PublicKey,
136    MLKEM1024PrivateKey, MLKEM1024PublicKey,
137};
138use crate::mlkem_keys::{
139    MLKEMPrivateKeyExpanded, MLKEMPublicKeyInternalTrait, MLKEMPublicKeyTrait,
140};
141use crate::mlkem_keys::{MLKEMPrivateKeyInternalTrait, MLKEMPrivateKeyTrait};
142use crate::polynomial::Polynomial;
143use bouncycastle_core::errors::KEMError;
144use bouncycastle_core::key_material::{KeyMaterial, KeyMaterialTrait, KeyType};
145use bouncycastle_core::traits::{Algorithm, Hash, KEM, RNG, SecurityStrength, XOF};
146use bouncycastle_rng::HashDRBG_SHA512;
147use bouncycastle_sha3::{SHA3_256, SHA3_512, SHAKE256};
148use bouncycastle_utils::ct::{conditional_copy_bytes, ct_eq_bytes};
149use core::marker::PhantomData;
150
151/*** Constants ***/
152
153///
154pub const ML_KEM_512_NAME: &str = "ML-KEM-512";
155///
156pub const ML_KEM_768_NAME: &str = "ML-KEM-768";
157///
158pub const ML_KEM_1024_NAME: &str = "ML-KEM-1024";
159
160// From FIPS 203 Table 2 and Table 3
161
162// Constants that are the same for all parameter sets
163/// Length of the \[u8] holding an ML-KEM seed value.
164pub const MLKEM_SEED_LEN: usize = 64;
165/// Length of the \[u8] holding an ML-KEM encaps random value, also sometimes called the message `m`
166pub const MLKEM_RND_LEN: usize = 32;
167/// Size of in bytes of an ML-KEM shared secret key.
168pub const MLKEM_SS_LEN: usize = 32;
169pub(crate) const N: usize = 256;
170pub(crate) const q: i16 = 3329;
171pub(crate) const q_inv: i32 = 62209;
172pub(crate) const ETA2: i16 = 2;
173pub(crate) const POLY_BYTES: usize = 384;
174
175/* ML-KEM-512 params */
176
177/// Length of the \[u8] holding a ML-KEM-512 public key.
178pub const MLKEM512_PK_LEN: usize = 800;
179/// Length of the \[u8] holding a ML-KEM-512 private key.
180pub const MLKEM512_SK_LEN: usize = 1632;
181/// Length of the \[u8] holding a ML-KEM-512 ciphertext.
182pub const MLKEM512_CT_LEN: usize = 768;
183pub(crate) const MLKEM512_k: usize = 2;
184pub(crate) const MLKEM512_ETA1: i16 = 3;
185pub(crate) const MLKEM512_DU: i16 = 10;
186pub(crate) const MLKEM512_DV: i16 = 4;
187/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
188pub(crate) const MLKEM512_LAMBDA: i16 = 128;
189
190/* ML-KEM-768 params */
191
192/// Length of the \[u8] holding a ML-KEM-768 public key.
193pub const MLKEM768_PK_LEN: usize = 1184;
194/// Length of the \[u8] holding a ML-KEM-768 private key.
195pub const MLKEM768_SK_LEN: usize = 2400;
196/// Length of the \[u8] holding a ML-KEM-768 ciphertext.
197pub const MLKEM768_CT_LEN: usize = 1088;
198pub(crate) const MLKEM768_k: usize = 3;
199pub(crate) const MLKEM768_ETA1: i16 = 2;
200pub(crate) const MLKEM768_DU: i16 = 10;
201pub(crate) const MLKEM768_DV: i16 = 4;
202/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
203pub(crate) const MLKEM768_LAMBDA: i16 = 192;
204
205/* ML-KEM-1024 params */
206
207/// Length of the \[u8] holding a ML-KEM-1024 public key.
208pub const MLKEM1024_PK_LEN: usize = 1568;
209/// Length of the \[u8] holding a ML-KEM-1024 private key.
210pub const MLKEM1024_SK_LEN: usize = 3168;
211/// Length of the \[u8] holding a ML-KEM-1024 ciphertext.
212pub const MLKEM1024_CT_LEN: usize = 1568;
213pub(crate) const MLKEM1024_k: usize = 4;
214pub(crate) const MLKEM1024_ETA1: i16 = 2;
215pub(crate) const MLKEM1024_DU: i16 = 11;
216pub(crate) const MLKEM1024_DV: i16 = 5;
217/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
218pub(crate) const MLKEM1024_LAMBDA: i16 = 256;
219
220// Typedefs just to make the algorithms look more like the FIPS 204 sample code.
221pub(crate) type G = SHA3_512;
222pub(crate) type H = SHA3_256;
223pub(crate) type J = SHAKE256;
224
225/*** Pub Types ***/
226
227/// The ML-KEM-512 algorithm.
228pub type MLKEM512 = MLKEM<
229    MLKEM512_PK_LEN,
230    MLKEM512_SK_LEN,
231    MLKEM512_CT_LEN,
232    MLKEM_SS_LEN,
233    MLKEM512PublicKey,
234    MLKEM512PrivateKey,
235    MLKEM512_k,
236    MLKEM512_ETA1,
237    MLKEM512_DU,
238    MLKEM512_DV,
239    MLKEM512_LAMBDA,
240>;
241
242impl Algorithm for MLKEM512 {
243    const ALG_NAME: &'static str = ML_KEM_512_NAME;
244    const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_128bit;
245}
246
247/// The ML-KEM-768 algorithm.
248pub type MLKEM768 = MLKEM<
249    MLKEM768_PK_LEN,
250    MLKEM768_SK_LEN,
251    MLKEM768_CT_LEN,
252    MLKEM_SS_LEN,
253    MLKEM768PublicKey,
254    MLKEM768PrivateKey,
255    MLKEM768_k,
256    MLKEM768_ETA1,
257    MLKEM768_DU,
258    MLKEM768_DV,
259    MLKEM768_LAMBDA,
260>;
261
262impl Algorithm for MLKEM768 {
263    const ALG_NAME: &'static str = ML_KEM_768_NAME;
264    const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_192bit;
265}
266
267/// The ML-KEM-1024 algorithm.
268pub type MLKEM1024 = MLKEM<
269    MLKEM1024_PK_LEN,
270    MLKEM1024_SK_LEN,
271    MLKEM1024_CT_LEN,
272    MLKEM_SS_LEN,
273    MLKEM1024PublicKey,
274    MLKEM1024PrivateKey,
275    MLKEM1024_k,
276    MLKEM1024_ETA1,
277    MLKEM1024_DU,
278    MLKEM1024_DV,
279    MLKEM1024_LAMBDA,
280>;
281
282impl Algorithm for MLKEM1024 {
283    const ALG_NAME: &'static str = ML_KEM_1024_NAME;
284    const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_256bit;
285}
286
287/// The core internal implementation of the ML-KEM algorithm.
288/// This needs to be public for the compiler to be able to find it, but you shouldn't ever
289/// need to use this directly. Please use the named public types.
290pub struct MLKEM<
291    const PK_LEN: usize,
292    const SK_LEN: usize,
293    const CT_LEN: usize,
294    const SS_LEN: usize,
295    PK: MLKEMPublicKeyTrait<k, PK_LEN> + MLKEMPublicKeyInternalTrait<k, PK_LEN>,
296    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN>
297        + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
298    const k: usize,
299    const eta: i16,
300    const du: i16,
301    const dv: i16,
302    const LAMBDA: i16,
303> {
304    _phantom: PhantomData<(PK, SK)>,
305}
306
307impl<
308    const PK_LEN: usize,
309    const SK_LEN: usize,
310    const CT_LEN: usize,
311    const SS_LEN: usize,
312    PK: MLKEMPublicKeyTrait<k, PK_LEN> + MLKEMPublicKeyInternalTrait<k, PK_LEN>,
313    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN>
314        + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
315    const k: usize,
316    const eta1: i16,
317    const du: i16,
318    const dv: i16,
319    const LAMBDA: i16,
320> MLKEM<PK_LEN, SK_LEN, CT_LEN, SS_LEN, PK, SK, k, eta1, du, dv, LAMBDA>
321{
322    /// Should still be ok in FIPS mode
323    pub fn keygen_from_os_rng() -> Result<(PK, SK), KEMError> {
324        let mut seed = KeyMaterial::<64>::new();
325        HashDRBG_SHA512::new_from_os().fill_keymaterial_out(&mut seed)?;
326        Self::keygen_internal(&seed)
327    }
328    /// Algorithm 16 ML-KEM.KeyGen_internal(๐‘‘, ๐‘ง)
329    /// Uses randomness to generate an encapsulation key and a corresponding decapsulation key.
330    /// Input: randomness ๐‘‘ โˆˆ ๐”น32 .
331    /// Input: randomness ๐‘ง โˆˆ ๐”น32 .
332    /// Output: encapsulation key ek โˆˆ ๐”น384๐‘˜+32 .
333    /// Output: decapsulation key dk โˆˆ ๐”น768๐‘˜+96 .
334    pub(crate) fn keygen_internal(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError> {
335        if !(seed.key_type() == KeyType::Seed || seed.key_type() == KeyType::BytesFullEntropy)
336            || seed.key_len() != 64
337        {
338            return Err(KEMError::KeyGenError(
339                "Seed must be 64 bytes and KeyType::Seed or KeyType::BytesFullEntropy.",
340            ));
341        }
342
343        if seed.security_strength() < SecurityStrength::from_bits(LAMBDA as usize) {
344            return Err(KEMError::KeyGenError(
345                "Seed SecurityStrength must match algorithm security strength",
346            ));
347        }
348
349        // 1: (ekPKE, dkPKE) โ† K-PKE.KeyGen(๐‘‘)
350        let (pk, s_hat) = Self::pke_keygen(&seed.ref_to_bytes()[..32].try_into().unwrap());
351
352        // 2: ek โ† ekPKE โ–ท KEM encaps key is just the PKE encryption key
353        // 3: dk โ† (dkPKEโ€–ekโ€–H(ek)โ€–๐‘ง) โ–ท KEM decaps key includes PKE decryption key
354        // 4: return (ek, dk)
355        let pk_hash = pk.compute_hash();
356        Ok((
357            pk.clone(),
358            SK::new(
359                s_hat,
360                pk,
361                pk_hash,
362                seed.ref_to_bytes()[32..].try_into().unwrap(),
363                Some(seed.ref_to_bytes()[..32].try_into().unwrap()),
364            ),
365        ))
366    }
367
368    /// Algorithm 13 K-PKE.KeyGen(๐‘‘)
369    /// Uses randomness to generate an encryption key and a corresponding decryption key.
370    /// Input: randomness ๐‘‘ โˆˆ ๐”น32 .
371    /// Output: encryption key ek_PKE โˆˆ ๐”น384๐‘˜+32.
372    /// Output: decryption key dk_PKE โˆˆ ๐”น384๐‘˜.
373    fn pke_keygen(d: &[u8; 32]) -> (PK, Vector<k>) {
374        // 1: (๐œŒ, ๐œŽ) โ† G(๐‘‘โ€–๐‘˜)
375        //  โ–ท expand 32+1 bytes to two pseudorandom 32-byte seeds1
376        // rho: public seed
377        // sigma: noise seed
378        let (rho, mut sigma) = {
379            let mut g = G::new();
380            g.do_update(d);
381            g.do_update(&[k as u8]);
382            let mut buf = [0u8; 64];
383            let bytes_written = g.do_final_out(&mut buf);
384            debug_assert_eq!(bytes_written, 64);
385
386            (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
387        };
388
389        // 2: ๐‘ โ† 0
390        //  Note: in the definition of PRF_eta on page 18, it's said to be one byte.
391        //  since the number of loops here is static; we can hard-code the N values rather than using a counter
392
393        // 8: for (๐‘– โ† 0; ๐‘– < ๐‘˜; ๐‘–++)
394        //  โ–ท generate ๐ฌ โˆˆ (โ„ค256)^k
395        // 9: ๐ฌ[๐‘–] โ† SamplePolyCBD๐œ‚1(PRF๐œ‚1 (๐œŽ, ๐‘ ))
396        //   โ–ท ๐ฌ[๐‘–] โˆˆ โ„ค256 sampled from CBD
397        // 10: ๐‘ โ† ๐‘ + 1
398        // Note: here n = 0
399        let s_hat = {
400            let mut s = sample_vector_CBD::<k, eta1>(&sigma, 0);
401
402            // 16: ๐ฌ_hat โ† NTT(๐ฌ)ฬ‚
403            s.ntt();
404            s.reduce();
405            s
406        };
407
408        // first half of
409        // 18: ๐ญ_hat โ† ๐€_hat โˆ˜ ๐ฌ_hat + ๐ž_hat
410        let mut t_hat = {
411            // 3: for (๐‘– โ† 0; ๐‘– < ๐‘˜; ๐‘–++)
412            //  โ–ท generate matrix A_hat โˆˆ (โ„ค256)^k x k
413            let A_hat = expandA(&rho);
414
415            A_hat.matrix_vector_ntt::<false>(&s_hat)
416        };
417
418        // second half of
419        // 18: ๐ญ_hat โ† ๐€_hat โˆ˜ ๐ฌ_hat + ๐ž_hat
420        {
421            // 12: for (๐‘– โ† 0; ๐‘– < ๐‘˜; ๐‘–++)
422            //  โ–ท generate ๐ž โˆˆ (โ„ค256)^k
423            // 13: ๐ž[๐‘–] โ† SamplePolyCBD๐œ‚1(PRF๐œ‚1 (๐œŽ, ๐‘ ))
424            //   โ–ท ๐ž[๐‘–] โˆˆ โ„ค256 sampled from CBD
425            // 14: ๐‘ โ† ๐‘ + 1
426            // Note: here n = k
427            let mut e = sample_vector_CBD::<k, eta1>(&sigma, k as u8);
428
429            e.ntt(); // technically now e_hat
430            e.reduce();
431            t_hat.add_vector_ntt(&e);
432        }
433
434        // Clear the secret data before returning memory to the OS
435        sigma.fill(0u8);
436
437        // 19: ekPKE โ† ByteEncode12(๐ญ)โ€–๐œŒ โ–ท run ByteEncode12 ๐‘˜ times, then append ๐€-seed
438        // 20: dkPKE โ† ByteEncode12(๐ฌ)ฬ‚ โ–ท run ByteEncode12 ๐‘˜ times
439        // Note: I'm skipping the encoding at this stage and leaving it expanded for future efficiency when it's used.
440        // 21: return (ekPKE, dkPKE)
441        (PK::new(t_hat, rho), s_hat)
442    }
443
444    /// Algorithm 14 K-PKE.Encrypt(ekPKE, ๐‘š, ๐‘Ÿ)
445    /// Uses the encryption key to encrypt a plaintext message using the randomness ๐‘Ÿ.
446    /// Input: encryption key ekPKE โˆˆ ๐”น384๐‘˜+32 .
447    /// Input: message ๐‘š โˆˆ ๐”น32 .
448    /// Input: randomness ๐‘Ÿ โˆˆ ๐”น32 .
449    /// Output: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
450    fn pke_encrypt(ek: &PK, A_hat: &Matrix<k, k>, m: [u8; 32], r: &[u8; 32]) -> [u8; CT_LEN] {
451        // 1: ๐‘ โ† 0
452        //  since the number of loops here is static; we can hard-code the N values rather than using a counter
453
454        // 2: ๐ญ โ† ByteDecode12(ekPKE[0 โˆถ 384๐‘˜])
455        // 3: ๐œŒ โ† ekPKE[384๐‘˜ โˆถ 384๐‘˜ + 32]
456        // not necessary here because ek is already decoded
457
458        // 4: for (๐‘– โ† 0; ๐‘– < ๐‘˜; ๐‘–++)
459        //   โ–ท re-generate matrix ๐€ โˆˆ (โ„ค256_๐‘ž )๐‘˜ร—๐‘˜ sampled in Alg. 13
460        // We're doing an optimization where the user can pre-expand A_hat within the
461        // public key object for faster repeated encapsulations against this public key.
462
463        // 9: for (๐‘– โ† 0; ๐‘– < ๐‘˜; ๐‘–++)
464        //  โ–ท generate ๐ฒ โˆˆ (โ„ค256_๐‘ž)k
465        // 10: ๐ฒ[๐‘–] โ† SamplePolyCBD๐œ‚1(PRF๐œ‚1 (๐‘Ÿ, ๐‘))
466        //   โ–ท ๐ฒ[๐‘–] โˆˆ โ„ค256 sampled from CBD
467        // 11: ๐‘ โ† ๐‘ + 1
468        // Note: here n = 0
469        let y_hat = {
470            let mut y = sample_vector_CBD::<k, eta1>(&r, 0);
471
472            // 18: ๐ฒ_hat โ† NTT(๐ฒ)
473            y.ntt();
474
475            y
476        };
477
478        // 19: ๐ฎ โ† NTTโˆ’1(๐€_hat^โŠบ โˆ˜ ๐ฒ_hat) + ๐ž
479        let mut u = A_hat.matrix_vector_ntt::<true>(&y_hat);
480        u.inv_ntt();
481        {
482            // 12: for (๐‘– โ† 0; ๐‘– < ๐‘˜; ๐‘–++)
483            //  โ–ท generate ๐ž โˆˆ (โ„ค256_๐‘ž)k
484            // 13: ๐ž[๐‘–] โ† SamplePolyCBD๐œ‚1(PRF๐œ‚1 (๐œŽ, ๐‘))
485            //  โ–ท ๐ž[๐‘–] โˆˆ โ„ค256 sampled from CBD๐‘ž
486            // 14: ๐‘ โ† ๐‘ + 1
487            // note: here n = k
488            let e1 = sample_vector_CBD::<k, ETA2>(&r, k as u8);
489
490            u.add_vector_ntt(&e1);
491        }
492        u.reduce();
493
494        // 20: ๐œ‡ โ† Decompress1(ByteDecode1(๐‘š))
495        // 21: ๐‘ฃ โ† NTTโˆ’1(๐ญ_hat^T โˆ˜ ๐ฒ_hat) + ๐‘’2 + ๐œ‡
496        //  โ–ท encode plaintext ๐‘š into polynomial ๐‘ฃ
497        let mut v = ek.t_hat().dot_product(&y_hat);
498        v.inv_ntt();
499
500        // 17: ๐‘’2 โ† SamplePolyCBD๐œ‚2(PRF๐œ‚2 (๐‘Ÿ, ๐‘))
501        //  โ–ท sample ๐‘’2 โˆˆ โ„ค256 from CBD
502        // note: here n = 2k
503        let e2 = sample_poly_CBD::<ETA2>(&r, 2 * k as u8);
504        v.add(&e2);
505
506        let mu = Polynomial::from_msg(m);
507        v.add(&mu);
508
509        v.poly_reduce();
510
511        pack_ciphertext::<k, CT_LEN, du, dv>(&u, &v)
512    }
513
514    /// Algorithm 17 ML-KEM.Encaps_internal(ek, ๐‘š)
515    /// Uses the encapsulation key and randomness to generate a key and an associated ciphertext.
516    /// Input: encapsulation key ek โˆˆ ๐”น384๐‘˜+32 .
517    /// Input: randomness ๐‘š โˆˆ ๐”น32 .
518    /// Output: shared secret key ๐พ โˆˆ ๐”น32 .
519    /// Output: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
520    ///
521    /// This function also takes an Option for the public matrix A.
522    /// If you don't know what it is, just provide None.
523    /// This is to enable performance
524    /// optimizations when the same public key is used for multiple encapsulations and the intermediate
525    /// value called the public matrix A_hat can be re-used for multiple encapsulations.
526    /// A_hat can be obtained from [MLKEMPublicKeyTrait::A_hat].
527    /// Alternatively, you can use a [MLKEMPublicKeyExpanded] with [MLKEM::encaps_for_expanded_key].
528    /// If you specify None, the function will compute A_hat internally and everything will work fine.
529    ///
530    /// Unlike the more public function exposed by [KEM::encaps], this returns the shared secret as raw bytes
531    /// instead of wrapped in an appropriately-set [KeyMaterialTrait], so you're on your own for handling it properly.
532    ///
533    /// Note: this is an internal function that allows the caller to specify the encapsulation
534    /// randomness (which is the message `m` to be encrypted by the underlying PKE scheme).
535    /// This function should not be used directly unless you really have a
536    /// good reason. [KEM::encaps] should be used in 99.9% of cases.
537    /// The reason this is exposed publicly is: A) for unit testing that requires access
538    /// to the deterministically reproducible function, and B) for operational environments
539    /// that wish to provide randomness from their own source instead of the built-in RNG in bc-rust.
540    /// If you think you will be clever and invent some scheme that uses a deterministic KEM,
541    /// then you will almost certainly end up with security problems. Please don't do this.
542    pub fn encaps_internal(
543        ek: &PK,
544        A_hat: Option<&Matrix<k, k>>,
545        m: [u8; 32],
546    ) -> ([u8; 32], [u8; CT_LEN]) {
547        debug_assert_eq!(CT_LEN, 32 * ((du as usize) * k + (dv as usize)));
548
549        // 1: (๐พ, ๐‘Ÿ) โ† G(๐‘šโ€–H(ek))
550        //  โ–ท derive shared secret key ๐พ and randomness ๐‘Ÿ
551        let K: [u8; MLKEM_SS_LEN];
552        let r: [u8; 32];
553        (K, r) = {
554            let mut g = G::new();
555            g.do_update(&m);
556            g.do_update(&ek.compute_hash());
557            let mut buf = [0u8; 64];
558            let bytes_written = g.do_final_out(&mut buf);
559            debug_assert_eq!(bytes_written, 64);
560
561            (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
562        };
563
564        // 2: ๐‘ โ† K-PKE.Encrypt(ek, ๐‘š, ๐‘Ÿ)
565        //  โ–ท encrypt ๐‘š using K-PKE with randomness ๐‘Ÿ
566        // deviation from FIPS:
567        //  To allow for pre-computing A_hat for multiple encapsulations, we will either take
568        // A_hat passed in, or compute it fresh.
569        let ct = match A_hat {
570            Some(A_hat) => Self::pke_encrypt(ek, A_hat, m, &r),
571            None => Self::pke_encrypt(ek, &ek.A_hat(), m, &r),
572        };
573
574        (K, ct)
575    }
576
577    /// Algorithm 15 K-PKE.Decrypt(dkPKE, ๐‘)
578    /// Uses the decryption key to decrypt a ciphertext.
579    /// Input: decryption key dkPKE  โˆˆ ๐”น384๐‘˜.
580    /// Input: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
581    /// Output: message ๐‘š โˆˆ ๐”น32 .
582    fn pke_decrypt(dk: &SK, ct: [u8; CT_LEN]) -> [u8; 32] {
583        // 1: ๐‘1 โ† ๐‘[0 โˆถ 32๐‘‘๐‘ข๐‘˜]
584        // 2: ๐‘2 โ† ๐‘[32๐‘‘๐‘ข๐‘˜ โˆถ 32(๐‘‘๐‘ข๐‘˜ + ๐‘‘๐‘ฃ)]
585        // 3: ๐ฎโ€ฒ โ† Decompress_๐‘‘๐‘ข(ByteDecode_๐‘‘๐‘ข(๐‘1))
586        // 4: ๐‘ฃโ€ฒ โ† Decompress_๐‘‘๐‘ฃ(ByteDecode_๐‘‘๐‘ฃ(๐‘2))
587        let v1 = {
588            let mut u_prime = unpack_ciphertext_u::<k, CT_LEN, du, dv>(&ct);
589
590            // 5: ๐ฌ_hat โ† ByteDecode12(dkPKE)
591            //   Unnecessary here because dk is already decoded
592
593            // 6: ๐‘ค โ† ๐‘ฃโ€ฒ โˆ’ NTTโˆ’1(๐ฌ_hat^T โˆ˜ NTT(๐ฎโ€ฒ))
594            u_prime.ntt();
595            let mut v1 = dk.s_hat().dot_product(&u_prime);
596            v1.inv_ntt();
597
598            v1
599        };
600
601        let w = {
602            let mut v_prime = unpack_ciphertext_v::<k, CT_LEN, du, dv>(&ct);
603
604            v_prime.sub(&v1);
605            v_prime.poly_reduce();
606            v_prime // rename to w
607        };
608
609        // 7: ๐‘š โ† ByteEncode1(Compress1(๐‘ค))
610        //   โ–ท decode plaintext ๐‘š from polynomial ๐‘ค
611        w.to_msg()
612    }
613
614    /// Algorithm 18 ML-KEM.Decaps_internal(dk, ๐‘)
615    /// Uses the decapsulation key to produce a shared secret key from a ciphertext.
616    /// Input: decapsulation key dk โˆˆ ๐”น768๐‘˜+96 .
617    /// Input: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
618    /// Output: shared secret key ๐พ โˆˆ ๐”น32 .
619    fn decaps_internal(
620        dk: &SK,
621        A_hat: Option<&Matrix<k, k>>,
622        c: [u8; CT_LEN],
623    ) -> [u8; MLKEM_SS_LEN] {
624        // I have tried to keep this as clean as possible for correspondence with the FIPS,
625        // but I have moved things around so that I can use unnamed scopes to limit how many
626        // stack variables are alive at the same time.
627
628        // 1: dkPKE โ† dk[0 โˆถ 384๐‘˜] โ–ท extract (from KEM decaps key) the PKE decryption key
629        // 2: ekPKE โ† dk[384๐‘˜ โˆถ 768๐‘˜ + 32] โ–ท extract PKE encryption key
630        // 3: โ„Ž โ† dk[768๐‘˜ + 32 โˆถ 768๐‘˜ + 64] โ–ท extract hash of PKE encryption key
631        // 4: ๐‘ง โ† dk[768๐‘˜ + 64 โˆถ 768๐‘˜ + 96] โ–ท extract implicit rejection value
632        // Nothing to do since dk is already decoded.
633
634        // 5: ๐‘šโ€ฒ โ† K-PKE.Decrypt(dkPKE, ๐‘)
635        let m_prime = Self::pke_decrypt(&dk, c);
636
637        // Compute the trial shared secret key
638        // 6: (๐พโ€ฒ, ๐‘Ÿโ€ฒ) โ† G(๐‘šโ€ฒโ€–โ„Ž)ฬ„
639        let K_prime: [u8; MLKEM_SS_LEN];
640        let r_prime: [u8; 32];
641        (K_prime, r_prime) = {
642            let mut g = G::new();
643            g.do_update(&m_prime);
644            g.do_update(&dk.pk().compute_hash());
645            let mut buf = [0u8; 64];
646            let bytes_written = g.do_final_out(&mut buf);
647            debug_assert_eq!(bytes_written, 64);
648
649            (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
650        };
651
652        // 7: ๐พ_bar โ† J(๐‘งโ€–๐‘)
653        //   Compute the rejection sampling key.
654        //   Note to future optimizers: this needs to be computed outside of the if at line 9 below
655        //   because if its computation is conditional on the Fujisaki-Okamoto check failing, then
656        //   you'll have a timing difference between success and failure.
657
658        let K_bar: [u8; MLKEM_SS_LEN];
659        K_bar = {
660            let mut j = J::new();
661            j.absorb(dk.z());
662            j.absorb(&c);
663            let mut buf = [0u8; MLKEM_SS_LEN];
664            let bytes_written = j.squeeze_out(&mut buf);
665            debug_assert_eq!(bytes_written, MLKEM_SS_LEN);
666
667            buf
668        };
669
670        // 8: ๐‘โ€ฒ โ† K-PKE.Encrypt(ekPKE, ๐‘šโ€ฒ, ๐‘Ÿโ€ฒ)
671        //   โ–ท re-encrypt using the derived randomness ๐‘Ÿโ€ฒ
672        // deviation from FIPS:
673        //  To allow for pre-computing A_hat for multiple encapsulations, we will either take
674        // A_hat passed in, or compute it fresh.
675        let c_prime = match A_hat {
676            Some(A_hat) => Self::pke_encrypt(dk.pk(), A_hat, m_prime, &r_prime),
677            None => Self::pke_encrypt(dk.pk(), &dk.pk().A_hat(), m_prime, &r_prime),
678        };
679
680        // 9: if ๐‘ โ‰  ๐‘โ€ฒ then
681        // 10: ๐พโ€ฒ โ† ๐พ_bar
682        //  โ–ท if ciphertexts do not match, โ€œimplicitly reject"
683        let mut K_out = [0u8; MLKEM_SS_LEN];
684        conditional_copy_bytes(&K_prime, &K_bar, &mut K_out, ct_eq_bytes(&c, &c_prime));
685
686        K_out
687    }
688
689    /// Alternative initialization of the streaming signer where you have your private key
690    /// as a seed and you want to delay its expansion as late as possible for memory-usage reasons.
691    // todo -- should we build a fully-stitched-together decaps-from-seed ... or not?
692    pub fn decaps_from_seed(
693        seed: &KeyMaterial<64>,
694        ct: &[u8],
695    ) -> Result<KeyMaterial<SS_LEN>, KEMError> {
696        let (_pk, sk) = Self::keygen_from_seed(seed)?;
697
698        Self::decaps(&sk, ct)
699    }
700}
701
702impl<
703    const PK_LEN: usize,
704    const SK_LEN: usize,
705    const CT_LEN: usize,
706    const SS_LEN: usize,
707    PK: MLKEMPublicKeyTrait<k, PK_LEN> + MLKEMPublicKeyInternalTrait<k, PK_LEN>,
708    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN>
709        + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
710    const k: usize,
711    const eta1: i16,
712    const du: i16,
713    const dv: i16,
714    const LAMBDA: i16,
715> MLKEMTrait<PK_LEN, SK_LEN, CT_LEN, SS_LEN, PK, SK, k, eta1, du, dv, LAMBDA>
716    for MLKEM<PK_LEN, SK_LEN, CT_LEN, SS_LEN, PK, SK, k, eta1, du, dv, LAMBDA>
717{
718    /// Imports a secret key from a seed.
719    fn keygen_from_seed(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError> {
720        Self::keygen_internal(seed)
721    }
722    /// Imports a secret key from both a seed and an encoded_sk.
723    ///
724    /// This is a convenience function to expand the key from seed and compare it against
725    /// the provided `encoded_sk` using a constant-time equality check.
726    /// If everything checks out, the secret key is returned fully populated with pk and seed.
727    /// If the provided key and derived key don't match, an error is returned.
728    fn keygen_from_seed_and_encoded(
729        seed: &KeyMaterial<64>,
730        encoded_sk: &[u8; SK_LEN],
731    ) -> Result<(PK, SK), KEMError> {
732        let (pk, sk) = Self::keygen_internal(seed)?;
733
734        let sk_from_bytes = SK::sk_decode(encoded_sk)?;
735
736        // MLKEMPrivateKey impls PartialEq with a constant-time equality check.
737        if sk != sk_from_bytes {
738            return Err(KEMError::KeyGenError("Encoded key does not match generated key"));
739        }
740
741        Ok((pk, sk))
742    }
743    /// Given a public key and a secret key, check that the public key matches the secret key.
744    /// This is a sanity check that the public key was generated correctly from the secret key.
745    ///
746    /// At the current time, this is only possible if `sk` either contains a public key (in which case
747    /// the two pk's are encoded and compared for byte equality), or if `sk` contains a seed
748    /// (in which case a keygen_from_seed is run and then the pk's compared).
749    ///
750    /// Returns either `()` or [KEMError::ConsistencyCheckFailed].
751    fn keypair_consistency_check(pk: &PK, sk: &SK) -> Result<(), KEMError> {
752        let derived_pk = sk.pk();
753        if derived_pk.compute_hash() == pk.compute_hash() {
754            Ok(())
755        } else {
756            Err(KEMError::ConsistencyCheckFailed(""))
757        }
758    }
759
760    fn encaps_for_expanded_key(
761        pk: &MLKEMPublicKeyExpanded<k, PK, PK_LEN>,
762    ) -> Result<(KeyMaterial<SS_LEN>, [u8; CT_LEN]), KEMError> {
763        let mut m = [0u8; 32];
764        HashDRBG_SHA512::new_from_os().next_bytes_out(&mut m)?;
765
766        let (ss, ct) = Self::encaps_internal(&pk.ek, Some(&pk.A_hat), m);
767
768        let mut key = KeyMaterial::<SS_LEN>::from_bytes_as_type(&ss, KeyType::BytesFullEntropy)?;
769        key.allow_hazardous_operations();
770        key.set_security_strength(SecurityStrength::from_bits(LAMBDA as usize))?;
771        key.drop_hazardous_operations();
772
773        Ok((key, ct))
774    }
775
776    fn decaps_with_expanded_key(
777        sk: &MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN>,
778        ct: &[u8],
779    ) -> Result<KeyMaterial<SS_LEN>, KEMError> {
780        /* decapsulation inputs checks described on FIPS 203 section 7.3 */
781        // 1. (Ciphertext type check) If ๐‘ is not a byte array of length 32(๐‘‘๐‘ข ๐‘˜ + ๐‘‘๐‘ฃ) for the values of ๐‘‘๐‘ข,
782        //     ๐‘‘๐‘ฃ, and ๐‘˜ specified by the relevant parameter set, then input checking has failed.
783        debug_assert_eq!(CT_LEN, 32 * ((du as usize) * k + (dv as usize)));
784
785        if ct.len() != CT_LEN {
786            return Err(KEMError::LengthError("Ciphertext has the incorrect length"));
787        }
788
789        // 2. (Decapsulation key type check) If dk is not a byte array of length 768๐‘˜ + 96 for the value of
790        //     ๐‘˜ specified by the relevant parameter set, then input checking has failed.
791        // This is handled at the time of loading dk into MLKEMPrivateKey
792
793        // 3. Check that the H(ek) stored in the private key matches the ek also stored in the private key.
794        // Again, this is handled by the MLKEMPrivateKey trait.
795
796        /* the actual decaps operation */
797        let K = Self::decaps_internal(&sk.dk, Some(&sk.A_hat), ct.try_into().unwrap());
798
799        let mut key = KeyMaterial::<SS_LEN>::from_bytes_as_type(&K, KeyType::BytesFullEntropy)?;
800        key.allow_hazardous_operations();
801        key.set_security_strength(SecurityStrength::from_bits(LAMBDA as usize))?;
802        key.drop_hazardous_operations();
803
804        Ok(key)
805    }
806}
807
808/// Trait for all three of the ML-DSA algorithm variants.
809pub trait MLKEMTrait<
810    const PK_LEN: usize,
811    const SK_LEN: usize,
812    const CT_LEN: usize,
813    const SS_LEN: usize,
814    PK: MLKEMPublicKeyTrait<k, PK_LEN> + MLKEMPublicKeyInternalTrait<k, PK_LEN>,
815    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN>
816        + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
817    const k: usize,
818    const eta: i16,
819    const du: i16,
820    const dv: i16,
821    const LAMBDA: i16,
822>: Sized
823{
824    /// Imports a secret key from a seed.
825    fn keygen_from_seed(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError>;
826    /// Imports a secret key from both a seed and an encoded_sk.
827    ///
828    /// This is a convenience function to expand the key from seed and compare it against
829    /// the provided `encoded_sk` using a constant-time equality check.
830    /// If everything checks out, the secret key is returned fully populated with pk and seed.
831    /// If the provided key and derived key don't match, an error is returned.
832    fn keygen_from_seed_and_encoded(
833        seed: &KeyMaterial<64>,
834        encoded_sk: &[u8; SK_LEN],
835    ) -> Result<(PK, SK), KEMError>;
836    /// Given a public key and a secret key, check that the public key matches the secret key.
837    /// This is a sanity check that the public key was generated correctly from the secret key.
838    ///
839    /// At the current time, this is only possible if `sk` either contains a public key (in which case
840    /// the two pk's are encoded and compared for byte equality), or if `sk` contains a seed
841    /// (in which case a keygen_from_seed is run and then the pk's compared).
842    ///
843    /// Returns either `()` or [KEMError::ConsistencyCheckFailed].
844    fn keypair_consistency_check(pk: &PK, sk: &SK) -> Result<(), KEMError>;
845
846    /// Same as [KEM::encaps], but acts on an [MLKEMPublicKeyExpanded].
847    fn encaps_for_expanded_key(
848        pk: &MLKEMPublicKeyExpanded<k, PK, PK_LEN>,
849    ) -> Result<(KeyMaterial<SS_LEN>, [u8; CT_LEN]), KEMError>;
850
851    /// Same as [KEM::decaps], but acts on an [MLKEMPrivateKeyExpanded].
852    fn decaps_with_expanded_key(
853        sk: &MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN>,
854        ct: &[u8],
855    ) -> Result<KeyMaterial<SS_LEN>, KEMError>;
856}
857
858impl<
859    const PK_LEN: usize,
860    const SK_LEN: usize,
861    const CT_LEN: usize,
862    const SS_LEN: usize,
863    PK: MLKEMPublicKeyTrait<k, PK_LEN> + MLKEMPublicKeyInternalTrait<k, PK_LEN>,
864    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN>
865        + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
866    const k: usize,
867    const eta: i16,
868    const du: i16,
869    const dv: i16,
870    const LAMBDA: i16,
871> KEM<PK, SK, PK_LEN, SK_LEN, CT_LEN, SS_LEN>
872    for MLKEM<PK_LEN, SK_LEN, CT_LEN, SS_LEN, PK, SK, k, eta, du, dv, LAMBDA>
873{
874    /// Generates a fresh key pair.
875    fn keygen() -> Result<(PK, SK), KEMError> {
876        Self::keygen_from_os_rng()
877    }
878
879    /// Performs an encapsulation against the given public key, using the library's default internal RNG.
880    /// Returns (shared_secret_key, ciphertext)
881    /// The derived shared secret key is returned as a KeyMaterial with the SecurityStrength set to
882    /// the security level of the ML-KEM parameter set.
883    ///
884    /// Algorithm 20 ML-KEM.Encaps(ek)
885    /// Uses the encapsulation key to generate a shared secret key and an associated ciphertext.
886    /// Checked input: encapsulation key ek โˆˆ ๐”น384๐‘˜+32 .
887    /// Output: shared secret key ๐พ โˆˆ ๐”น32 .
888    /// Output: ciphertext ๐‘ โˆˆ ๐”น32(๐‘‘๐‘ข๐‘˜+๐‘‘๐‘ฃ).
889    fn encaps(pk: &PK) -> Result<(KeyMaterial<SS_LEN>, [u8; CT_LEN]), KEMError> {
890        Self::encaps_for_expanded_key(&MLKEMPublicKeyExpanded::<k, PK, PK_LEN>::from(pk))
891    }
892
893    /// Performs a decapsulation of the given ciphertext.
894    /// Returns the shared secret key.
895    /// The derived shared secret key is returned as a KeyMaterial with the SecurityStrength set to
896    /// the security level of the ML-KEM parameter set.
897    /// As ML-KEM is an implicitly-rejecting KEM, this returns an error only if the ciphertext is invalid (ie the wrong length).
898    fn decaps(sk: &SK, ct: &[u8]) -> Result<KeyMaterial<SS_LEN>, KEMError> {
899        Self::decaps_with_expanded_key(
900            &MLKEMPrivateKeyExpanded::<k, PK, SK, SK_LEN, PK_LEN>::from(sk),
901            ct,
902        )
903    }
904}