1use crate::aux_functions::{byte_decode, byte_encode, expandA};
2use crate::matrix::{Matrix, Vector};
3use crate::mlkem::{POLY_BYTES, H, q};
4use crate::{ML_KEM_512_NAME, ML_KEM_768_NAME, ML_KEM_1024_NAME};
5use crate::mlkem::{MLKEM512_k, MLKEM512_PK_LEN, MLKEM512_SK_LEN};
6use crate::mlkem::{MLKEM768_k, MLKEM768_PK_LEN, MLKEM768_SK_LEN};
7use crate::mlkem::{MLKEM1024_k, MLKEM1024_PK_LEN, MLKEM1024_SK_LEN};
8use bouncycastle_core::key_material::{KeyMaterialTrait, KeyMaterial, KeyType};
9use bouncycastle_core::traits::{Hash, KEMPrivateKey, KEMPublicKey, Secret, SecurityStrength};
10use bouncycastle_core::errors::KEMError;
11use core::fmt;
12use core::fmt::{Debug, Display, Formatter};
13use bouncycastle_sha3::SHA3_256;
14
15
16#[allow(unused_imports)]
18use crate::mlkem::MLKEMTrait;
19
20
21
22pub type MLKEM512PublicKey = MLKEMPublicKey<MLKEM512_k, MLKEM512_PK_LEN>;
26pub type MLKEM512PrivateKey = MLKEMPrivateKey<MLKEM512_k, MLKEM512PublicKey, MLKEM512_SK_LEN, MLKEM512_PK_LEN>;
28pub type MLKEM768PublicKey = MLKEMPublicKey<MLKEM768_k, MLKEM768_PK_LEN>;
30pub type MLKEM768PrivateKey = MLKEMPrivateKey<MLKEM768_k, MLKEM768PublicKey, MLKEM768_SK_LEN, MLKEM768_PK_LEN>;
32pub type MLKEM1024PublicKey = MLKEMPublicKey<MLKEM1024_k, MLKEM1024_PK_LEN>;
34pub type MLKEM1024PrivateKey = MLKEMPrivateKey<MLKEM1024_k, MLKEM1024PublicKey, MLKEM1024_SK_LEN, MLKEM1024_PK_LEN>;
36
37
38pub type MLKEM512PublicKeyExpanded = MLKEMPublicKeyExpanded<MLKEM512_k, MLKEM512PublicKey, MLKEM512_PK_LEN>;
42pub type MLKEM512PrivateKeyExpanded = MLKEMPrivateKeyExpanded<MLKEM512_k, MLKEM512PublicKey, MLKEM512PrivateKey, MLKEM512_SK_LEN, MLKEM512_PK_LEN>;
44pub type MLKEM768PublicKeyExpanded = MLKEMPublicKeyExpanded<MLKEM768_k, MLKEM768PublicKey, MLKEM768_PK_LEN>;
46pub type MLKEM768PrivateKeyExpanded = MLKEMPrivateKeyExpanded<MLKEM768_k, MLKEM768PublicKey, MLKEM768PrivateKey, MLKEM768_SK_LEN, MLKEM768_PK_LEN>;
48pub type MLKEM1024PublicKeyExpanded = MLKEMPublicKeyExpanded<MLKEM1024_k, MLKEM1024PublicKey, MLKEM1024_PK_LEN>;
50pub type MLKEM1024PrivateKeyExpanded = MLKEMPrivateKeyExpanded<MLKEM1024_k, MLKEM1024PublicKey,MLKEM1024PrivateKey, MLKEM1024_SK_LEN, MLKEM1024_PK_LEN>;
52
53#[derive(Clone)]
55pub struct MLKEMPublicKey<const k: usize, const PK_LEN: usize> {
56 t_hat: Vector<k>,
57 rho: [u8; 32],
58}
59
60pub trait MLKEMPublicKeyTrait<const k: usize, const PK_LEN: usize> : KEMPublicKey<PK_LEN> {
62 fn pk_decode(pk: &[u8; PK_LEN]) -> Result<Self, KEMError>;
67 fn A_hat(&self) -> Matrix<k, k>;
69 fn compute_hash(&self) -> [u8; 32];
71}
72
73pub(crate) trait MLKEMPublicKeyInternalTrait<const k: usize, const PK_LEN: usize> : MLKEMPublicKeyTrait<k, PK_LEN> {
74 fn new(t_hat: Vector<k>, rho: [u8; 32], ) -> Self;
77
78 fn t_hat(&self) -> &Vector<k>;
80}
81
82impl<const k: usize, const PK_LEN: usize> MLKEMPublicKeyTrait<k, PK_LEN> for MLKEMPublicKey<k, PK_LEN> {
83 fn pk_decode(pk: &[u8; PK_LEN]) -> Result<Self, KEMError> {
84 let (pk_chunks, last_chunk) = pk.as_chunks::<POLY_BYTES>();
85
86 debug_assert_eq!(pk_chunks.len(), k);
88 debug_assert_eq!(last_chunk.len(), 32);
89
90 let t_hat = {
91 let mut t_hat = Vector::<k>::new();
92
93 for (t_i, pk_chunk) in t_hat.vec.iter_mut().zip(pk_chunks) {
94 t_i.coeffs.copy_from_slice(&byte_decode::<12, POLY_BYTES>(pk_chunk).coeffs);
95
96 for coeff in t_i.coeffs.iter() {
103 if *coeff < 0 || *coeff >= q {
104 return Err(KEMError::DecodingError("Invalid or corrupted key"));
105 }
106 }
107
108 }
109
110 t_hat
111 };
112 let rho = last_chunk.try_into().unwrap();
113
114 Ok(Self::new(t_hat, rho))
115 }
116
117 fn A_hat(&self) -> Matrix<k, k> {
118 expandA(&self.rho)
119 }
120
121 fn compute_hash(&self) -> [u8; 32] {
122 let mut out = [0u8; 32];
123 let bytes_written = H::default().hash_out(&self.encode(), &mut out);
124 debug_assert_eq!(bytes_written, 32);
125 out
126 }
127}
128
129impl<const k: usize, const PK_LEN: usize> MLKEMPublicKeyInternalTrait<k, PK_LEN> for MLKEMPublicKey<k, PK_LEN> {
130 fn new(t_hat: Vector<k>, rho: [u8; 32]) -> Self {
131 Self { rho, t_hat }
132 }
133
134 fn t_hat(&self) -> &Vector<k> { &self.t_hat }
135}
136
137impl<const k: usize, const PK_LEN: usize> KEMPublicKey<PK_LEN> for MLKEMPublicKey<k, PK_LEN> {
138 fn encode(&self) -> [u8; PK_LEN] {
141 let mut pk = [0u8; PK_LEN];
142 self.encode_out(&mut pk);
143
144 pk
145 }
146 fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
149 debug_assert_eq!(PK_LEN, 12*k*32 + 32);
150 debug_assert_eq!(POLY_BYTES, 12*32);
151
152 let (pk_chunks, last_chunk) = out.as_chunks_mut::<POLY_BYTES>();
153
154 debug_assert_eq!(pk_chunks.len(), k);
156 debug_assert_eq!(last_chunk.len(), 32);
157
158 for (pk_chunk, t_i) in pk_chunks.into_iter().zip(&self.t_hat.vec) {
159 pk_chunk.copy_from_slice(&byte_encode::<12, POLY_BYTES>(t_i));
160 }
161 last_chunk.copy_from_slice(&self.rho);
162
163 PK_LEN
164 }
165
166 fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
167 if bytes.len() != PK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
168 let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
169 Self::pk_decode(&bytes_sized)
170 }
171}
172
173impl<const k: usize, const PK_LEN: usize> Eq for MLKEMPublicKey<k, PK_LEN> { }
174
175impl<const k: usize, const PK_LEN: usize> PartialEq for MLKEMPublicKey<k, PK_LEN> {
176 fn eq(&self, other: &Self) -> bool {
177 bouncycastle_utils::ct::ct_eq_bytes(&self.encode(), &other.encode())
178 }
179}
180
181impl<const k: usize, const PK_LEN: usize> Debug for MLKEMPublicKey<k, PK_LEN> {
182 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
183 let alg = match k {
184 2 => ML_KEM_512_NAME,
185 3 => ML_KEM_768_NAME,
186 4 => ML_KEM_1024_NAME,
187 _ => panic!("Unsupported key length"),
188 };
189 let hash = SHA3_256::new().hash(&self.encode());
190 write!(f, "MLKEMPublicKey {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
191 }
192}
193
194impl<const k: usize, const PK_LEN: usize> Display for MLKEMPublicKey<k, PK_LEN> {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 let alg = match k {
197 2 => ML_KEM_512_NAME,
198 3 => ML_KEM_768_NAME,
199 4 => ML_KEM_1024_NAME,
200 _ => panic!("Unsupported key length"),
201 };
202 let hash = SHA3_256::new().hash(&self.encode());
203 write!(f, "MLKEMPublicKey {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
204 }
205}
206
207#[derive(Clone)]
211pub struct MLKEMPublicKeyExpanded<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize> {
212 pub(crate) ek: PK,
213 pub(crate) A_hat: Matrix<k, k>,
214}
215
216impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
217MLKEMPublicKeyInternalTrait<k, PK_LEN> for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
218 fn new(t_hat: Vector<k>, rho: [u8; 32]) -> Self {
219 let ek = PK::new(t_hat, rho);
220 let A_hat = ek.A_hat();
221
222 Self {
223 ek,
224 A_hat,
225 }
226 }
227
228 fn t_hat(&self) -> &Vector<k> {
229 self.ek.t_hat()
230 }
231}
232
233impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
234KEMPublicKey<PK_LEN> for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
235 fn encode(&self) -> [u8; PK_LEN] {
236 let mut pk = [0u8; PK_LEN];
237 self.encode_out(&mut pk);
238
239 pk
240 }
241
242 fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
243 self.ek.encode_out(out)
244 }
245
246 fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
247 if bytes.len() != PK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
248 let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
249 Self::pk_decode(&bytes_sized)
250 }
251}
252
253impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
254PartialEq for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
255 fn eq(&self, other: &Self) -> bool {
256 self.encode() == other.encode()
257 }
258}
259
260impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
261Eq for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {}
262
263impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
264Debug for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
265 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
266 let alg = match k {
267 2 => ML_KEM_512_NAME,
268 3 => ML_KEM_768_NAME,
269 4 => ML_KEM_1024_NAME,
270 _ => panic!("Unsupported key length"),
271 };
272 let hash = SHA3_256::new().hash(&self.encode());
273 write!(f, "MLKEMPublicKeyExpanded {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
274 }
275}
276
277impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
278Display for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
279 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280 let alg = match k {
281 2 => ML_KEM_512_NAME,
282 3 => ML_KEM_768_NAME,
283 4 => ML_KEM_1024_NAME,
284 _ => panic!("Unsupported key length"),
285 };
286 let hash = SHA3_256::new().hash(&self.encode());
287 write!(f, "MLKEMPublicKeyExpanded {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
288 }
289}
290
291impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
292MLKEMPublicKeyTrait<k, PK_LEN> for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
293 fn pk_decode(pk: &[u8; PK_LEN]) -> Result<Self, KEMError> {
294 let ek = PK::pk_decode(pk)?;
295 let A_hat = ek.A_hat();
296 Ok(Self { ek, A_hat })
297 }
298
299 fn A_hat(&self) -> Matrix<k, k> {
300 self.A_hat.clone()
301 }
302
303 fn compute_hash(&self) -> [u8; 32] {
304 self.ek.compute_hash()
305 }
306}
307
308impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize> From<&PK>
309for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
310 fn from(ek: &PK) -> Self {
313 let A_hat = ek.A_hat();
314
315 Self {
316 ek: ek.clone(),
317 A_hat,
318 }
319 }
320}
321
322
323
324
325
326#[derive(Clone)]
328pub struct MLKEMPrivateKey<
329 const k: usize,
330 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
331 const SK_LEN: usize,
332 const PK_LEN: usize,
333> {
334 s_hat: Vector<k>,
335 ek: PK,
336 pk_hash: [u8; 32],
337 z: [u8; 32],
338 seed_d: Option<[u8; 32]>,
339}
340
341impl<
342 const k: usize,
343 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
344 const SK_LEN: usize,
345 const PK_LEN: usize,
346> MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
347 fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
350 debug_assert_eq!(SK_LEN, 12*k*32 + PK_LEN + 32 + 32);
351
352 let mut pos = 0usize;
353
354 for i in 0..k {
357 out[i*POLY_BYTES .. (i+1)*POLY_BYTES].copy_from_slice(&byte_encode::<12, POLY_BYTES>(
358 &self.s_hat[i]
359 ));
360 }
361 pos += k * POLY_BYTES;
362
363 debug_assert_eq!(self.ek.encode().len(), PK_LEN);
366 out[pos .. pos + PK_LEN].copy_from_slice(&self.ek.encode());
367 pos += PK_LEN;
368
369 out[pos .. pos + 32].copy_from_slice(&self.pk_hash);
371 pos += 32;
372
373 out[pos .. pos + 32].copy_from_slice(&self.z);
375
376 debug_assert_eq!(pos + 32, SK_LEN);
377 SK_LEN
378 }
379}
380
381pub trait MLKEMPrivateKeyTrait<
383 const k: usize,
384 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
385 const SK_LEN: usize,
386 const PK_LEN: usize> : KEMPrivateKey<SK_LEN> {
387 fn seed(&self) -> Option<KeyMaterial<64>>;
389
390 fn pk(&self) -> &PK;
392 fn pk_hash(&self) -> &[u8; 32];
394 fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, KEMError>;
396}
397
398pub(crate) trait MLKEMPrivateKeyInternalTrait<const k: usize, PK: MLKEMPublicKeyTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize> {
399 fn new(
402 s_hat: Vector<k>,
403 ek: PK,
404 h: [u8; 32],
405 z: [u8; 32],
406 seed_d: Option<[u8; 32]>,
407 ) -> Self;
408
409 fn s_hat(&self) -> &Vector<k>;
411
412 fn z(&self) -> &[u8; 32];
413}
414
415
416impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
417 MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
418 fn seed(&self) -> Option<KeyMaterial<64>> {
419 if self.seed_d.is_none() {
420 None
421 } else {
422 let mut tmp = [0u8; 64];
423 tmp[..32].copy_from_slice(&self.seed_d.unwrap());
424 tmp[32..].copy_from_slice(&self.z);
425 let mut seed = KeyMaterial::<64>::from_bytes_as_type(&tmp, KeyType::Seed).unwrap();
426 seed.allow_hazardous_operations();
427 seed.set_security_strength( match k {
428 2 => SecurityStrength::_128bit,
429 3 => SecurityStrength::_192bit,
430 4 => SecurityStrength::_256bit,
431 _ => unreachable!("Invalid mlkem param set"),
432 }).unwrap();
433 seed.drop_hazardous_operations();
434 Some(seed)
435 }
436 }
437
438 fn pk(&self) -> &PK {
439 &self.ek
440 }
441
442 fn pk_hash(&self) -> &[u8; 32] {
443 &self.pk_hash
444 }
445
446 fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, KEMError> {
447 debug_assert_eq!(SK_LEN, 12*k*32 + PK_LEN + 32 + 32);
448
449 let mut pos = 0usize;
450
451 let mut s_hat = Vector::<k>::new();
453 for i in 0..k {
455 s_hat[i] = byte_decode::<12, POLY_BYTES>(
456 sk[i*POLY_BYTES .. (i+1)*POLY_BYTES].try_into().unwrap()
457 );
458
459 for coeff in s_hat[i].coeffs.iter() {
466 if *coeff < -q || *coeff >= q {
467 return Err(KEMError::DecodingError("Invalid or corrupted key"));
468 }
469 }
470 }
471 pos += k * POLY_BYTES;
472
473 let ek = PK::pk_decode(sk[pos .. pos + PK_LEN].try_into().unwrap())?;
475 pos += PK_LEN;
476
477 let h_pk: [u8; 32] = sk[pos .. pos + 32].try_into().unwrap();
479 pos += 32;
480
481 if h_pk != ek.compute_hash() {
485 return Err(KEMError::ConsistencyCheckFailed("Corrupted private key: computed hash of ek != h_ek stored in private key"));
486 }
487
488 let z: [u8; 32] = sk[pos .. pos + 32].try_into().unwrap();
490
491 Ok(Self::new(s_hat, ek, h_pk, z, None))
492 }
493}
494
495impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
496 MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN> for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
497 fn new(
499 s_hat: Vector<k>,
500 ek: PK,
501 pk_hash: [u8; 32],
502 z: [u8; 32],
503 seed_d: Option<[u8; 32]>,
504 ) -> Self {
505 Self {
506 s_hat,
507 ek,
508 pk_hash,
509 z,
510 seed_d: seed_d.clone(),
511 }
512 }
513
514 fn s_hat(&self) -> &Vector<k> { &self.s_hat }
515
516 fn z(&self) -> &[u8; 32] { &self.z }
517}
518
519impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize
520> KEMPrivateKey<SK_LEN> for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
521 fn encode(&self) -> [u8; SK_LEN] {
522 let mut out = [0u8; SK_LEN];
523 self.encode_out(&mut out);
524
525 out
526 }
527
528 fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
529 self.sk_encode_out(out)
530 }
531
532 fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
533 if bytes.len() != SK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
534 let bytes_sized: [u8; SK_LEN] = bytes[..SK_LEN].try_into().unwrap();
535
536 Self::sk_decode(&bytes_sized)
537 }
538}
539
540impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
541 Eq for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {}
542
543impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
544 PartialEq for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
545{
546 fn eq(&self, other: &Self) -> bool {
547 let self_encoded = self.encode();
548 let other_encoded = other.encode();
549 bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
550 }
551}
552
553impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
554Secret for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {}
555
556impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
558 fmt::Debug for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
559{
560 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
561 let alg = match k {
562 2 => ML_KEM_512_NAME,
563 3 => ML_KEM_768_NAME,
564 4 => ML_KEM_1024_NAME,
565 _ => panic!("Unsupported key length"),
566 };
567 write!(
568 f,
569 "MLKEMPrivateKey {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
570 alg,
571 self.pk_hash,
572 self.seed_d.is_some(),
573 )
574 }
575}
576
577impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
579 Display for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
580{
581 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
582 let alg = match k {
583 2 => ML_KEM_512_NAME,
584 3 => ML_KEM_768_NAME,
585 4 => ML_KEM_1024_NAME,
586 _ => panic!("Unsupported key length"),
587 };
588 write!(
589 f,
590 "MLKEMPrivateKey {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
591 alg,
592 self.pk_hash,
593 self.seed_d.is_some(),
594 )
595 }
596}
597
598impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
600Drop for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
601{
602 fn drop(&mut self) {
603 self.pk_hash.fill(0u8);
605 self.z.fill(0u8);
606 if self.seed_d.is_some() { self.seed_d.as_mut().unwrap().fill(0u8); }
607 }
608}
609
610
611
612#[derive(Clone)]
616pub struct MLKEMPrivateKeyExpanded<
617 const k: usize,
618 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
619 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
620 const SK_LEN: usize,
621 const PK_LEN: usize
622> {
623 _phantom: core::marker::PhantomData<PK>,
624 pub(crate) dk: SK,
625 pub(crate) A_hat: Matrix<k,k>,
626}
627
628impl<
629 const k: usize,
630 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
631 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
632 const SK_LEN: usize,
633 const PK_LEN: usize
634> From<&SK>
635for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
636 fn from(dk: &SK) -> Self {
639 let A_hat = dk.pk().A_hat();
640
641 Self {
642 _phantom: core::marker::PhantomData,
643 dk: dk.clone(),
644 A_hat,
645 }
646 }
647}
648
649impl<
650 const k: usize,
651 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
652 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
653 const SK_LEN: usize,
654 const PK_LEN: usize
655> KEMPrivateKey<SK_LEN> for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
656 fn encode(&self) -> [u8; SK_LEN] {
657 self.dk.encode()
658 }
659
660 fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
661 self.dk.encode_out(out)
662 }
663
664 fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
665 Ok(Self::from(&SK::from_bytes(bytes)?))
666 }
667}
668
669impl<
670 const k: usize,
671 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
672 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
673 const SK_LEN: usize,
674 const PK_LEN: usize
675> PartialEq for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
676 fn eq(&self, other: &Self) -> bool {
677 self.dk.eq(&other.dk)
678 }
679}
680
681impl<
682 const k: usize,
683 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
684 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
685 const SK_LEN: usize,
686 const PK_LEN: usize
687> Eq for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {}
688
689impl<
690 const k: usize,
691 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
692 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
693 const SK_LEN: usize,
694 const PK_LEN: usize
695> Secret for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {}
696
697impl<
698 const k: usize,
699 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
700 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
701 const SK_LEN: usize,
702 const PK_LEN: usize
703> Drop for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
704 fn drop(&mut self) {
705 }
707}
708
709impl<
710 const k: usize,
711 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
712 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
713 const SK_LEN: usize,
714 const PK_LEN: usize
715> Debug for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
716 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
717 let alg = match k {
718 2 => ML_KEM_512_NAME,
719 3 => ML_KEM_768_NAME,
720 4 => ML_KEM_1024_NAME,
721 _ => panic!("Unsupported key length"),
722 };
723 write!(
724 f,
725 "MLKEMPrivateKeyExpanded {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
726 alg,
727 self.dk.pk().compute_hash(),
728 self.dk.seed().is_some(),
729 )
730 }
731}
732
733impl<
734 const k: usize,
735 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
736 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
737 const SK_LEN: usize,
738 const PK_LEN: usize
739> Display for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
740 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
741 let alg = match k {
742 2 => ML_KEM_512_NAME,
743 3 => ML_KEM_768_NAME,
744 4 => ML_KEM_1024_NAME,
745 _ => panic!("Unsupported key length"),
746 };
747 write!(
748 f,
749 "MLKEMPrivateKeyExpanded {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
750 alg,
751 self.dk.pk().compute_hash(),
752 self.dk.seed().is_some(),
753 )
754 }
755}
756
757impl<
758 const k: usize,
759 PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
760 SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
761 const SK_LEN: usize,
762 const PK_LEN: usize
763> MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
764 fn seed(&self) -> Option<KeyMaterial<64>> {
765 self.dk.seed()
766 }
767
768 fn pk(&self) -> &PK {
769 self.dk.pk()
770 }
771
772 fn pk_hash(&self) -> &[u8; 32] {
773 &self.dk.pk_hash()
774 }
775
776 fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, KEMError> {
777 let dk = SK::sk_decode(sk)?;
778 let A_hat = dk.pk().A_hat();
779
780 Ok(Self {
781 _phantom: core::marker::PhantomData,
782 dk: dk.clone(),
783 A_hat,
784 })
785 }
786}