Skip to main content

bouncycastle_mlkem/
mlkem_keys.rs

1use crate::aux_functions::{byte_decode, byte_encode, expandA};
2use crate::matrix::{Matrix, Vector};
3use crate::mlkem::{POLY_BYTES, H, q};
4use crate::{ML_KEM_512_NAME, ML_KEM_768_NAME, ML_KEM_1024_NAME};
5use crate::mlkem::{MLKEM512_k, MLKEM512_PK_LEN, MLKEM512_SK_LEN};
6use crate::mlkem::{MLKEM768_k, MLKEM768_PK_LEN, MLKEM768_SK_LEN};
7use crate::mlkem::{MLKEM1024_k, MLKEM1024_PK_LEN, MLKEM1024_SK_LEN};
8use bouncycastle_core::key_material::{KeyMaterialTrait, KeyMaterial, KeyType};
9use bouncycastle_core::traits::{Hash, KEMPrivateKey, KEMPublicKey, Secret, SecurityStrength};
10use bouncycastle_core::errors::KEMError;
11use core::fmt;
12use core::fmt::{Debug, Display, Formatter};
13use bouncycastle_sha3::SHA3_256;
14
15
16// imports just for docs
17#[allow(unused_imports)]
18use crate::mlkem::MLKEMTrait;
19
20
21
22/* Pub Types */
23
24/// ML-KEM-512 Public Key
25pub type MLKEM512PublicKey = MLKEMPublicKey<MLKEM512_k, MLKEM512_PK_LEN>;
26/// ML-KEM-512 Private Key
27pub type MLKEM512PrivateKey = MLKEMPrivateKey<MLKEM512_k, MLKEM512PublicKey, MLKEM512_SK_LEN, MLKEM512_PK_LEN>;
28/// ML-KEM-768 Public Key
29pub type MLKEM768PublicKey = MLKEMPublicKey<MLKEM768_k, MLKEM768_PK_LEN>;
30/// ML-KEM-768 Private Key
31pub type MLKEM768PrivateKey = MLKEMPrivateKey<MLKEM768_k, MLKEM768PublicKey, MLKEM768_SK_LEN, MLKEM768_PK_LEN>;
32/// ML-KEM-1024 Public Key
33pub type MLKEM1024PublicKey = MLKEMPublicKey<MLKEM1024_k, MLKEM1024_PK_LEN>;
34/// ML-KEM-1024 Private Key
35pub type MLKEM1024PrivateKey = MLKEMPrivateKey<MLKEM1024_k, MLKEM1024PublicKey, MLKEM1024_SK_LEN, MLKEM1024_PK_LEN>;
36
37
38/* Pre-expanded keys for repeated operations */
39
40/// ML-KEM-512 Public Key with a pre-expanded public matrix A for repeated encaps operations.
41pub type MLKEM512PublicKeyExpanded = MLKEMPublicKeyExpanded<MLKEM512_k, MLKEM512PublicKey, MLKEM512_PK_LEN>;
42/// ML-KEM-512 Private Key with a pre-expanded public matrix A for repeated decaps operations.
43pub type MLKEM512PrivateKeyExpanded = MLKEMPrivateKeyExpanded<MLKEM512_k, MLKEM512PublicKey, MLKEM512PrivateKey, MLKEM512_SK_LEN, MLKEM512_PK_LEN>;
44/// ML-KEM-768 Public Key with a pre-expanded public matrix A for repeated encaps operations.
45pub type MLKEM768PublicKeyExpanded = MLKEMPublicKeyExpanded<MLKEM768_k, MLKEM768PublicKey, MLKEM768_PK_LEN>;
46/// ML-KEM-768 Private Key with a pre-expanded public matrix A for repeated decaps operations.
47pub type MLKEM768PrivateKeyExpanded = MLKEMPrivateKeyExpanded<MLKEM768_k, MLKEM768PublicKey, MLKEM768PrivateKey, MLKEM768_SK_LEN, MLKEM768_PK_LEN>;
48/// ML-KEM-1024 Public Key with a pre-expanded public matrix A for repeated encaps operations.
49pub type MLKEM1024PublicKeyExpanded = MLKEMPublicKeyExpanded<MLKEM1024_k, MLKEM1024PublicKey, MLKEM1024_PK_LEN>;
50/// ML-KEM-1024 Private Key with a pre-expanded public matrix A for repeated decaps operations.
51pub type MLKEM1024PrivateKeyExpanded = MLKEMPrivateKeyExpanded<MLKEM1024_k, MLKEM1024PublicKey,MLKEM1024PrivateKey, MLKEM1024_SK_LEN, MLKEM1024_PK_LEN>;
52
53/// An ML-KEM public key.
54#[derive(Clone)]
55pub struct MLKEMPublicKey<const k: usize, const PK_LEN: usize> {
56    t_hat: Vector<k>,
57    rho: [u8; 32],
58}
59
60/// General trait for all ML-KEM public keys types.
61pub trait MLKEMPublicKeyTrait<const k: usize, const PK_LEN: usize> : KEMPublicKey<PK_LEN> {
62    /// Algorithm 23 pkDecode(𝑝𝑘)
63    /// Reverses the procedure pkEncode.
64    /// Input: Public key 𝑝𝑘 ∈ 𝔹32+32𝑘(bitlen (𝑞−1)−𝑑).
65    /// Output: 𝜌 ∈ 𝔹32, 𝐭1 ∈ 𝑅𝑘 with coefficients in [0, 2bitlen (𝑞−1)−𝑑 − 1].
66    fn pk_decode(pk: &[u8; PK_LEN]) -> Result<Self, KEMError>;
67    /// Get a copy of the expanded public matrix A_hat
68    fn A_hat(&self) -> Matrix<k, k>;
69    /// Get the hash of the public key
70    fn compute_hash(&self) -> [u8; 32];
71}
72
73pub(crate) trait MLKEMPublicKeyInternalTrait<const k: usize, const PK_LEN: usize> : MLKEMPublicKeyTrait<k, PK_LEN> {
74    /// Not exposing a constructor publicly because you should have to get an instance either by
75    /// running a keygen, or by decoding an existing key.
76    fn new(t_hat: Vector<k>, rho: [u8; 32], ) -> Self;
77
78    /// Get a ref to t1
79    fn t_hat(&self) -> &Vector<k>;
80}
81
82impl<const k: usize, const PK_LEN: usize> MLKEMPublicKeyTrait<k, PK_LEN> for MLKEMPublicKey<k, PK_LEN> {
83    fn pk_decode(pk: &[u8; PK_LEN]) -> Result<Self, KEMError> {
84        let (pk_chunks, last_chunk) = pk.as_chunks::<POLY_BYTES>();
85
86        // that should divide evenly the remainder of the array, leaving space for rho at the end
87        debug_assert_eq!(pk_chunks.len(), k);
88        debug_assert_eq!(last_chunk.len(), 32);
89
90        let t_hat = {
91            let mut  t_hat = Vector::<k>::new();
92
93            for (t_i, pk_chunk) in t_hat.vec.iter_mut().zip(pk_chunks) {
94                t_i.coeffs.copy_from_slice(&byte_decode::<12, POLY_BYTES>(pk_chunk).coeffs);
95
96                // FIPS 203 says:
97                //      "Specifically, ByteDecode12 converts each 12-bit
98                //      segment of its input into an integer modulo 212 = 4096 and then reduces the result
99                //      modulo 𝑞. This is no longer a one-to-one operation. Indeed, some 12-bit segments could
100                //      correspond to an integer greater than 𝑞 − 1 = 3328 but less than 4096."
101                //  Since we are here in the d=12 case, we can and should check that all coeffs are less than q-1
102                for coeff in t_i.coeffs.iter() {
103                    if *coeff < 0 || *coeff >= q {
104                        return Err(KEMError::DecodingError("Invalid or corrupted key"));
105                    }
106                }
107
108            }
109
110            t_hat
111        };
112        let rho = last_chunk.try_into().unwrap();
113
114        Ok(Self::new(t_hat, rho))
115    }
116
117    fn A_hat(&self) -> Matrix<k, k> {
118        expandA(&self.rho)
119    }
120
121    fn compute_hash(&self) -> [u8; 32] {
122        let mut out = [0u8; 32];
123        let bytes_written = H::default().hash_out(&self.encode(), &mut out);
124        debug_assert_eq!(bytes_written, 32);
125        out
126    }
127}
128
129impl<const k: usize, const PK_LEN: usize> MLKEMPublicKeyInternalTrait<k, PK_LEN> for MLKEMPublicKey<k, PK_LEN> {
130    fn new(t_hat: Vector<k>, rho: [u8; 32]) -> Self {
131        Self { rho, t_hat }
132    }
133
134    fn t_hat(&self) -> &Vector<k> { &self.t_hat }
135}
136
137impl<const k: usize, const PK_LEN: usize>  KEMPublicKey<PK_LEN> for MLKEMPublicKey<k, PK_LEN> {
138    /// Encodes the public key as per FIPS 203 Algorithm 13
139    /// 19: ekPKE ← ByteEncode12(𝐭)‖𝜌
140    fn encode(&self) -> [u8; PK_LEN] {
141        let mut pk = [0u8; PK_LEN];
142        self.encode_out(&mut pk);
143
144        pk
145    }
146    /// Encodes the public key as per FIPS 203 Algorithm 13
147    /// 19: ekPKE ← ByteEncode12(𝐭)‖𝜌
148    fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
149        debug_assert_eq!(PK_LEN, 12*k*32 + 32);
150        debug_assert_eq!(POLY_BYTES, 12*32);
151
152        let (pk_chunks, last_chunk) = out.as_chunks_mut::<POLY_BYTES>();
153
154        // that should divide evenly the remainder of the array, leaving space for rho at the end
155        debug_assert_eq!(pk_chunks.len(), k);
156        debug_assert_eq!(last_chunk.len(), 32);
157
158        for (pk_chunk, t_i) in pk_chunks.into_iter().zip(&self.t_hat.vec) {
159            pk_chunk.copy_from_slice(&byte_encode::<12, POLY_BYTES>(t_i));
160        }
161        last_chunk.copy_from_slice(&self.rho);
162
163        PK_LEN
164    }
165
166    fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
167        if bytes.len() != PK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
168        let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
169        Self::pk_decode(&bytes_sized)
170    }
171}
172
173impl<const k: usize, const PK_LEN: usize> Eq for MLKEMPublicKey<k, PK_LEN> { }
174
175impl<const k: usize, const PK_LEN: usize> PartialEq for MLKEMPublicKey<k, PK_LEN> {
176    fn eq(&self, other: &Self) -> bool {
177        bouncycastle_utils::ct::ct_eq_bytes(&self.encode(), &other.encode())
178    }
179}
180
181impl<const k: usize, const PK_LEN: usize> Debug for MLKEMPublicKey<k, PK_LEN> {
182    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
183        let alg = match k {
184            2 => ML_KEM_512_NAME,
185            3 => ML_KEM_768_NAME,
186            4 => ML_KEM_1024_NAME,
187            _ => panic!("Unsupported key length"),
188        };
189        let hash = SHA3_256::new().hash(&self.encode());
190        write!(f, "MLKEMPublicKey {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
191    }
192}
193
194impl<const k: usize, const PK_LEN: usize> Display for MLKEMPublicKey<k, PK_LEN> {
195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196        let alg = match k {
197            2 => ML_KEM_512_NAME,
198            3 => ML_KEM_768_NAME,
199            4 => ML_KEM_1024_NAME,
200            _ => panic!("Unsupported key length"),
201        };
202        let hash = SHA3_256::new().hash(&self.encode());
203        write!(f, "MLKEMPublicKey {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
204    }
205}
206
207/// A fully expanded ML-KEM public key that includes the intermediate values needed for performing multiple encaps operations
208/// against the same public key, which causes the MLKEMPublicKey struct to take up more memory, but results
209/// in more efficient repeated encaps() operations.
210#[derive(Clone)]
211pub struct MLKEMPublicKeyExpanded<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize> {
212    pub(crate) ek: PK,
213    pub(crate) A_hat: Matrix<k, k>,
214}
215
216impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
217MLKEMPublicKeyInternalTrait<k, PK_LEN> for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
218    fn new(t_hat: Vector<k>, rho: [u8; 32]) -> Self {
219        let ek = PK::new(t_hat, rho);
220        let A_hat = ek.A_hat();
221
222        Self {
223            ek,
224            A_hat,
225        }
226    }
227
228    fn t_hat(&self) -> &Vector<k> {
229        self.ek.t_hat()
230    }
231}
232
233impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
234KEMPublicKey<PK_LEN> for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
235    fn encode(&self) -> [u8; PK_LEN] {
236        let mut pk = [0u8; PK_LEN];
237        self.encode_out(&mut pk);
238
239        pk
240    }
241
242    fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
243        self.ek.encode_out(out)
244    }
245
246    fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
247        if bytes.len() != PK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
248        let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
249        Self::pk_decode(&bytes_sized)
250    }
251}
252
253impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
254PartialEq for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
255    fn eq(&self, other: &Self) -> bool {
256        self.encode() == other.encode()
257    }
258}
259
260impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
261Eq for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {}
262
263impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
264Debug for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
265    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
266        let alg = match k {
267            2 => ML_KEM_512_NAME,
268            3 => ML_KEM_768_NAME,
269            4 => ML_KEM_1024_NAME,
270            _ => panic!("Unsupported key length"),
271        };
272        let hash = SHA3_256::new().hash(&self.encode());
273        write!(f, "MLKEMPublicKeyExpanded {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
274    }
275}
276
277impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
278Display for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
279    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280        let alg = match k {
281            2 => ML_KEM_512_NAME,
282            3 => ML_KEM_768_NAME,
283            4 => ML_KEM_1024_NAME,
284            _ => panic!("Unsupported key length"),
285        };
286        let hash = SHA3_256::new().hash(&self.encode());
287        write!(f, "MLKEMPublicKeyExpanded {{ alg: {}, pub_key_hash: {:x?} }}", alg, hash)
288    }
289}
290
291impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize>
292MLKEMPublicKeyTrait<k, PK_LEN> for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
293    fn pk_decode(pk: &[u8; PK_LEN]) -> Result<Self, KEMError> {
294        let ek = PK::pk_decode(pk)?;
295        let A_hat = ek.A_hat();
296        Ok(Self { ek, A_hat })
297    }
298
299    fn A_hat(&self) -> Matrix<k, k> {
300        self.A_hat.clone()
301    }
302
303    fn compute_hash(&self) -> [u8; 32] {
304        self.ek.compute_hash()
305    }
306}
307
308impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const PK_LEN: usize> From<&PK>
309for MLKEMPublicKeyExpanded<k, PK, PK_LEN> {
310    /// Fully expands the intermediate values needed for performing multiple encaps operations
311    /// against the same public key, which causes the MLKEMPublicKey struct to take up
312    fn from(ek: &PK) -> Self {
313        let A_hat = ek.A_hat();
314
315        Self {
316            ek: ek.clone(),
317            A_hat,
318        }
319    }
320}
321
322
323
324
325
326/// An ML-KEM private key.
327#[derive(Clone)]
328pub struct MLKEMPrivateKey<
329    const k: usize,
330    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
331    const SK_LEN: usize,
332    const PK_LEN: usize,
333> {
334    s_hat: Vector<k>,
335    ek: PK,
336    pk_hash: [u8; 32],
337    z: [u8; 32],
338    seed_d: Option<[u8; 32]>,
339}
340
341impl<
342    const k: usize,
343    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
344    const SK_LEN: usize,
345    const PK_LEN: usize,
346> MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
347    /// As described on Algorithm 16 line
348    ///   3: dk ← (dkPKE ‖ ek ‖ H(ek) ‖ 𝑧)
349    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
350        debug_assert_eq!(SK_LEN, /* dk_pke*/12*k*32 + /*ek*/PK_LEN + /*H(ek)*/32 + /*z*/32);
351
352        let mut pos = 0usize;
353
354        /* dk_pke */
355        // Alg 13; line 20: dkPKE ← ByteEncode12(𝐬)
356        for i in 0..k {
357            out[i*POLY_BYTES .. (i+1)*POLY_BYTES].copy_from_slice(&byte_encode::<12, POLY_BYTES>(
358                &self.s_hat[i]
359            ));
360        }
361        pos += k * POLY_BYTES;
362
363        /* ek */
364        // Alg 13; line 19: ekPKE ← ByteEncode12(𝐭)‖𝜌
365        debug_assert_eq!(self.ek.encode().len(), PK_LEN);
366        out[pos .. pos + PK_LEN].copy_from_slice(&self.ek.encode());
367        pos += PK_LEN;
368
369        /* H(ek) */
370        out[pos .. pos + 32].copy_from_slice(&self.pk_hash);
371        pos += 32;
372
373        /* z */
374        out[pos .. pos + 32].copy_from_slice(&self.z);
375
376        debug_assert_eq!(pos + 32, SK_LEN);
377        SK_LEN
378    }
379}
380
381/// General trait for all ML-KEM private keys types.
382pub trait MLKEMPrivateKeyTrait<
383    const k: usize,
384    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
385    const SK_LEN: usize,
386    const PK_LEN: usize> : KEMPrivateKey<SK_LEN> {
387    /// Get a ref to the seed, if there is one stored with this private key
388    fn seed(&self) -> Option<KeyMaterial<64>>;
389
390    /// This is a partial implementation of keygen_internal(), and probably not allowed in FIPS mode.
391    fn pk(&self) -> &PK;
392    /// Get a ref to the stored public key hash.
393    fn pk_hash(&self) -> &[u8; 32];
394    /// Decode the private key.
395    fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, KEMError>;
396}
397
398pub(crate) trait MLKEMPrivateKeyInternalTrait<const k: usize, PK: MLKEMPublicKeyTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize> {
399    /// Not exposing a constructor publicly because you should have to get an instance either by
400    /// running a keygen, or by decoding an existing key.
401    fn new(
402        s_hat: Vector<k>,
403        ek: PK,
404        h: [u8; 32],
405        z: [u8; 32],
406        seed_d: Option<[u8; 32]>,
407    ) -> Self;
408
409    /// Get a ref to s_hat
410    fn s_hat(&self) -> &Vector<k>;
411
412    fn z(&self) -> &[u8; 32];
413}
414
415
416impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
417    MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
418    fn seed(&self) -> Option<KeyMaterial<64>> {
419        if self.seed_d.is_none() {
420            None
421        } else {
422            let mut tmp = [0u8; 64];
423            tmp[..32].copy_from_slice(&self.seed_d.unwrap());
424            tmp[32..].copy_from_slice(&self.z);
425            let mut seed = KeyMaterial::<64>::from_bytes_as_type(&tmp, KeyType::Seed).unwrap();
426            seed.allow_hazardous_operations();
427            seed.set_security_strength( match k {
428                2 => SecurityStrength::_128bit,
429                3 => SecurityStrength::_192bit,
430                4 => SecurityStrength::_256bit,
431                _ => unreachable!("Invalid mlkem param set"),
432            }).unwrap();
433            seed.drop_hazardous_operations();
434            Some(seed)
435        }
436    }
437
438    fn pk(&self) -> &PK {
439        &self.ek
440    }
441
442    fn pk_hash(&self) -> &[u8; 32] {
443        &self.pk_hash
444    }
445
446    fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, KEMError> {
447        debug_assert_eq!(SK_LEN, /* dk_pke*/12*k*32 + /*ek*/PK_LEN + /*H(ek)*/32 + /*z*/32);
448
449        let mut pos = 0usize;
450
451        /* dk_pke */
452        let mut s_hat = Vector::<k>::new();
453        // for (s_i, sk_chunk) in s_hat.0.iter_mut().zip(sk_chunks) {
454        for i in 0..k {
455            s_hat[i] = byte_decode::<12, POLY_BYTES>(
456                sk[i*POLY_BYTES .. (i+1)*POLY_BYTES].try_into().unwrap()
457            );
458
459            // FIPS 203 says:
460            //      "Specifically, ByteDecode12 converts each 12-bit
461            //      segment of its input into an integer modulo 212 = 4096 and then reduces the result
462            //      modulo 𝑞. This is no longer a one-to-one operation. Indeed, some 12-bit segments could
463            //      correspond to an integer greater than 𝑞 − 1 = 3328 but less than 4096."
464            //  Since we are here in the d=12 case, we can and should check that all coeffs are less than q-1
465            for coeff in s_hat[i].coeffs.iter() {
466                if *coeff < -q || *coeff >= q {
467                    return Err(KEMError::DecodingError("Invalid or corrupted key"));
468                }
469            }
470        }
471        pos += k * POLY_BYTES;
472
473        /* ek */
474        let ek = PK::pk_decode(sk[pos .. pos + PK_LEN].try_into().unwrap())?;
475        pos += PK_LEN;
476
477        /* H(ek) */
478        let h_pk: [u8; 32] = sk[pos .. pos + 32].try_into().unwrap();
479        pos += 32;
480
481        // This satisfies the "Decapsulation input check #3) in FIPS 203 section 7.3.
482        // We're doing it here on key load rather than as part of the decapsulation for performance
483        // because if you're doing multiple decapsulations, you only need to perform this check once. 
484        if h_pk != ek.compute_hash() {
485            return Err(KEMError::ConsistencyCheckFailed("Corrupted private key: computed hash of ek != h_ek stored in private key"));
486        }
487
488        /* z */
489        let z: [u8; 32] = sk[pos .. pos + 32].try_into().unwrap();
490
491        Ok(Self::new(s_hat, ek, h_pk, z, None))
492    }
493}
494
495impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
496    MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN> for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
497    /// Note to future maintainers: FIPS 203 section 7.3 requires that ek be hashed and compared to pk_hash.
498    fn new(
499        s_hat: Vector<k>,
500        ek: PK,
501        pk_hash: [u8; 32],
502        z: [u8; 32],
503        seed_d: Option<[u8; 32]>,
504    ) -> Self {
505        Self {
506            s_hat,
507            ek,
508            pk_hash,
509            z,
510            seed_d: seed_d.clone(),
511        }
512    }
513
514    fn s_hat(&self) -> &Vector<k> { &self.s_hat }
515
516    fn z(&self) -> &[u8; 32] { &self.z }
517}
518
519impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize
520> KEMPrivateKey<SK_LEN> for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {
521    fn encode(&self) -> [u8; SK_LEN] {
522        let mut out = [0u8; SK_LEN];
523        self.encode_out(&mut out);
524
525        out
526    }
527
528    fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
529        self.sk_encode_out(out)
530    }
531
532    fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
533        if bytes.len() != SK_LEN { return Err(KEMError::DecodingError("Provided key bytes are the incorrect length")) }
534        let bytes_sized: [u8; SK_LEN] = bytes[..SK_LEN].try_into().unwrap();
535
536        Self::sk_decode(&bytes_sized)
537    }
538}
539
540impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
541    Eq for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {}
542
543impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
544    PartialEq for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
545{
546    fn eq(&self, other: &Self) -> bool {
547        let self_encoded = self.encode();
548        let other_encoded = other.encode();
549        bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
550    }
551}
552
553impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
554Secret for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN> {}
555
556/// Debug impl mainly to prevent the secret key from being printed in logs.
557impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
558    fmt::Debug for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
559{
560    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
561            let alg = match k {
562                2 => ML_KEM_512_NAME,
563                3 => ML_KEM_768_NAME,
564                4 => ML_KEM_1024_NAME,
565                _ => panic!("Unsupported key length"),
566            };
567        write!(
568            f,
569            "MLKEMPrivateKey {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
570            alg,
571            self.pk_hash,
572            self.seed_d.is_some(),
573        )
574    }
575}
576
577/// Display impl mainly to prevent the secret key from being printed in logs.
578impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
579    Display for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
580{
581    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
582        let alg = match k {
583            2 => ML_KEM_512_NAME,
584            3 => ML_KEM_768_NAME,
585            4 => ML_KEM_1024_NAME,
586            _ => panic!("Unsupported key length"),
587        };
588        write!(
589            f,
590            "MLKEMPrivateKey {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
591            alg,
592            self.pk_hash,
593            self.seed_d.is_some(),
594        )
595    }
596}
597
598/// Zeroizing drop
599impl<const k: usize, PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>, const SK_LEN: usize, const PK_LEN: usize>
600Drop for MLKEMPrivateKey<k, PK, SK_LEN, PK_LEN>
601{
602    fn drop(&mut self) {
603        // s_hat, has its own zeroizing drop
604        self.pk_hash.fill(0u8);
605        self.z.fill(0u8);
606        if self.seed_d.is_some() { self.seed_d.as_mut().unwrap().fill(0u8); }
607    }
608}
609
610
611
612/// A fully expanded ML-KEM private key that includes the intermediate values needed for performing
613/// multiple decaps operations with the same private key, which causes the private key struct to
614/// take up more memory, but results in more efficient repeated decaps() operations.
615#[derive(Clone)]
616pub struct MLKEMPrivateKeyExpanded<
617    const k: usize,
618    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
619    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
620    const SK_LEN: usize,
621    const PK_LEN: usize
622> {
623    _phantom: core::marker::PhantomData<PK>,
624    pub(crate) dk: SK,
625    pub(crate) A_hat: Matrix<k,k>,
626}
627
628impl<
629    const k: usize,
630    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
631    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
632    const SK_LEN: usize,
633    const PK_LEN: usize
634> From<&SK>
635for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
636    /// Fully expands the intermediate values needed for performing multiple encaps operations
637    /// against the same public key, which causes the MLKEMPublicKey struct to take up
638    fn from(dk: &SK) -> Self {
639        let A_hat = dk.pk().A_hat();
640
641        Self {
642            _phantom: core::marker::PhantomData,
643            dk: dk.clone(),
644            A_hat,
645        }
646    }
647}
648
649impl<
650    const k: usize,
651    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
652    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
653    const SK_LEN: usize,
654    const PK_LEN: usize
655> KEMPrivateKey<SK_LEN> for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
656    fn encode(&self) -> [u8; SK_LEN] {
657        self.dk.encode()
658    }
659
660    fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
661        self.dk.encode_out(out)
662    }
663
664    fn from_bytes(bytes: &[u8]) -> Result<Self, KEMError> {
665        Ok(Self::from(&SK::from_bytes(bytes)?))
666    }
667}
668
669impl<
670    const k: usize,
671    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
672    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
673    const SK_LEN: usize,
674    const PK_LEN: usize
675> PartialEq for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
676    fn eq(&self, other: &Self) -> bool {
677        self.dk.eq(&other.dk)
678    }
679}
680
681impl<
682    const k: usize,
683    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
684    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
685    const SK_LEN: usize,
686    const PK_LEN: usize
687> Eq for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {}
688
689impl<
690    const k: usize,
691    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
692    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
693    const SK_LEN: usize,
694    const PK_LEN: usize
695> Secret for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {}
696
697impl<
698    const k: usize,
699    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
700    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
701    const SK_LEN: usize,
702    const PK_LEN: usize
703> Drop for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
704    fn drop(&mut self) {
705        // Nothing to do since self.sk already impls zeroizing Drop
706    }
707}
708
709impl<
710    const k: usize,
711    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
712    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
713    const SK_LEN: usize,
714    const PK_LEN: usize
715> Debug for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
716    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
717        let alg = match k {
718            2 => ML_KEM_512_NAME,
719            3 => ML_KEM_768_NAME,
720            4 => ML_KEM_1024_NAME,
721            _ => panic!("Unsupported key length"),
722        };
723        write!(
724            f,
725            "MLKEMPrivateKeyExpanded {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
726            alg,
727            self.dk.pk().compute_hash(),
728            self.dk.seed().is_some(),
729        )
730    }
731}
732
733impl<
734    const k: usize,
735    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
736    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
737    const SK_LEN: usize,
738    const PK_LEN: usize
739> Display for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
740    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
741        let alg = match k {
742            2 => ML_KEM_512_NAME,
743            3 => ML_KEM_768_NAME,
744            4 => ML_KEM_1024_NAME,
745            _ => panic!("Unsupported key length"),
746        };
747        write!(
748            f,
749            "MLKEMPrivateKeyExpanded {{ alg: {}, pub_key_hash: {:x?}, has_seed: {} }}",
750            alg,
751            self.dk.pk().compute_hash(),
752            self.dk.seed().is_some(),
753        )
754    }
755}
756
757impl<
758    const k: usize,
759    PK: MLKEMPublicKeyInternalTrait<k, PK_LEN>,
760    SK: MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> + MLKEMPrivateKeyInternalTrait<k, PK, SK_LEN, PK_LEN>,
761    const SK_LEN: usize,
762    const PK_LEN: usize
763> MLKEMPrivateKeyTrait<k, PK, SK_LEN, PK_LEN> for MLKEMPrivateKeyExpanded<k, PK, SK, SK_LEN, PK_LEN> {
764    fn seed(&self) -> Option<KeyMaterial<64>> {
765        self.dk.seed()
766    }
767
768    fn pk(&self) -> &PK {
769        self.dk.pk()
770    }
771
772    fn pk_hash(&self) -> &[u8; 32] {
773        &self.dk.pk_hash()
774    }
775
776    fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, KEMError> {
777        let dk = SK::sk_decode(sk)?;
778        let A_hat = dk.pk().A_hat();
779
780        Ok(Self {
781            _phantom: core::marker::PhantomData,
782            dk: dk.clone(),
783            A_hat,
784        })
785    }
786}