1use crate::aux_functions::{sample_poly_CBD};
2use crate::mlkem::{POLY_BYTES, H, G};
3use crate::{ML_KEM_512_NAME, ML_KEM_768_NAME, ML_KEM_1024_NAME};
4use crate::mlkem::{MLKEM512_k, MLKEM512_ETA1, MLKEM512_LAMBDA, MLKEM512_PK_LEN, MLKEM512_SK_LEN, MLKEM512_FULL_SK_LEN, MLKEM512_T_PACKED_LEN};
5use crate::mlkem::{MLKEM768_k, MLKEM768_ETA1, MLKEM768_LAMBDA, MLKEM768_PK_LEN, MLKEM768_SK_LEN, MLKEM768_FULL_SK_LEN, MLKEM768_T_PACKED_LEN};
6use crate::mlkem::{MLKEM1024_k, MLKEM1024_ETA1, MLKEM1024_LAMBDA, MLKEM1024_PK_LEN, MLKEM1024_SK_LEN, MLKEM1024_FULL_SK_LEN, MLKEM1024_T_PACKED_LEN};
7use bouncycastle_core::key_material::{KeyMaterialTrait, KeyMaterial, KeyType};
8use bouncycastle_core::traits::{Hash, KEMPrivateKey, KEMPublicKey, Secret, SecurityStrength};
9use bouncycastle_core::errors::{KEMError};
10use core::fmt;
11use core::fmt::{Debug, Display, Formatter};
12use bouncycastle_sha3::SHA3_256;
13use crate::low_memory_helpers::{compute_A_hat_dot_s_hat, pack_s_hat_row, pack_t_hat_row};
14use crate::polynomial::{Polynomial};
15
16pub type MLKEM512PublicKey = MLKEMPublicKey<MLKEM512_k, MLKEM512_PK_LEN, MLKEM512_T_PACKED_LEN>;
22pub type MLKEM512PrivateKey = MLKEMSeedPrivateKey<MLKEM512_k, MLKEM512_ETA1, MLKEM512_LAMBDA, MLKEM512_SK_LEN, MLKEM512_FULL_SK_LEN, MLKEM512_PK_LEN, MLKEM512_T_PACKED_LEN>;
24pub type MLKEM768PublicKey = MLKEMPublicKey<MLKEM768_k, MLKEM768_PK_LEN, MLKEM768_T_PACKED_LEN>;
26pub type MLKEM768PrivateKey = MLKEMSeedPrivateKey<MLKEM768_k, MLKEM768_ETA1, MLKEM768_LAMBDA, MLKEM768_SK_LEN, MLKEM768_FULL_SK_LEN, MLKEM768_PK_LEN, MLKEM768_T_PACKED_LEN>;
28pub type MLKEM1024PublicKey = MLKEMPublicKey<MLKEM1024_k, MLKEM1024_PK_LEN, MLKEM1024_T_PACKED_LEN>;
30pub type MLKEM1024PrivateKey = MLKEMSeedPrivateKey<MLKEM1024_k, MLKEM1024_ETA1, MLKEM1024_LAMBDA, MLKEM1024_SK_LEN, MLKEM1024_FULL_SK_LEN, MLKEM1024_PK_LEN, MLKEM1024_T_PACKED_LEN>;
32
33
34#[derive(Clone)]
36pub struct MLKEMPublicKey<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize> {
37 pub(crate) t_hat_packed: [u8; T_PACKED_LEN],
38 pub(crate) rho: [u8; 32],
39}
40
41pub trait MLKEMPublicKeyTrait<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize> : KEMPublicKey<PK_LEN> {
43 fn pk_decode(pk: &[u8; PK_LEN]) -> Self;
48
49 fn t_hat_packed(&self) -> &[u8; T_PACKED_LEN];
51
52 fn rho(&self) -> &[u8; 32];
54
55 fn compute_hash(&self) -> [u8; 32];
57}
58
59pub(crate) trait MLKEMPublicKeyInternalTrait<
60 const k: usize,
61 const T_PACKED_LEN: usize,
62 const PK_LEN: usize
63> : MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN> {
64 fn new(t_hat: [u8; T_PACKED_LEN], rho: [u8; 32]) -> Self;
67}
68
69impl<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize>
70MLKEMPublicKeyTrait<k, PK_LEN, T_PACKED_LEN> for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
71 fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
72 Self::new(pk[..T_PACKED_LEN].try_into().unwrap(), pk[T_PACKED_LEN..].try_into().unwrap())
73 }
74
75 fn t_hat_packed(&self) -> &[u8; T_PACKED_LEN] {
76 &self.t_hat_packed
77 }
78
79 fn rho(&self) -> &[u8; 32] {
80 &self.rho
81 }
82
83 fn compute_hash(&self) -> [u8; 32] {
84 let mut out = [0u8; 32];
87 let mut h = H::default();
88 h.do_update(&self.t_hat_packed);
89 h.do_update(&self.rho);
90 let bytes_written = h.do_final_out(&mut out);
91 debug_assert_eq!(bytes_written, 32);
92 out
93 }
94}
95
96impl<const k: usize, const T_PACKED_LEN: usize, const PK_LEN: usize>
97MLKEMPublicKeyInternalTrait<k, T_PACKED_LEN, PK_LEN> for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
98 fn new(t_hat_packed: [u8; T_PACKED_LEN], rho: [u8; 32]) -> Self {
99 Self { rho, t_hat_packed }
100 }
101}
102
103impl<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize>
104KEMPublicKey<PK_LEN> for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
105 fn encode(&self) -> [u8; PK_LEN] {
110 debug_assert_eq!(PK_LEN, 32 + 12*k*32);
111 let mut pk = [0u8; PK_LEN];
112 self.encode_out(&mut pk);
113
114 pk
115 }
116
117 fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
118 debug_assert_eq!(self.t_hat_packed.len(), T_PACKED_LEN);
119
120 out[.. T_PACKED_LEN].copy_from_slice(&self.t_hat_packed);
121 debug_assert_eq!(out[T_PACKED_LEN..].len(), 32);
122 out[T_PACKED_LEN..].copy_from_slice(&self.rho);
123
124 PK_LEN
125 }
126
127 fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
128 if bytes.len() != PK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
129 let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
130 Ok(Self::pk_decode(&bytes_sized))
131 }
132}
133
134impl<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize>
135Eq for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> { }
136
137impl<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize>
138PartialEq for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
139 fn eq(&self, other: &Self) -> bool {
140 bouncycastle_utils::ct::ct_eq_bytes(&self.encode(), &other.encode())
141 }
142}
143
144impl<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize>
145Debug for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
146 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
147 let alg = match k {
148 2 => ML_KEM_512_NAME,
149 3 => ML_KEM_768_NAME,
150 4 => ML_KEM_1024_NAME,
151 _ => panic!("Unsupported key length"),
152 };
153 let hash = SHA3_256::new().hash(&self.encode());
154 write!(f, "MLKEMPublicKey {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
155 }
156}
157
158impl<const k: usize, const PK_LEN: usize, const T_PACKED_LEN: usize>
159Display for MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
160 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
161 let alg = match k {
162 2 => ML_KEM_512_NAME,
163 3 => ML_KEM_768_NAME,
164 4 => ML_KEM_1024_NAME,
165 _ => panic!("Unsupported key length"),
166 };
167 let hash = SHA3_256::new().hash(&self.encode());
168 write!(f, "MLKEMPublicKey {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
169 }
170}
171
172
173
174
175
176#[derive(Clone)]
178pub struct MLKEMSeedPrivateKey<
179 const k: usize,
180 const eta1: i16,
181 const LAMBDA: i16,
182 const SK_LEN: usize,
183 const FULL_SK_LEN: usize,
184 const PK_LEN: usize,
185 const T_PACKED_LEN: usize
186> {
187 rho: [u8; 32],
188 sigma: [u8; 32],
189 pk_hash: Option<[u8; 32]>,
190 z: [u8; 32],
191 seed_d: [u8; 32],
192}
193
194impl<
195 const k: usize,
196 const eta1: i16,
197 const LAMBDA: i16,
198 const SK_LEN: usize,
199 const FULL_SK_LEN: usize,
200 const PK_LEN: usize,
201 const T_PACKED_LEN: usize
202> MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> {
203 pub fn new(seed: &KeyMaterial<64>) -> Result<Self, KEMError> {
206 if !(seed.key_type() == KeyType::Seed || seed.key_type() == KeyType::BytesFullEntropy)
207 || seed.key_len() != 64
208 {
209 return Err(KEMError::KeyGenError(
210 "Seed must be 64 bytes and KeyType::Seed or KeyType::BytesFullEntropy.",
211 ));
212 }
213
214 if seed.security_strength() < SecurityStrength::from_bits(LAMBDA as usize) {
215 return Err(KEMError::KeyGenError("SecurityStrength"));
216 }
217
218 let seed_d: [u8; 32] = seed.ref_to_bytes()[..32].try_into().unwrap();
219 let z: [u8; 32] = seed.ref_to_bytes()[32..].try_into().unwrap();
220
221 let (rho, sigma) = Self::compute_rho_and_sigma(&seed_d);
222
223 Ok(Self { rho, sigma, pk_hash: None, z, seed_d })
226 }
227 fn compute_rho_and_sigma(seed_d: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
233 let mut g = G::new();
234 g.do_update(seed_d);
235 g.do_update(&[k as u8]);
236 let mut buf = [0u8; 64];
237 let bytes_written = g.do_final_out(&mut buf);
238 debug_assert_eq!(bytes_written, 64);
239
240 (buf[..32].try_into().unwrap(), buf[32..64].try_into().unwrap())
241 }
242}
243
244
245pub trait MLKEMPrivateKeyTrait<
247 const k: usize,
248 const SK_LEN: usize,
249 const FULL_SK_LEN: usize,
250 const PK_LEN: usize,
251 const T_PACKED_LEN: usize,
252> : KEMPrivateKey<SK_LEN> {
253 fn from_keymaterial(seed: &KeyMaterial<64>) -> Result<Self, KEMError>;
255 fn seed(&self) -> Option<KeyMaterial<64>>;
258 fn pk(&self) -> MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN>;
261 fn pk_hash(&mut self) -> &[u8; 32];
267 fn encode_full_sk(&self) -> [u8; FULL_SK_LEN];
276 fn full_sk_encode_out(&self, out: &mut [u8; FULL_SK_LEN]) -> usize;
285 fn sk_decode(sk: &[u8; SK_LEN]) -> Self;
287}
288
289pub(crate) trait MLKEMPrivateKeyInternalTrait<const k: usize, const SK_LEN: usize, const PK_LEN: usize,
290 const T_PACKED_LEN: usize,> {
291
292 fn z(&self) -> &[u8; 32];
293
294 fn compute_s_hat_row(&self, idx: usize) -> Polynomial;
295
296 fn rho(&self) -> &[u8; 32];
297
298 fn t_hat_packed(&self) -> [u8; T_PACKED_LEN];
300}
301
302
303impl<
304 const k: usize,
305 const eta1: i16,
306 const LAMBDA: i16,
307 const SK_LEN: usize,
308 const FULL_SK_LEN: usize,
309 const PK_LEN: usize,
310 const T_PACKED_LEN: usize
311> MLKEMPrivateKeyTrait<k, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> {
312 fn from_keymaterial(seed: &KeyMaterial<64>) -> Result<Self, KEMError> {
313 Self::new(seed)
314 }
315 fn seed(&self) -> Option<KeyMaterial<64>> {
316 let mut tmp = [0u8; 64];
317 tmp[..32].copy_from_slice(&self.seed_d);
318 tmp[32..].copy_from_slice(&self.z);
319 let mut seed = KeyMaterial::<64>::from_bytes_as_type(&tmp, KeyType::Seed).unwrap();
320 seed.allow_hazardous_operations();
321 seed.set_security_strength( match k {
322 2 => SecurityStrength::_128bit,
323 3 => SecurityStrength::_192bit,
324 4 => SecurityStrength::_256bit,
325 _ => unreachable!("Invalid mlkem param set"),
326 }).unwrap();
327 seed.drop_hazardous_operations();
328
329 Some(seed)
330 }
331 fn pk(&self) -> MLKEMPublicKey<k, PK_LEN, T_PACKED_LEN> {
332 MLKEMPublicKey::<k, PK_LEN, T_PACKED_LEN>::new(self.t_hat_packed(), self.rho)
333 }
334 fn pk_hash(&mut self) -> &[u8; 32] {
335 if self.pk_hash.is_none() {
336 self.pk_hash = Some(self.pk().compute_hash().clone());
337 }
338
339 &self.pk_hash.as_ref().unwrap()
340 }
341 fn encode_full_sk(&self) -> [u8; FULL_SK_LEN] {
350 let mut out = [0u8; FULL_SK_LEN];
351 self.full_sk_encode_out(&mut out);
352
353 out
354 }
355 fn full_sk_encode_out(&self, out: &mut [u8; FULL_SK_LEN]) -> usize {
363 out.fill(0);
364
365 let mut pos = 0usize;
366
367 for i in 0..k {
370 pack_s_hat_row::<k>(&self.compute_s_hat_row(i), i, out);
374 }
375 pos += k * POLY_BYTES;
376
377 let pk = self.pk();
380 out[pos .. pos + PK_LEN].copy_from_slice(&pk.encode());
381 pos += PK_LEN;
382
383 out[pos .. pos + 32].copy_from_slice(&pk.compute_hash());
385 pos += 32;
386
387 out[pos .. pos + 32].copy_from_slice(&self.z);
389
390
391 FULL_SK_LEN
392 }
393 fn sk_decode(sk: &[u8; SK_LEN]) -> Self {
394 debug_assert_eq!(SK_LEN, 64);
395 Self::from_bytes(sk).unwrap()
396 }
397}
398
399impl<
400 const k: usize,
401 const eta1: i16,
402 const LAMBDA: i16,
403 const SK_LEN: usize,
404 const FULL_SK_LEN: usize,
405 const PK_LEN: usize,
406 const T_PACKED_LEN: usize
407> MLKEMPrivateKeyInternalTrait<k, SK_LEN, PK_LEN, T_PACKED_LEN> for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> {
408
409 fn z(&self) -> &[u8; 32] { &self.z }
410
411 fn compute_s_hat_row(&self, idx: usize) -> Polynomial {
412 debug_assert!(idx < k);
413
414 let mut s_i = sample_poly_CBD::<eta1>(&self.sigma, idx as u8);
422
423 s_i.ntt();
425 s_i
426 }
427
428 fn rho(&self) -> &[u8; 32] {
429 &self.rho
430 }
431 fn t_hat_packed(&self) -> [u8; T_PACKED_LEN] {
434 let mut t_hat_packed = [0u8; T_PACKED_LEN];
435
436 for i in 0 .. k {
437 let mut t_hat_i = compute_A_hat_dot_s_hat::<k, eta1>(&self.rho, &self.sigma, i);
440
441 {
444 let mut e_i = sample_poly_CBD::<eta1>(&self.sigma, (k + i) as u8);
451
452 e_i.ntt(); t_hat_i.add(&e_i);
454 }
455 t_hat_i.poly_reduce();
456
457 pack_t_hat_row::<T_PACKED_LEN>(&t_hat_i, i, &mut t_hat_packed);
458 }
459
460 t_hat_packed
461 }
462}
463
464impl<
465 const k: usize,
466 const eta1: i16,
467 const LAMBDA: i16,
468 const SK_LEN: usize,
469 const FULL_SK_LEN: usize,
470 const PK_LEN: usize,
471 const T_PACKED_LEN: usize,
472> KEMPrivateKey<SK_LEN> for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> {
473 fn encode(&self) -> [u8; SK_LEN] {
475 let mut sk = [0u8; SK_LEN];
476 self.encode_out(&mut sk);
477
478 sk
479 }
480
481 fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
482 debug_assert_eq!(SK_LEN, 64);
483
484 out[..32].copy_from_slice(&self.seed_d);
485 out[32..].copy_from_slice(&self.z);
486
487 SK_LEN
488 }
489
490 fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
491 if bytes.len() != 64 {
492 return Err(KEMError::DecodingError("Invalid seed length"));
493 }
494 let mut keymat = KeyMaterial::<64>::from_bytes(bytes)?;
495 keymat.allow_hazardous_operations();
496 keymat.set_key_type(KeyType::Seed)?;
497 keymat.set_security_strength(SecurityStrength::_256bit)?;
498 keymat.drop_hazardous_operations();
499
500 Self::new(&keymat)
501 }
502}
503
504impl<
505 const k: usize,
506 const eta1: i16,
507 const LAMBDA: i16,
508 const SK_LEN: usize,
509 const FULL_SK_LEN: usize,
510 const PK_LEN: usize,
511 const T_PACKED_LEN: usize,
512> Eq for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> {}
513
514impl<
515 const k: usize,
516 const eta1: i16,
517 const LAMBDA: i16,
518 const SK_LEN: usize,
519 const FULL_SK_LEN: usize,
520 const PK_LEN: usize,
521 const T_PACKED_LEN: usize,
522> PartialEq for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
523{
524 fn eq(&self, other: &Self) -> bool {
525 let self_encoded = self.encode();
526 let other_encoded = other.encode();
527 bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
528 }
529}
530
531impl<
532 const k: usize,
533 const eta1: i16,
534 const LAMBDA: i16,
535 const SK_LEN: usize,
536 const FULL_SK_LEN: usize,
537 const PK_LEN: usize,
538 const T_PACKED_LEN: usize,
539> Secret for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN> {}
540
541impl<
543 const k: usize,
544 const eta1: i16,
545 const LAMBDA: i16,
546 const SK_LEN: usize,
547 const FULL_SK_LEN: usize,
548 const PK_LEN: usize,
549 const T_PACKED_LEN: usize,
550> fmt::Debug for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
551{
552 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
553 let alg = match k {
554 2 => ML_KEM_512_NAME,
555 3 => ML_KEM_768_NAME,
556 4 => ML_KEM_1024_NAME,
557 _ => panic!("Unsupported key length"),
558 };
559 let pk_hash = self.pk().compute_hash();
560 write!(
561 f,
562 "MLKEMSeedPrivateKey {{ alg: {}, pub_key_hash: {:x?} }}",
563 alg,
564 &pk_hash,
565 )
566 }
567}
568
569impl<
571 const k: usize,
572 const eta1: i16,
573 const LAMBDA: i16,
574 const SK_LEN: usize,
575 const FULL_SK_LEN: usize,
576 const PK_LEN: usize,
577 const T_PACKED_LEN: usize,
578> Display for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
579{
580 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
581 let alg = match k {
582 2 => ML_KEM_512_NAME,
583 3 => ML_KEM_768_NAME,
584 4 => ML_KEM_1024_NAME,
585 _ => panic!("Unsupported key length"),
586 };
587 let pk_hash = self.pk().compute_hash();
588 write!(
589 f,
590 "MLKEMSeedPrivateKey {{ alg: {}, pub_key_hash: {:x?} }}",
591 alg,
592 &pk_hash,
593 )
594 }
595}
596
597impl<
599 const k: usize,
600 const eta1: i16,
601 const LAMBDA: i16,
602 const SK_LEN: usize,
603 const FULL_SK_LEN: usize,
604 const PK_LEN: usize,
605 const T_PACKED_LEN: usize,
606> Drop for MLKEMSeedPrivateKey<k, eta1, LAMBDA, SK_LEN, FULL_SK_LEN, PK_LEN, T_PACKED_LEN>
607{
608 fn drop(&mut self) {
609 self.rho.fill(0u8);
610 self.sigma.fill(0u8);
611 self.z.fill(0u8);
612 self.seed_d.fill(0u8);
613 }
614}