Skip to main content

bouncycastle_mldsa/
mldsa_keys.rs

1use crate::aux_functions::{
2    bit_pack_eta, bit_pack_t0, bit_unpack_eta, bit_unpack_t0, bitlen_eta, expandA, power_2_round_vec,
3    simple_bit_pack_t1, simple_bit_unpack_t1
4};
5use crate::matrix::Vector;
6use crate::mldsa::H;
7use crate::{ML_DSA_44_NAME, ML_DSA_65_NAME, ML_DSA_87_NAME};
8use crate::mldsa::{MLDSA44_ETA, MLDSA44_PK_LEN, MLDSA44_SK_LEN, MLDSA44_k, MLDSA44_l};
9use crate::mldsa::{MLDSA65_ETA, MLDSA65_PK_LEN, MLDSA65_SK_LEN, MLDSA65_k, MLDSA65_l};
10use crate::mldsa::{MLDSA87_ETA, MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_k, MLDSA87_l};
11use crate::mldsa::{POLY_T0PACKED_LEN, POLY_T1PACKED_LEN, SEED_LEN};
12use bouncycastle_core_interface::errors::SignatureError;
13use bouncycastle_core_interface::key_material::KeyMaterialSized;
14use bouncycastle_core_interface::traits::{Secret, SignaturePrivateKey, SignaturePublicKey, XOF};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17
18// imports just for docs
19#[allow(unused_imports)]
20use crate::mldsa::MLDSATrait;
21
22
23
24/* Pub Types */
25
26/// ML-DSA-44 Public Key
27pub type MLDSA44PublicKey = MLDSAPublicKey<MLDSA44_k, MLDSA44_PK_LEN>;
28/// ML-DSA-44 Private Key
29pub type MLDSA44PrivateKey = MLDSAPrivateKey<MLDSA44_k, MLDSA44_l, MLDSA44_ETA, MLDSA44_SK_LEN, MLDSA44_PK_LEN>;
30/// ML-DSA-65 Public Key
31pub type MLDSA65PublicKey = MLDSAPublicKey<MLDSA65_k, MLDSA65_PK_LEN>;
32/// ML-DSA-65 Private Key
33pub type MLDSA65PrivateKey = MLDSAPrivateKey<MLDSA65_k, MLDSA65_l, MLDSA65_ETA, MLDSA65_SK_LEN, MLDSA65_PK_LEN>;
34/// ML-DSA-87 Public Key
35pub type MLDSA87PublicKey = MLDSAPublicKey<MLDSA87_k, MLDSA87_PK_LEN>;
36/// ML-DSA-87 Private Key
37pub type MLDSA87PrivateKey = MLDSAPrivateKey<MLDSA87_k, MLDSA87_l, MLDSA87_ETA, MLDSA87_SK_LEN, MLDSA87_PK_LEN>;
38
39/// An ML-DSA public key.
40#[derive(Clone)]
41pub struct MLDSAPublicKey<const k: usize, const PK_LEN: usize> {
42    rho: [u8; SEED_LEN],
43    t1: Vector<k>,
44}
45
46/// General trait for all ML-DSA public keys types.
47pub trait MLDSAPublicKeyTrait<const k: usize, const PK_LEN: usize> : SignaturePublicKey {
48    /// Algorithm 22 pkEncode(𝜌, 𝐭1)
49    /// Encodes a public key for ML-DSA into a byte string.
50    /// Input:𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
51    /// Output: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
52    fn pk_encode(&self) -> [u8; PK_LEN];
53
54    /// Algorithm 23 pkDecode(π‘π‘˜)
55    /// Reverses the procedure pkEncode.
56    /// Input: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
57    /// Output: 𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
58    fn pk_decode(pk: &[u8; PK_LEN]) -> Self;
59
60    /// Compute the public key hash (tr) from the public key.
61    ///
62    /// This is exposed as a public API for a few reasons:
63    /// 1. `tr` is required for some external-prehashing schemes such as the so-called "external mu" signing mode.
64    /// 2. `tr` is the canonical fingerprint of an ML-DSA public key, so would be an appropriate value
65    ///     to use, for example, to build a public key lookup or deny-listing table.
66    fn compute_tr(&self) -> [u8; 64];
67}
68
69pub(crate) trait MLDSAPublicKeyInternalTrait<const k: usize, const PK_LEN: usize> {
70    /// Not exposing a constructor publicly because you should have to get an instance either by
71    /// running a keygen, or by decoding an existing key.
72    fn new(rho: &[u8; SEED_LEN], t1: &Vector<k>) -> Self;
73
74    /// Get a ref to rho
75    fn rho(&self) -> &[u8; 32];
76
77    /// Get a ref to t1
78    fn t1(&self) -> &Vector<k>;
79}
80
81impl<const k: usize, const PK_LEN: usize> MLDSAPublicKeyTrait<k, PK_LEN> for MLDSAPublicKey<k, PK_LEN> {
82    fn pk_encode(&self) -> [u8; PK_LEN] {
83        let mut pk = [0u8; PK_LEN];
84
85        pk[0..SEED_LEN].copy_from_slice(&self.rho);
86
87        let (pk_chunks, last_chunk) = pk[SEED_LEN..].as_chunks_mut::<POLY_T1PACKED_LEN>();
88
89        // that should divide evenly the remainder of the array
90        debug_assert_eq!(pk_chunks.len(), k);
91        debug_assert_eq!(last_chunk.len(), 0);
92
93        // Potential optimization point:
94        // these loops have no interaction between sequential iterations,
95        // so could be replaced with some kind of threaded for construct.
96        // This should be done carefully against benchmarks to make sure we're actually making a
97        // performance improvement, and making sure that whatever multi-threading construct is used
98        // falls back to sequential execution when not available (such as a no_std build).
99        for (pk_chunk, t1_i) in pk_chunks.into_iter().zip(&self.t1.vec) {
100            pk_chunk.copy_from_slice(&simple_bit_pack_t1(&t1_i));
101        }
102
103        pk
104    }
105
106    fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
107        let rho = pk[0..32].try_into().unwrap();
108        let mut t1 = Vector::<k>::new();
109
110        let (pk_chunks, last_chunk) = pk[32..].as_chunks::<POLY_T1PACKED_LEN>();
111
112        // that should divide evenly the remainder of the array
113        debug_assert_eq!(pk_chunks.len(), k);
114        debug_assert_eq!(last_chunk.len(), 0);
115
116        for (t1_i, pk_chunk) in t1.vec.iter_mut().zip(pk_chunks) {
117            t1_i.0.copy_from_slice(&simple_bit_unpack_t1(pk_chunk).0);
118        }
119
120        Self::new(&rho, &t1)
121    }
122
123    fn compute_tr(&self) -> [u8; 64] {
124        let mut tr = [0u8; 64];
125        H::new().hash_xof_out(&self.pk_encode(), &mut tr);
126
127        tr
128    }
129}
130
131impl<const k: usize, const PK_LEN: usize> MLDSAPublicKeyInternalTrait<k, PK_LEN> for MLDSAPublicKey<k, PK_LEN> {
132    fn new(rho: &[u8; SEED_LEN], t1: &Vector<k>) -> Self {
133        Self { rho: rho.clone(), t1: t1.clone() }
134    }
135
136    fn rho(&self) -> &[u8; 32] { &self.rho }
137
138    fn t1(&self) -> &Vector<k> { &self.t1 }
139}
140
141impl<const k: usize, const PK_LEN: usize>  SignaturePublicKey for MLDSAPublicKey<k, PK_LEN> {
142    fn encode(&self) -> Vec<u8> {
143        Vec::from(self.pk_encode())
144    }
145
146    fn encode_out(&self, out: &mut [u8]) -> Result<usize, SignatureError> {
147        if out.len() < PK_LEN {
148            Err(SignatureError::EncodingError("Output buffer too small"))
149        } else {
150            let tmp = self.pk_encode();
151            debug_assert_eq!(tmp.len(), PK_LEN);
152            out[..PK_LEN].copy_from_slice(&tmp);
153            Ok(PK_LEN)
154        }
155    }
156
157    fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
158        if bytes.len() != PK_LEN { return Err(SignatureError::DecodingError("Provided key bytes are the incorrect length")) }
159        let sized_bytes: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
160        Ok(Self::pk_decode(&sized_bytes))
161    }
162}
163
164impl<const k: usize, const PK_LEN: usize> Eq for MLDSAPublicKey<k, PK_LEN> { }
165
166impl<const k: usize, const PK_LEN: usize> PartialEq for MLDSAPublicKey<k, PK_LEN> {
167    fn eq(&self, other: &Self) -> bool {
168        let self_encoded = self.pk_encode();
169        let other_encoded = other.pk_encode();
170        bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
171    }
172}
173
174impl<const k: usize, const PK_LEN: usize> fmt::Debug for MLDSAPublicKey<k, PK_LEN> {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
176        let alg = match k {
177            4 => ML_DSA_44_NAME,
178            6 => ML_DSA_65_NAME,
179            8 => ML_DSA_87_NAME,
180            _ => panic!("Unsupported key length"),
181        };
182        write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
183    }
184}
185
186impl<const k: usize, const PK_LEN: usize> Display for MLDSAPublicKey<k, PK_LEN> {
187    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
188        let alg = match k {
189            4 => ML_DSA_44_NAME,
190            6 => ML_DSA_65_NAME,
191            8 => ML_DSA_87_NAME,
192            _ => panic!("Unsupported key length"),
193        };
194        write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
195    }
196}
197
198/// An ML-DSA private key.
199#[derive(Clone)]
200pub struct MLDSAPrivateKey<
201    const k: usize,
202    const l: usize,
203    const eta: usize,
204    const SK_LEN: usize,
205    const PK_LEN: usize,
206> {
207    rho: [u8; 32],
208    K: [u8; 32],
209    tr: [u8; 64],
210    s1: Vector<l>,
211    s2: Vector<k>,
212    t0: Vector<k>,
213    seed: Option<KeyMaterialSized<32>>,
214}
215
216/// General trait for all ML-DSA private keys types.
217pub trait MLDSAPrivateKeyTrait<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize> : SignaturePrivateKey {
218    /// Get a ref to the seed, if there is one stored with this private key
219    fn seed(&self) -> &Option<KeyMaterialSized<32>>;
220
221    /// Get a ref to the key hash `tr`.
222    fn tr(&self) -> &[u8; 64];
223
224    /// This is a partial implementation of keygen_internal(), and probably not allowed in FIPS mode.
225    fn derive_pk(&self) -> MLDSAPublicKey<k, PK_LEN>;
226    /// Algorithm 24 skEncode(𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0)
227    /// Encodes a secret key for ML-DSA into a byte string.
228    /// Input: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 , 𝐬1 ∈ 𝑅ℓ with coefficients in [βˆ’πœ‚, πœ‚], 𝐬2 ∈ π‘…π‘˜ with
229    /// coefficients in [βˆ’πœ‚, πœ‚], 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
230    /// Output: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((π‘˜+β„“)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
231    fn sk_encode(&self) -> [u8; SK_LEN];
232    /// Algorithm 24 skEncode(𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0)
233    /// Encodes a secret key for ML-DSA into a byte string.
234    /// Input: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 , 𝐬1 ∈ 𝑅ℓ with coefficients in [βˆ’πœ‚, πœ‚], 𝐬2 ∈ π‘…π‘˜ with
235    /// coefficients in [βˆ’πœ‚, πœ‚], 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
236    /// Output: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((π‘˜+β„“)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
237    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize;
238    /// Algorithm 25 skDecode(π‘ π‘˜)
239    /// Reverses the procedure skEncode.
240    /// Input: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((β„“+π‘˜)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
241    /// Output: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 ,
242    /// 𝐬1 ∈ 𝑅ℓ , 𝐬2 ∈ π‘…π‘˜ , 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
243    ///
244    /// Note: this object contains only the simple decoding routine to unpack a semi-expanded key.
245    /// See [MLDSATrait] for key generation functions, including derive-from-seed and consistency-check functions.
246    fn sk_decode(sk: &[u8; SK_LEN]) -> Self;
247}
248
249pub(crate) trait MLDSAPrivateKeyInternalTrait<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize> {
250    /// Not exposing a constructor publicly because you should have to get an instance either by
251    /// running a keygen, or by decoding an existing key.
252    fn new(
253        rho: &[u8; 32],
254        K: &[u8; 32],
255        tr: &[u8; 64],
256        s1: &Vector<l>,
257        s2: &Vector<k>,
258        t0: &Vector<k>,
259        seed: Option<KeyMaterialSized<32>>,
260    ) -> Self;
261    /// Get a ref to rho
262    fn rho(&self) -> &[u8; 32];
263
264    /// Get a ref to K
265    fn K(&self) -> &[u8; 32];
266
267    /// Get a ref to tr
268    // don't need here because there's one in the public trait
269    // fn tr(&self) -> &[u8; 64];
270
271    /// Get a ref to s1
272    fn s1(&self) -> &Vector<l>;
273
274    /// Get a ref to s2
275    fn s2(&self) -> &Vector<k>;
276
277    /// Get a ref to t0
278    fn t0(&self) -> &Vector<k>;
279}
280
281
282impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
283    MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN> for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {
284    fn seed(&self) -> &Option<KeyMaterialSized<32>> { &self.seed }
285
286    fn tr(&self) -> &[u8; 64] {
287        &self.tr
288    }
289
290    fn derive_pk(&self) -> MLDSAPublicKey<k, PK_LEN> {
291
292        // 5: 𝐭 ← NTTβˆ’1(𝐀 ∘ NTT(𝐬1)) + 𝐬2
293        //   β–· compute 𝐭 = 𝐀𝐬1 + 𝐬2
294        let mut s1_hat = self.s1.clone();
295        s1_hat.ntt();
296
297        let mut t = { // scope for A_hat
298            let A_hat = expandA::<k, l>(&self.rho);
299
300            // 3: 𝐀 ← ExpandA(𝜌) β–· 𝐀 is generated and stored in NTT representation as 𝐀
301            let mut t_ntt = A_hat.matrix_vector_ntt(&s1_hat);
302            t_ntt.inv_ntt();
303            t_ntt
304        };
305        t.add_vector_ntt(&self.s2);
306        t.conditional_add_q();
307
308        // 6: (𝐭1, 𝐭0) ← Power2Round(𝐭)
309        //   β–· compress 𝐭
310        //   β–· PowerTwoRound is applied componentwise (see explanatory text in Section 7.4)
311        let (t1, _) = power_2_round_vec::<k>(&t);
312
313        MLDSAPublicKey::<k, PK_LEN>::new(&self.rho, &t1)
314    }
315
316    fn sk_encode(&self) -> [u8; SK_LEN] {
317        let mut out = [0u8; SK_LEN];
318        let bytes_written = self.sk_encode_out(&mut out);
319        debug_assert_eq!(bytes_written, SK_LEN);
320        out
321    }
322
323    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
324        // bytes written counter
325        let mut off: usize = 0;
326
327        out[0..32].copy_from_slice(&self.rho);
328        out[32..64].copy_from_slice(&self.K);
329        out[64..128].copy_from_slice(&self.tr);
330        off += 128;
331
332        let mut buf = [0u8; 32 * 4]; // largest possible buffer
333        let eta_pack_len = bitlen_eta(eta);
334
335        let sk_chunks = out[off..off + l * bitlen_eta(eta)].chunks_mut(bitlen_eta(eta));
336        debug_assert_eq!(sk_chunks.len(), l);
337        for (sk_chunk, s1_i) in sk_chunks.into_iter().zip(&self.s1.vec) {
338            bit_pack_eta::<eta>(s1_i, &mut buf);
339            sk_chunk.copy_from_slice(&buf[..eta_pack_len]);
340        }
341        off += l * bitlen_eta(eta);
342
343        let sk_chunks = out[off..off + k * bitlen_eta(eta)].chunks_mut(bitlen_eta(eta));
344        debug_assert_eq!(sk_chunks.len(), k);
345        for (sk_chunk, s2_i) in sk_chunks.into_iter().zip(&self.s2.vec) {
346            bit_pack_eta::<eta>(s2_i, &mut buf);
347            sk_chunk.copy_from_slice(&buf[..eta_pack_len]);
348        }
349        off += k * bitlen_eta(eta);
350
351        let sk_chunks = out[off..off + k * POLY_T0PACKED_LEN].chunks_mut(POLY_T0PACKED_LEN);
352        debug_assert_eq!(sk_chunks.len(), k);
353        for (sk_chunk, t0_i) in sk_chunks.into_iter().zip(&self.t0.vec) {
354            sk_chunk.copy_from_slice(&bit_pack_t0(t0_i));
355        }
356
357        SK_LEN
358    }
359    fn sk_decode(sk: &[u8; SK_LEN]) -> Self {
360        let rho = sk[0..32].try_into().unwrap();
361        let K = sk[32..64].try_into().unwrap();
362        let tr = sk[64..128].try_into().unwrap();
363        let mut s1 = Vector::<l>::new();
364        let mut s2 = Vector::<k>::new();
365        let mut t0 = Vector::<k>::new();
366        let mut off = 128;
367
368        // unpack s1
369        // let mut i: usize = 0;
370        let sk_chunks = sk[128..128 + (l * bitlen_eta(eta))].chunks(bitlen_eta(eta));
371        debug_assert_eq!(sk_chunks.len(), l);
372        for (s1_i, sk_chunk) in s1.vec.iter_mut().zip(sk_chunks) {
373            s1_i.0.copy_from_slice(&bit_unpack_eta::<eta>(&sk_chunk).0);
374        }
375        off += l * bitlen_eta(eta);
376
377        // unpack s2
378        let sk_chunks = sk[off..off + (k * bitlen_eta(eta))].chunks(bitlen_eta(eta));
379        debug_assert_eq!(sk_chunks.len(), k);
380        for (s2_i, sk_chunk) in s2.vec.iter_mut().zip(sk_chunks) {
381            s2_i.0.copy_from_slice(&bit_unpack_eta::<eta>(&sk_chunk).0);
382        }
383        off += k * bitlen_eta(eta);
384
385        // unpack t0
386        let (sk_chunks, last_chunk) =
387            sk[off..off + (k * POLY_T0PACKED_LEN)].as_chunks::<POLY_T0PACKED_LEN>();
388
389        // that should divide evenly the remainder of the array
390        debug_assert_eq!(sk_chunks.len(), k);
391        debug_assert_eq!(last_chunk.len(), 0);
392
393        for (t0_i, sk_chunk) in t0.vec.iter_mut().zip(sk_chunks) {
394            t0_i.0.copy_from_slice(&bit_unpack_t0(sk_chunk).0);
395        }
396
397        Self::new(&rho, &K, &tr, &s1, &s2, &t0, None)
398    }
399}
400
401impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
402    MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN> for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {
403    fn new(
404        rho: &[u8; 32],
405        K: &[u8; 32],
406        tr: &[u8; 64],
407        s1: &Vector<l>,
408        s2: &Vector<k>,
409        t0: &Vector<k>,
410        seed: Option<KeyMaterialSized<32>>,
411    ) -> Self {
412        Self {
413            rho: rho.clone(),
414            K: K.clone(),
415            tr: tr.clone(),
416            s1: s1.clone(),
417            s2: s2.clone(),
418            t0: t0.clone(),
419            seed: seed.clone(),
420        }
421    }
422
423    fn rho(&self) -> &[u8; 32] { &self.rho }
424
425    fn K(&self) -> &[u8; 32] { &self.K }
426
427    // don't need here because there's one in the public trait
428    // fn tr(&self) -> &[u8; 64] { &self.tr }
429
430    fn s1(&self) -> &Vector<l> { &self.s1 }
431
432    fn s2(&self) -> &Vector<k> { &self.s2 }
433
434    fn t0(&self) -> &Vector<k> { &self.t0 }
435}
436
437impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
438    SignaturePrivateKey for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {
439    fn encode(&self) -> Vec<u8> {
440        self.sk_encode().to_vec()
441    }
442
443    fn encode_out(&self, out: &mut [u8]) -> Result<usize, SignatureError> {
444        if out.len() < SK_LEN {
445            Err(SignatureError::EncodingError("Output buffer too small"))
446        } else {
447            let out_sized: &mut [u8; SK_LEN] = out[..SK_LEN].as_mut().try_into().unwrap();
448            Ok(self.sk_encode_out(out_sized))
449        }
450    }
451
452    fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
453        if bytes.len() != SK_LEN { return Err(SignatureError::DecodingError("Provided key bytes are the incorrect length")) }
454        let sized_bytes: [u8; SK_LEN] = bytes[..SK_LEN].try_into().unwrap();
455        Ok(Self::sk_decode(&sized_bytes))
456    }
457}
458
459impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
460    Eq for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {}
461
462impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
463    PartialEq for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
464{
465    fn eq(&self, other: &Self) -> bool {
466        let self_encoded = self.sk_encode();
467        let other_encoded = other.sk_encode();
468        bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
469    }
470}
471
472impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
473Secret for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {}
474
475/// Debug impl mainly to prevent the secret key from being printed in logs.
476impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
477    fmt::Debug for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
478{
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
480        let alg = match k {
481            4 => ML_DSA_44_NAME,
482            6 => ML_DSA_65_NAME,
483            8 => ML_DSA_87_NAME,
484            _ => panic!("Unsupported key length"),
485        };
486        write!(
487            f,
488            "MLDSAPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
489            alg,
490            self.tr,
491            self.seed.is_some(),
492        )
493    }
494}
495
496/// Display impl mainly to prevent the secret key from being printed in logs.
497impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
498    Display for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
499{
500    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
501        let alg = match k {
502            4 => ML_DSA_44_NAME,
503            6 => ML_DSA_65_NAME,
504            8 => ML_DSA_87_NAME,
505            _ => panic!("Unsupported key length"),
506        };
507        write!(
508            f,
509            "MLDSAPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
510            alg,
511            self.tr,
512            self.seed.is_some(),
513        )
514    }
515}
516
517/// Zeroizing drop
518impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
519Drop for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
520{
521    fn drop(&mut self) {
522        self.K.fill(0u8);
523        // s1, s2, t0, seed have their own zeroizing drop
524    }
525}