bouncycastle_mlkem_lowmemory/mlkem.rs
1//! There are no advanced features in this low memory crate that are not already documented in the standard \[bouncycastle_mlkem] crate.
2
3use crate::aux_functions::sample_poly_CBD;
4use crate::low_memory_helpers::{
5 compress_u_row, compute_A_hat_dot_y_hat, compute_t_hat_dot_y_hat_row, unpack_ciphertext_u_row,
6 unpack_ciphertext_v, unpack_t_hat_row,
7};
8use crate::mlkem_keys::{
9 MLKEM512PrivateKey, MLKEM512PublicKey, MLKEM768PrivateKey, MLKEM768PublicKey,
10 MLKEM1024PrivateKey, MLKEM1024PublicKey,
11};
12use crate::mlkem_keys::{MLKEMPrivateKeyInternalTrait, MLKEMPrivateKeyTrait};
13use crate::mlkem_keys::{MLKEMPublicKeyInternalTrait, MLKEMPublicKeyTrait};
14use crate::polynomial::Polynomial;
15use bouncycastle_core::errors::KEMError;
16use bouncycastle_core::key_material::{KeyMaterial, KeyMaterialTrait, KeyType};
17use bouncycastle_core::traits::{Algorithm, Hash, KEM, RNG, SecurityStrength, XOF};
18use bouncycastle_rng::HashDRBG_SHA512;
19use bouncycastle_sha3::{SHA3_256, SHA3_512, SHAKE256};
20use bouncycastle_utils::ct::{conditional_copy_bytes, ct_eq_bytes};
21use core::marker::PhantomData;
22
23/*** Constants ***/
24
25///
26pub const ML_KEM_512_NAME: &str = "ML-KEM-512";
27///
28pub const ML_KEM_768_NAME: &str = "ML-KEM-768";
29///
30pub const ML_KEM_1024_NAME: &str = "ML-KEM-1024";
31
32// From FIPS 203 Table 2 and Table 3
33
34// Constants that are the same for all parameter sets
35/// Length of the \[u8] holding an ML-KEM seed value.
36pub const MLKEM_SEED_LEN: usize = 64;
37/// Length of the \[u8] holding an ML-KEM encaps random value, also sometimes called the message `m`
38pub const MLKEM_RND_LEN: usize = 32;
39/// Size of in bytes of an ML-KEM shared secret key.
40pub const MLKEM_SS_LEN: usize = 32;
41pub(crate) const N: usize = 256;
42pub(crate) const q: i16 = 3329;
43pub(crate) const q_inv: i32 = 62209;
44pub(crate) const ETA2: i16 = 2;
45pub(crate) const POLY_BYTES: usize = 384;
46
47/* ML-KEM-512 params */
48
49/// Length of the \[u8] holding a ML-KEM-512 public key.
50pub const MLKEM512_PK_LEN: usize = 800;
51/// Length of the \[u8] holding a ML-KEM-512 seed-based private key.
52pub const MLKEM512_SK_LEN: usize = MLKEM_SEED_LEN;
53/// Length of the \[u8] holding a full ML-KEM-512 private key in the NIST encoding.
54pub const MLKEM512_FULL_SK_LEN: usize = 1632;
55/// Length of the \[u8] holding a ML-KEM-512 ciphertext.
56pub const MLKEM512_CT_LEN: usize = 768;
57pub(crate) const MLKEM512_k: usize = 2;
58pub(crate) const MLKEM512_ETA1: i16 = 3;
59pub(crate) const MLKEM512_DU: i16 = 10;
60pub(crate) const MLKEM512_DV: i16 = 4;
61/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
62pub(crate) const MLKEM512_LAMBDA: i16 = 128;
63
64// internal derived values
65pub(crate) const MLKEM512_T_PACKED_LEN: usize = 12 * MLKEM512_k * 32;
66
67/* ML-KEM-768 params */
68
69/// Length of the \[u8] holding a ML-KEM-768 public key.
70pub const MLKEM768_PK_LEN: usize = 1184;
71/// Length of the \[u8] holding a ML-KEM-768 seed-based private key.
72pub const MLKEM768_SK_LEN: usize = MLKEM_SEED_LEN;
73/// Length of the \[u8] holding a full ML-KEM-768 private key in the NIST encoding.
74pub const MLKEM768_FULL_SK_LEN: usize = 2400;
75/// Length of the \[u8] holding a ML-KEM-768 ciphertext.
76pub const MLKEM768_CT_LEN: usize = 1088;
77pub(crate) const MLKEM768_k: usize = 3;
78pub(crate) const MLKEM768_ETA1: i16 = 2;
79pub(crate) const MLKEM768_DU: i16 = 10;
80pub(crate) const MLKEM768_DV: i16 = 4;
81/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
82pub(crate) const MLKEM768_LAMBDA: i16 = 192;
83
84// internal derived values
85pub(crate) const MLKEM768_T_PACKED_LEN: usize = 12 * MLKEM768_k * 32;
86
87/* ML-KEM-1024 params */
88
89/// Length of the \[u8] holding a ML-KEM-1024 public key.
90pub const MLKEM1024_PK_LEN: usize = 1568;
91/// Length of the \[u8] holding a ML-KEM-512 seed-based private key.
92pub const MLKEM1024_SK_LEN: usize = MLKEM_SEED_LEN;
93/// Length of the \[u8] holding a full ML-KEM-512 private key in the NIST encoding.
94pub const MLKEM1024_FULL_SK_LEN: usize = 3168;
95/// Length of the \[u8] holding a ML-KEM-1024 ciphertext.
96pub const MLKEM1024_CT_LEN: usize = 1568;
97pub(crate) const MLKEM1024_k: usize = 4;
98pub(crate) const MLKEM1024_ETA1: i16 = 2;
99pub(crate) const MLKEM1024_DU: i16 = 11;
100pub(crate) const MLKEM1024_DV: i16 = 5;
101/// Maps to "required RBG strength (bits)" in FIPS 203 Table 2
102pub(crate) const MLKEM1024_LAMBDA: i16 = 256;
103
104// internal derived values
105pub(crate) const MLKEM1024_T_PACKED_LEN: usize = 12 * MLKEM1024_k * 32;
106
107// Typedefs just to make the algorithms look more like the FIPS 204 sample code.
108pub(crate) type G = SHA3_512;
109pub(crate) type H = SHA3_256;
110pub(crate) type J = SHAKE256;
111
112/*** Pub Types ***/
113
114/// The ML-KEM-512 algorithm.
115pub type MLKEM512 = MLKEM<
116 MLKEM512_PK_LEN,
117 MLKEM512_SK_LEN,
118 MLKEM512_FULL_SK_LEN,
119 MLKEM512_CT_LEN,
120 MLKEM_SS_LEN,
121 MLKEM512PublicKey,
122 MLKEM512PrivateKey,
123 MLKEM512_k,
124 MLKEM512_ETA1,
125 MLKEM512_DU,
126 MLKEM512_DV,
127 MLKEM512_LAMBDA,
128 MLKEM512_T_PACKED_LEN,
129>;
130
131impl Algorithm for MLKEM512 {
132 const ALG_NAME: &'static str = ML_KEM_512_NAME;
133 const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_128bit;
134}
135
136/// The ML-KEM-768 algorithm.
137pub type MLKEM768 = MLKEM<
138 MLKEM768_PK_LEN,
139 MLKEM768_SK_LEN,
140 MLKEM768_FULL_SK_LEN,
141 MLKEM768_CT_LEN,
142 MLKEM_SS_LEN,
143 MLKEM768PublicKey,
144 MLKEM768PrivateKey,
145 MLKEM768_k,
146 MLKEM768_ETA1,
147 MLKEM768_DU,
148 MLKEM768_DV,
149 MLKEM768_LAMBDA,
150 MLKEM768_T_PACKED_LEN,
151>;
152
153impl Algorithm for MLKEM768 {
154 const ALG_NAME: &'static str = ML_KEM_768_NAME;
155 const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_192bit;
156}
157
158/// The ML-KEM-1024 algorithm.
159pub type MLKEM1024 = MLKEM<
160 MLKEM1024_PK_LEN,
161 MLKEM1024_SK_LEN,
162 MLKEM1024_FULL_SK_LEN,
163 MLKEM1024_CT_LEN,
164 MLKEM_SS_LEN,
165 MLKEM1024PublicKey,
166 MLKEM1024PrivateKey,
167 MLKEM1024_k,
168 MLKEM1024_ETA1,
169 MLKEM1024_DU,
170 MLKEM1024_DV,
171 MLKEM1024_LAMBDA,
172 MLKEM1024_T_PACKED_LEN,
173>;
174
175impl Algorithm for MLKEM1024 {
176 const ALG_NAME: &'static str = ML_KEM_1024_NAME;
177 const MAX_SECURITY_STRENGTH: SecurityStrength = SecurityStrength::_256bit;
178}
179
180/// The core internal implementation of the ML-KEM algorithm.
181/// This needs to be public for the compiler to be able to find it, but you shouldn't ever
182/// need to use this directly. Please use the named public types.
183pub struct MLKEM<
184 const PK_LEN: usize,
185 const SK_LEN: usize,
186 const FULL_SK_LEN: usize,
187 const CT_LEN: usize,
188 const SS_LEN: usize,
189 PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
190 + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
191 SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
192 + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
193 const k: usize,
194 const eta1: i16,
195 const du: i16,
196 const dv: i16,
197 const LAMBDA: i16,
198 const T_PACKED_LEN: usize,
199> {
200 _phantom: PhantomData<(PK, SK)>,
201}
202
203impl<
204 const PK_LEN: usize,
205 const SK_LEN: usize,
206 const FULL_SK_LEN: usize,
207 const CT_LEN: usize,
208 const SS_LEN: usize,
209 PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
210 + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
211 SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
212 + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
213 const k: usize,
214 const eta1: i16,
215 const du: i16,
216 const dv: i16,
217 const LAMBDA: i16,
218 const T_PACKED_LEN: usize,
219>
220 MLKEM<
221 PK_LEN,
222 SK_LEN,
223 FULL_SK_LEN,
224 CT_LEN,
225 SS_LEN,
226 PK,
227 SK,
228 k,
229 eta1,
230 du,
231 dv,
232 LAMBDA,
233 T_PACKED_LEN,
234 >
235{
236 /// Should still be ok in FIPS mode
237 pub fn keygen_from_os_rng() -> Result<(PK, SK), KEMError> {
238 let mut seed = KeyMaterial::<64>::new();
239 HashDRBG_SHA512::new_from_os().fill_keymaterial_out(&mut seed)?;
240 // Self::keygen_internal(&seed)
241 Self::keygen_internal(&seed)
242 }
243 /// Performs the first step of key generation to transform the single provided seed into a set of internal intermediate seeds.
244 ///
245 /// Unlike other interfaces across the library that take an &impl KeyMaterial, this one
246 /// specifically takes a 64-byte [KeyMaterial512] and checks that it has [KeyType::Seed] and
247 /// the appropriate [SecurityStrength] for the requested ML-KEM parameter set.
248 /// If you happen to have your seed in a larger KeyMaterial, you'll have to copy it using
249 /// [KeyMaterial::from_key].
250 pub(crate) fn keygen_internal(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError> {
251 let sk = SK::from_keymaterial(seed)?;
252 let pk = sk.pk();
253 let pk = PK::new(pk.t_hat_packed, pk.rho); // stupid conversion, but it gets around these overly-generified rust types
254 Ok((pk, sk))
255 }
256
257 /// Algorithm 14 K-PKE.Encrypt(ekPKE, ๐, ๐)
258 /// Uses the encryption key to encrypt a plaintext message using the randomness ๐.
259 /// Input: encryption key ekPKE โ ๐น384๐+32 .
260 /// Input: message ๐ โ ๐น32 .
261 /// Input: randomness ๐ โ ๐น32 .
262 /// Output: ciphertext ๐ โ ๐น32(๐๐ข๐+๐๐ฃ).
263 fn pke_encrypt(
264 t_hat_packed: &[u8; T_PACKED_LEN],
265 rho: &[u8; 32],
266 m: [u8; 32],
267 r: &[u8; 32],
268 ) -> [u8; CT_LEN] {
269 let mut ct = [0u8; CT_LEN];
270
271 // 1: ๐ โ 0
272 // since the number of loops here is static; we can hard-code the N values rather than using a counter
273
274 // 2: ๐ญ โ ByteDecode12(ekPKE[0 โถ 384๐])
275 // 3: ๐ โ ekPKE[384๐ โถ 384๐ + 32]
276 // not necessary here because ek is already decoded
277
278 // 19: ๐ฎ โ NTTโ1(๐_hat^โบ โ ๐ฒ_hat) + ๐1
279 // 22: ๐1 โ ByteEncode_๐๐ข(Compress_๐๐ข(๐ฎ))
280
281 // Note: you need y_hat twice: once here at line 19, and again at line 21.
282 // We'll just generate it twice to save the memory of holding on to it.
283 for i in 0..k {
284 let mut u_i = compute_A_hat_dot_y_hat::<k, eta1>(rho, &r, i);
285
286 let e1_i = sample_poly_CBD::<ETA2>(&r, (k + i) as u8);
287 u_i.add(&e1_i);
288 u_i.poly_reduce();
289
290 compress_u_row::<du, CT_LEN>(u_i, i, &mut ct);
291 }
292
293 // 17: ๐2 โ SamplePolyCBD_๐2(PRF๐2 (๐, ๐))
294 // 20: ๐ โ Decompress1(ByteDecode1(๐))
295 // 21: ๐ฃ โ NTTโ1(๐ญ_hat_T โ ๐ฒ_hat) + ๐2 + ๐
296 // 23: ๐2 โ ByteEncode_๐๐ฃ(Compress_๐๐ฃ(๐ฃ))
297 {
298 // compute v, which is a single polynomial, but requires iterating over the vectors t_hat and y_hat
299 let mut v = compute_t_hat_dot_y_hat_row::<k, eta1>(
300 &r,
301 &unpack_t_hat_row(t_hat_packed, 0),
302 /*row*/ 0,
303 );
304
305 for i in 1..k {
306 let v_i = compute_t_hat_dot_y_hat_row::<k, eta1>(
307 &r,
308 &unpack_t_hat_row(t_hat_packed, i),
309 /*row*/ i,
310 );
311 v.add(&v_i);
312 }
313
314 // perform polynomial addition
315 let e2 = sample_poly_CBD::<ETA2>(&r, 2 * k as u8);
316 v.add(&e2);
317
318 let mu = Polynomial::from_msg(m);
319 v.add(&mu);
320
321 v.poly_reduce();
322
323 v.compress_poly::<dv>(&mut ct[CT_LEN - (N * (dv as usize) / 8)..]);
324 }
325
326 ct
327 }
328
329 /// Algorithm 17 ML-KEM.Encaps_internal(ek, ๐)
330 /// Uses the encapsulation key and randomness to generate a key and an associated ciphertext.
331 /// Input: encapsulation key ek โ ๐น384๐+32 .
332 /// Input: randomness ๐ โ ๐น32 .
333 /// Output: shared secret key ๐พ โ ๐น32 .
334 /// Output: ciphertext ๐ โ ๐น32(๐๐ข๐+๐๐ฃ).
335 ///
336 /// Unlike the more public function exposed by [KEM::encaps], this returns the shared secret as raw bytes
337 /// instead of wrapped in an appropriately-set [KeyMaterialTrait], so you're on your own for handling it properly.
338 ///
339 /// Note: this is an internal function that allows the caller to specify the encapsulation
340 /// randomness (which is the message `m` to be encrypted by the underlying PKE scheme).
341 /// This function should not be used directly unless you really have a
342 /// good reason. [KEM::encaps] should be used in 99.9% of cases.
343 /// The reason this is exposed publicly is: A) for unit testing that requires access
344 /// to the deterministically reproducible function, and B) for operational environments
345 /// that wish to provide randomness from their own source instead of the built-in RNG in bc-rust.
346 /// If you think you will be clever and invent some scheme that uses a deterministic KEM,
347 /// then you will almost certainly end up with security problems. Please don't do this.
348 pub fn encaps_internal(ek: &PK, m: [u8; 32]) -> ([u8; 32], [u8; CT_LEN]) {
349 debug_assert_eq!(CT_LEN, 32 * ((du as usize) * k + (dv as usize)));
350
351 // 1: (๐พ, ๐) โ G(๐โH(ek))
352 // โท derive shared secret key ๐พ and randomness ๐
353 let K: [u8; MLKEM_SS_LEN];
354 let r: [u8; 32];
355 (K, r) = {
356 let mut g = G::new();
357 g.do_update(&m);
358 g.do_update(&ek.compute_hash());
359 let mut buf = [0u8; 64];
360 let bytes_written = g.do_final_out(&mut buf);
361 debug_assert_eq!(bytes_written, 64);
362
363 (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
364 };
365
366 // 2: ๐ โ K-PKE.Encrypt(ek, ๐, ๐)
367 // โท encrypt ๐ using K-PKE with randomness ๐
368 // deviation from FIPS:
369 let ct = Self::pke_encrypt(ek.t_hat_packed(), ek.rho(), m, &r);
370
371 (K, ct)
372 }
373
374 /// Algorithm 15 K-PKE.Decrypt(dkPKE, ๐)
375 /// Uses the decryption key to decrypt a ciphertext
376 /// Input: decryption key dkPKE โ ๐น384๐.
377 /// Input: ciphertext ๐ โ ๐น32(๐๐ข๐+๐๐ฃ).
378 /// Output: message ๐ โ ๐น32 .
379 fn pke_decrypt(dk: &SK, ct: [u8; CT_LEN]) -> [u8; 32] {
380 // 1: ๐1 โ ๐[0 โถ 32๐๐ข๐]
381 // 3: ๐ฎโฒ โ Decompress_๐๐ข(ByteDecode_๐๐ข(๐1))
382
383 // 5: ๐ฌ_hat โ ByteDecode12(dkPKE)
384 // Unnecessary here because we're gonna re-compute them row-by-row
385
386 // first half of
387 // 6: ๐ค โ ๐ฃโฒ โ NTTโ1(๐ฌ_hat^T โ NTT(๐ฎโฒ))
388 let v1 = {
389 // i = 0 case
390 let mut v1 = {
391 let mut s_hat_i = dk.compute_s_hat_row(0);
392 {
393 let mut u_prime_i = unpack_ciphertext_u_row::<du, CT_LEN>(0, &ct);
394 u_prime_i.ntt();
395 s_hat_i.base_mult_montgomery(&u_prime_i);
396 }
397 s_hat_i.inv_ntt();
398
399 s_hat_i
400 };
401
402 for i in 1..k {
403 let mut s_hat_i = dk.compute_s_hat_row(i);
404 {
405 let mut u_prime_i = unpack_ciphertext_u_row::<du, CT_LEN>(i, &ct);
406 u_prime_i.ntt();
407 s_hat_i.base_mult_montgomery(&u_prime_i);
408 }
409 s_hat_i.inv_ntt();
410 v1.add(&s_hat_i);
411 }
412
413 v1
414 };
415
416 // 2: ๐2 โ ๐[32๐๐ข๐ โถ 32(๐๐ข๐ + ๐๐ฃ)]
417 // 4: ๐ฃโฒ โ Decompress_๐๐ฃ(ByteDecode_๐๐ฃ(๐2))
418 let w = {
419 // second half of
420 // 6: ๐ค โ ๐ฃโฒ โ NTTโ1(๐ฌ_hat^T โ NTT(๐ฎโฒ))
421 let mut v_prime = unpack_ciphertext_v::<k, CT_LEN, du, dv>(&ct);
422
423 v_prime.sub(&v1);
424 v_prime.poly_reduce();
425
426 v_prime // rename to w
427 };
428
429 // 7: ๐ โ ByteEncode1(Compress1(๐ค))
430 // โท decode plaintext ๐ from polynomial ๐ค
431 w.to_msg()
432 }
433
434 /// Algorithm 18 ML-KEM.Decaps_internal(dk, ๐)
435 /// Uses the decapsulation key to produce a shared secret key from a ciphertext.
436 /// Input: decapsulation key dk โ ๐น768๐+96 .
437 /// Input: ciphertext ๐ โ ๐น32(๐๐ข๐+๐๐ฃ).
438 /// Output: shared secret key ๐พ โ ๐น32 .
439 fn decaps_internal(dk: &SK, c: [u8; CT_LEN]) -> [u8; MLKEM_SS_LEN] {
440 // I have tried to keep this as clean as possible for correspondence with the FIPS,
441 // but I have moved things around so that I can use unnamed scopes to limit how many
442 // stack variables are alive at the same time.
443
444 // 1: dkPKE โ dk[0 โถ 384๐] โท extract (from KEM decaps key) the PKE decryption key
445 // 2: ekPKE โ dk[384๐ โถ 768๐ + 32] โท extract PKE encryption key
446 // 3: โ โ dk[768๐ + 32 โถ 768๐ + 64] โท extract hash of PKE encryption key
447 // 4: ๐ง โ dk[768๐ + 64 โถ 768๐ + 96] โท extract implicit rejection value
448 // Nothing to do since dk is already decoded.
449
450 // 5: ๐โฒ โ K-PKE.Decrypt(dkPKE, ๐)
451 let m_prime = Self::pke_decrypt(&dk, c);
452
453 // Compute the trial shared secret key
454 // 6: (๐พโฒ, ๐โฒ) โ G(๐โฒโโ)ฬ
455 let K_prime: [u8; MLKEM_SS_LEN];
456 let r_prime: [u8; 32];
457 (K_prime, r_prime) = {
458 let mut g = G::new();
459 g.do_update(&m_prime);
460 g.do_update(&dk.pk().compute_hash());
461 let mut buf = [0u8; 64];
462 let bytes_written = g.do_final_out(&mut buf);
463 debug_assert_eq!(bytes_written, 64);
464
465 (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
466 };
467
468 // 7: ๐พ_bar โ J(๐งโ๐)
469 // Compute the rejection sampling key.
470 // Note to future optimizers: this needs to be computed outside of the if at line 9 below
471 // because if its computation is conditional on the Fujisaki-Okamoto check failing, then
472 // you'll have a timing difference between success and failure.
473
474 let K_bar: [u8; MLKEM_SS_LEN];
475 K_bar = {
476 let mut j = J::new();
477 j.absorb(dk.z());
478 j.absorb(&c);
479 let mut buf = [0u8; MLKEM_SS_LEN];
480 let bytes_written = j.squeeze_out(&mut buf);
481 debug_assert_eq!(bytes_written, MLKEM_SS_LEN);
482
483 buf
484 };
485
486 // 8: ๐โฒ โ K-PKE.Encrypt(ekPKE, ๐โฒ, ๐โฒ)
487 // โท re-encrypt using the derived randomness ๐โฒ
488 let c_prime = Self::pke_encrypt(&dk.t_hat_packed(), dk.rho(), m_prime, &r_prime);
489
490 // 9: if ๐ โ ๐โฒ then
491 // 10: ๐พโฒ โ ๐พ_bar
492 // โท if ciphertexts do not match, โimplicitly reject"
493 let mut K_out = [0u8; MLKEM_SS_LEN];
494 conditional_copy_bytes(&K_prime, &K_bar, &mut K_out, ct_eq_bytes(&c, &c_prime));
495
496 K_out
497 }
498
499 /// Alternative initialization of the streaming signer where you have your private key
500 /// as a seed and you want to delay its expansion as late as possible for memory-usage reasons.
501 pub fn decaps_from_seed(
502 seed: &KeyMaterial<64>,
503 ct: &[u8],
504 ) -> Result<KeyMaterial<SS_LEN>, KEMError> {
505 let sk = SK::from_keymaterial(seed)?;
506
507 Self::decaps(&sk, ct)
508 }
509}
510
511impl<
512 const PK_LEN: usize,
513 const SK_LEN: usize,
514 const FULL_SK_LEN: usize,
515 const CT_LEN: usize,
516 const SS_LEN: usize,
517 PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
518 + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
519 SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
520 + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
521 const k: usize,
522 const eta1: i16,
523 const du: i16,
524 const dv: i16,
525 const LAMBDA: i16,
526 const T_PACKED_LEN: usize,
527>
528 MLKEMTrait<
529 PK_LEN,
530 SK_LEN,
531 FULL_SK_LEN,
532 CT_LEN,
533 SS_LEN,
534 PK,
535 SK,
536 k,
537 eta1,
538 du,
539 dv,
540 LAMBDA,
541 T_PACKED_LEN,
542 >
543 for MLKEM<
544 PK_LEN,
545 SK_LEN,
546 FULL_SK_LEN,
547 CT_LEN,
548 SS_LEN,
549 PK,
550 SK,
551 k,
552 eta1,
553 du,
554 dv,
555 LAMBDA,
556 T_PACKED_LEN,
557 >
558{
559 /// Imports a secret key from a seed.
560 fn keygen_from_seed(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError> {
561 Self::keygen_internal(seed)
562 }
563 /// Imports a secret key from both a seed and an encoded_sk.
564 ///
565 /// This is a convenience function to expand the key from seed and compare it against
566 /// the provided `encoded_sk` using a constant-time equality check.
567 /// If everything checks out, the secret key is returned fully populated with pk and seed.
568 /// If the provided key and derived key don't match, an error is returned.
569 fn keygen_from_seed_and_encoded(
570 seed: &KeyMaterial<64>,
571 encoded_sk: &[u8; SK_LEN],
572 ) -> Result<(PK, SK), KEMError> {
573 let (pk, sk) = Self::keygen_internal(seed)?;
574
575 let sk_from_bytes = SK::sk_decode(encoded_sk);
576
577 // MLKEMPrivateKey impls PartialEq with a constant-time equality check.
578 if sk != sk_from_bytes {
579 return Err(KEMError::KeyGenError("Encoded key does not match generated key"));
580 }
581
582 Ok((pk, sk))
583 }
584 /// Given a public key and a secret key, check that the public key matches the secret key.
585 /// This is a sanity check that the public key was generated correctly from the secret key.
586 ///
587 /// At the current time, this is only possible if `sk` either contains a public key (in which case
588 /// the two pk's are encoded and compared for byte equality), or if `sk` contains a seed
589 /// (in which case a keygen_from_seed is run and then the pk's compared).
590 ///
591 /// Returns either `()` or [KEMError::ConsistencyCheckFailed].
592 fn keypair_consistency_check(pk: &PK, sk: &SK) -> Result<(), KEMError> {
593 let derived_pk = sk.pk();
594 if derived_pk.compute_hash() == pk.compute_hash() {
595 Ok(())
596 } else {
597 Err(KEMError::ConsistencyCheckFailed(""))
598 }
599 }
600}
601
602/// Trait for all three of the ML-DSA algorithm variants.
603pub trait MLKEMTrait<
604 const PK_LEN: usize,
605 const SK_LEN: usize,
606 const FULL_SK_LEN: usize,
607 const CT_LEN: usize,
608 const SS_LEN: usize,
609 PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
610 + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
611 SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
612 + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
613 const k: usize,
614 const eta: i16,
615 const du: i16,
616 const dv: i16,
617 const LAMBDA: i16,
618 const T_PACKED_LEN: usize,
619>: Sized
620{
621 /// Imports a secret key from a seed.
622 fn keygen_from_seed(seed: &KeyMaterial<64>) -> Result<(PK, SK), KEMError>;
623 /// Imports a secret key from both a seed and an encoded_sk.
624 ///
625 /// This is a convenience function to expand the key from seed and compare it against
626 /// the provided `encoded_sk` using a constant-time equality check.
627 /// If everything checks out, the secret key is returned fully populated with pk and seed.
628 /// If the provided key and derived key don't match, an error is returned.
629 fn keygen_from_seed_and_encoded(
630 seed: &KeyMaterial<64>,
631 encoded_sk: &[u8; SK_LEN],
632 ) -> Result<(PK, SK), KEMError>;
633 /// Given a public key and a secret key, check that the public key matches the secret key.
634 /// This is a sanity check that the public key was generated correctly from the secret key.
635 ///
636 /// At the current time, this is only possible if `sk` either contains a public key (in which case
637 /// the two pk's are encoded and compared for byte equality), or if `sk` contains a seed
638 /// (in which case a keygen_from_seed is run and then the pk's compared).
639 ///
640 /// Returns either `()` or [KEMError::ConsistencyCheckFailed].
641 fn keypair_consistency_check(pk: &PK, sk: &SK) -> Result<(), KEMError>;
642}
643
644impl<
645 const PK_LEN: usize,
646 const SK_LEN: usize,
647 const FULL_SK_LEN: usize,
648 const CT_LEN: usize,
649 const SS_LEN: usize,
650 PK: MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN>
651 + MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN>,
652 SK: MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
653 + MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN>,
654 const k: usize,
655 const eta: i16,
656 const du: i16,
657 const dv: i16,
658 const LAMBDA: i16,
659 const T_PACKED_LEN: usize,
660> KEM<PK, SK, PK_LEN, SK_LEN, CT_LEN, SS_LEN>
661 for MLKEM<
662 PK_LEN,
663 SK_LEN,
664 FULL_SK_LEN,
665 CT_LEN,
666 SS_LEN,
667 PK,
668 SK,
669 k,
670 eta,
671 du,
672 dv,
673 LAMBDA,
674 T_PACKED_LEN,
675 >
676{
677 /// Generates a fresh key pair.
678 fn keygen() -> Result<(PK, SK), KEMError> {
679 Self::keygen_from_os_rng()
680 }
681
682 fn encaps(pk: &PK) -> Result<(KeyMaterial<SS_LEN>, [u8; CT_LEN]), KEMError> {
683 let mut m = [0u8; 32];
684 HashDRBG_SHA512::new_from_os().next_bytes_out(&mut m)?;
685
686 let (ss_bytes, ct) = Self::encaps_internal(pk, m);
687
688 let mut ss_keymaterial =
689 KeyMaterial::<SS_LEN>::from_bytes_as_type(&ss_bytes, KeyType::BytesFullEntropy)?;
690 ss_keymaterial.allow_hazardous_operations();
691 ss_keymaterial.set_security_strength(SecurityStrength::from_bits(LAMBDA as usize))?;
692 ss_keymaterial.drop_hazardous_operations();
693
694 Ok((ss_keymaterial, ct))
695 }
696 /// Performs a decapsulation of the given ciphertext.
697 /// Returns the shared secret key.
698 /// The derived shared secret key is returned as a KeyMaterial with the SecurityStrength set to
699 /// the security level of the ML-KEM parameter set.
700 /// As ML-KEM is an implicitly-rejecting KEM, this returns an error only if the ciphertext is invalid (ie the wrong length).
701 fn decaps(sk: &SK, ct: &[u8]) -> Result<KeyMaterial<SS_LEN>, KEMError> {
702 if ct.len() != CT_LEN {
703 return Err(KEMError::LengthError("Invalid ciphertext length"));
704 }
705
706 let ss_bytes = Self::decaps_internal(sk, ct.try_into().unwrap());
707
708 let mut ss_keymaterial =
709 KeyMaterial::<SS_LEN>::from_bytes_as_type(&ss_bytes, KeyType::BytesFullEntropy)?;
710 ss_keymaterial.allow_hazardous_operations();
711 ss_keymaterial.set_security_strength(SecurityStrength::from_bits(LAMBDA as usize))?;
712 ss_keymaterial.drop_hazardous_operations();
713
714 Ok(ss_keymaterial)
715 }
716}