1use crate::aux_functions::{
2 bit_pack_eta, bit_pack_t0, bit_unpack_eta, bit_unpack_t0, bitlen_eta, expandA,
3 power_2_round_vec, simple_bit_pack_t1, simple_bit_unpack_t1,
4};
5use crate::matrix::{Matrix, Vector};
6use crate::mldsa::H;
7use crate::mldsa::{MLDSA44_ETA, MLDSA44_PK_LEN, MLDSA44_SK_LEN, MLDSA44_k, MLDSA44_l};
8use crate::mldsa::{MLDSA65_ETA, MLDSA65_PK_LEN, MLDSA65_SK_LEN, MLDSA65_k, MLDSA65_l};
9use crate::mldsa::{MLDSA87_ETA, MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_k, MLDSA87_l};
10use crate::mldsa::{POLY_T0PACKED_LEN, POLY_T1PACKED_LEN};
11use crate::{ML_DSA_44_NAME, ML_DSA_65_NAME, ML_DSA_87_NAME};
12use bouncycastle_core::errors::SignatureError;
13use bouncycastle_core::key_material::KeyMaterial;
14use bouncycastle_core::traits::{Secret, SignaturePrivateKey, SignaturePublicKey, XOF};
15use core::fmt;
16use core::fmt::{Debug, Display, Formatter};
17
18#[allow(unused_imports)]
20use crate::mldsa::MLDSATrait;
21
22pub type MLDSA44PublicKey = MLDSAPublicKey<MLDSA44_k, MLDSA44_l, MLDSA44_PK_LEN>;
26pub type MLDSA44PrivateKey =
28 MLDSAPrivateKey<MLDSA44_k, MLDSA44_l, MLDSA44_ETA, MLDSA44_SK_LEN, MLDSA44_PK_LEN>;
29pub type MLDSA65PublicKey = MLDSAPublicKey<MLDSA65_k, MLDSA65_l, MLDSA65_PK_LEN>;
31pub type MLDSA65PrivateKey =
33 MLDSAPrivateKey<MLDSA65_k, MLDSA65_l, MLDSA65_ETA, MLDSA65_SK_LEN, MLDSA65_PK_LEN>;
34pub type MLDSA87PublicKey = MLDSAPublicKey<MLDSA87_k, MLDSA87_l, MLDSA87_PK_LEN>;
36pub type MLDSA87PrivateKey =
38 MLDSAPrivateKey<MLDSA87_k, MLDSA87_l, MLDSA87_ETA, MLDSA87_SK_LEN, MLDSA87_PK_LEN>;
39
40pub type MLDSA44PublicKeyExpanded =
44 MLDSAPublicKeyExpanded<MLDSA44_k, MLDSA44_l, MLDSA44PublicKey, MLDSA44_PK_LEN>;
45pub type MLDSA44PrivateKeyExpanded = MLDSAPrivateKeyExpanded<
47 MLDSA44_k,
48 MLDSA44_l,
49 MLDSA44_ETA,
50 MLDSA44PublicKey,
51 MLDSA44PrivateKey,
52 MLDSA44_SK_LEN,
53 MLDSA44_PK_LEN,
54>;
55pub type MLDSA65PublicKeyExpanded =
57 MLDSAPublicKeyExpanded<MLDSA65_k, MLDSA65_l, MLDSA65PublicKey, MLDSA65_PK_LEN>;
58pub type MLDSA65PrivateKeyExpanded = MLDSAPrivateKeyExpanded<
60 MLDSA65_k,
61 MLDSA65_l,
62 MLDSA65_ETA,
63 MLDSA65PublicKey,
64 MLDSA65PrivateKey,
65 MLDSA65_SK_LEN,
66 MLDSA65_PK_LEN,
67>;
68pub type MLDSA87PublicKeyExpanded =
70 MLDSAPublicKeyExpanded<MLDSA87_k, MLDSA87_l, MLDSA87PublicKey, MLDSA87_PK_LEN>;
71pub type MLDSA87PrivateKeyExpanded = MLDSAPrivateKeyExpanded<
73 MLDSA87_k,
74 MLDSA87_l,
75 MLDSA87_ETA,
76 MLDSA87PublicKey,
77 MLDSA87PrivateKey,
78 MLDSA87_SK_LEN,
79 MLDSA87_PK_LEN,
80>;
81
82#[derive(Clone)]
84pub struct MLDSAPublicKey<const k: usize, const l: usize, const PK_LEN: usize> {
85 rho: [u8; 32],
86 t1: Vector<k>,
87}
88
89impl<const k: usize, const l: usize, const PK_LEN: usize> MLDSAPublicKey<k, l, PK_LEN> {
90 fn pk_encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
95 out[0..32].copy_from_slice(&self.rho);
96
97 let (pk_chunks, last_chunk) = out[32..].as_chunks_mut::<POLY_T1PACKED_LEN>();
98
99 debug_assert_eq!(pk_chunks.len(), k);
101 debug_assert_eq!(last_chunk.len(), 0);
102
103 for (pk_chunk, t1_i) in pk_chunks.into_iter().zip(&self.t1.vec) {
104 pk_chunk.copy_from_slice(&simple_bit_pack_t1(&t1_i));
105 }
106
107 PK_LEN
108 }
109}
110
111pub trait MLDSAPublicKeyTrait<const k: usize, const l: usize, const PK_LEN: usize>:
113 SignaturePublicKey<PK_LEN>
114{
115 fn pk_decode(pk: &[u8; PK_LEN]) -> Self;
120
121 fn A_hat(&self) -> Matrix<k, l>;
123
124 fn compute_tr(&self) -> [u8; 64];
131}
132
133pub(crate) trait MLDSAPublicKeyInternalTrait<const k: usize, const PK_LEN: usize>:
134 SignaturePublicKey<PK_LEN>
135{
136 fn new(rho: [u8; 32], t1: Vector<k>) -> Self;
139
140 fn t1(&self) -> &Vector<k>;
142}
143
144impl<const k: usize, const l: usize, const PK_LEN: usize> MLDSAPublicKeyTrait<k, l, PK_LEN>
145 for MLDSAPublicKey<k, l, PK_LEN>
146{
147 fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
149 let rho = pk[0..32].try_into().unwrap();
150 let mut t1 = Vector::<k>::new();
151
152 let (pk_chunks, last_chunk) = pk[32..].as_chunks::<POLY_T1PACKED_LEN>();
153
154 debug_assert_eq!(pk_chunks.len(), k);
156 debug_assert_eq!(last_chunk.len(), 0);
157
158 for (t1_i, pk_chunk) in t1.vec.iter_mut().zip(pk_chunks) {
159 t1_i.coeffs.copy_from_slice(&simple_bit_unpack_t1(pk_chunk).coeffs);
163 }
164
165 Self::new(rho, t1)
166 }
167
168 fn A_hat(&self) -> Matrix<k, l> {
169 expandA::<k, l>(&self.rho)
170 }
171
172 fn compute_tr(&self) -> [u8; 64] {
173 let mut tr = [0u8; 64];
174 H::new().hash_xof_out(&self.encode(), &mut tr);
175
176 tr
177 }
178}
179
180impl<const k: usize, const l: usize, const PK_LEN: usize> MLDSAPublicKeyInternalTrait<k, PK_LEN>
181 for MLDSAPublicKey<k, l, PK_LEN>
182{
183 fn new(rho: [u8; 32], t1: Vector<k>) -> Self {
184 Self { rho, t1 }
185 }
186
187 fn t1(&self) -> &Vector<k> {
188 &self.t1
189 }
190}
191
192impl<const k: usize, const l: usize, const PK_LEN: usize> SignaturePublicKey<PK_LEN>
193 for MLDSAPublicKey<k, l, PK_LEN>
194{
195 fn encode(&self) -> [u8; PK_LEN] {
196 let mut pk = [0u8; PK_LEN];
197 let bytes_written = self.encode_out(&mut pk);
198 debug_assert_eq!(bytes_written, PK_LEN);
199
200 pk
201 }
202
203 fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
204 self.pk_encode_out(out)
205 }
206
207 fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
208 if bytes.len() != PK_LEN {
209 return Err(SignatureError::DecodingError(
210 "Provided key bytes are the incorrect length",
211 ));
212 }
213 let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
214 Ok(Self::pk_decode(&bytes_sized))
215 }
216}
217
218impl<const k: usize, const l: usize, const PK_LEN: usize> Eq for MLDSAPublicKey<k, l, PK_LEN> {}
219
220impl<const k: usize, const l: usize, const PK_LEN: usize> PartialEq
221 for MLDSAPublicKey<k, l, PK_LEN>
222{
223 fn eq(&self, other: &Self) -> bool {
224 let self_encoded = self.encode();
225 let other_encoded = other.encode();
226 bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
227 }
228}
229
230impl<const k: usize, const l: usize, const PK_LEN: usize> Debug for MLDSAPublicKey<k, l, PK_LEN> {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 let alg = match k {
233 4 => ML_DSA_44_NAME,
234 6 => ML_DSA_65_NAME,
235 8 => ML_DSA_87_NAME,
236 _ => panic!("Unsupported key length"),
237 };
238 write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
239 }
240}
241
242impl<const k: usize, const l: usize, const PK_LEN: usize> Display for MLDSAPublicKey<k, l, PK_LEN> {
243 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244 let alg = match k {
245 4 => ML_DSA_44_NAME,
246 6 => ML_DSA_65_NAME,
247 8 => ML_DSA_87_NAME,
248 _ => panic!("Unsupported key length"),
249 };
250 write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
251 }
252}
253
254#[derive(Clone)]
258pub struct MLDSAPublicKeyExpanded<
259 const k: usize,
260 const l: usize,
261 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
262 const PK_LEN: usize,
263> {
264 pub(crate) pk: PK,
265 pub(crate) A_hat: Matrix<k, l>,
266}
267
268impl<
269 const k: usize,
270 const l: usize,
271 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
272 const PK_LEN: usize,
273> SignaturePublicKey<PK_LEN> for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
274{
275 fn encode(&self) -> [u8; PK_LEN] {
276 self.pk.encode()
277 }
278
279 fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
280 self.pk.encode_out(out)
281 }
282
283 fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
284 if bytes.len() != PK_LEN {
285 return Err(SignatureError::DecodingError(
286 "Provided key bytes are the incorrect length",
287 ));
288 }
289 let bytes_sized: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
290 Ok(Self::pk_decode(&bytes_sized))
291 }
292}
293
294impl<
295 const k: usize,
296 const l: usize,
297 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
298 const PK_LEN: usize,
299> PartialEq for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
300{
301 fn eq(&self, other: &Self) -> bool {
302 self.pk.eq(&other.pk)
303 }
304}
305
306impl<
307 const k: usize,
308 const l: usize,
309 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
310 const PK_LEN: usize,
311> Eq for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
312{
313}
314
315impl<
316 const k: usize,
317 const l: usize,
318 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
319 const PK_LEN: usize,
320> Debug for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
321{
322 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
323 let alg = match k {
324 4 => ML_DSA_44_NAME,
325 6 => ML_DSA_65_NAME,
326 8 => ML_DSA_87_NAME,
327 _ => panic!("Unsupported key length"),
328 };
329 write!(
330 f,
331 "MLDSAPublicKeyExpanded {{ alg: {}, pub_key_hash (tr): {:x?} }}",
332 alg,
333 self.compute_tr(),
334 )
335 }
336}
337
338impl<
339 const k: usize,
340 const l: usize,
341 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
342 const PK_LEN: usize,
343> Display for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
344{
345 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
346 let alg = match k {
347 4 => ML_DSA_44_NAME,
348 6 => ML_DSA_65_NAME,
349 8 => ML_DSA_87_NAME,
350 _ => panic!("Unsupported key length"),
351 };
352 write!(
353 f,
354 "MLDSAPublicKeyExpanded {{ alg: {}, pub_key_hash (tr): {:x?} }}",
355 alg,
356 self.compute_tr(),
357 )
358 }
359}
360
361impl<
362 const k: usize,
363 const l: usize,
364 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
365 const PK_LEN: usize,
366> From<&PK> for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
367{
368 fn from(pk: &PK) -> Self {
371 let A_hat = pk.A_hat();
372
373 Self { pk: pk.clone(), A_hat }
374 }
375}
376
377impl<
378 const k: usize,
379 const l: usize,
380 PK: MLDSAPublicKeyTrait<k, l, PK_LEN> + MLDSAPublicKeyInternalTrait<k, PK_LEN>,
381 const PK_LEN: usize,
382> MLDSAPublicKeyTrait<k, l, PK_LEN> for MLDSAPublicKeyExpanded<k, l, PK, PK_LEN>
383{
384 fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
385 let pk1 = PK::pk_decode(pk);
386 let A_hat = pk1.A_hat();
387 Self { pk: pk1, A_hat }
388 }
389
390 fn A_hat(&self) -> Matrix<k, l> {
391 self.A_hat.clone()
392 }
393
394 fn compute_tr(&self) -> [u8; 64] {
395 self.pk.compute_tr()
396 }
397}
398
399#[derive(Clone)]
401pub struct MLDSAPrivateKey<
402 const k: usize,
403 const l: usize,
404 const eta: usize,
405 const SK_LEN: usize,
406 const PK_LEN: usize,
407> {
408 rho: [u8; 32],
409 K: [u8; 32],
410 tr: [u8; 64],
411 s1_hat: Vector<l>,
418 s2_hat: Vector<k>,
419 t0_hat: Vector<k>,
420 seed: Option<KeyMaterial<32>>,
421}
422
423pub trait MLDSAPrivateKeyTrait<
425 const k: usize,
426 const l: usize,
427 const eta: usize,
428 const SK_LEN: usize,
429 const PK_LEN: usize,
430>: SignaturePrivateKey<SK_LEN>
431{
432 fn seed(&self) -> &Option<KeyMaterial<32>>;
434
435 fn tr(&self) -> &[u8; 64];
437
438 fn A_hat(&self) -> Matrix<k, l>;
440
441 fn derive_pk(&self) -> MLDSAPublicKey<k, l, PK_LEN>;
443 fn sk_encode(&self) -> [u8; SK_LEN];
449 fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize;
455 fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, SignatureError>;
464}
465
466pub(crate) trait MLDSAPrivateKeyInternalTrait<
467 const k: usize,
468 const l: usize,
469 const eta: usize,
470 const SK_LEN: usize,
471 const PK_LEN: usize,
472>
473{
474 fn new(
477 rho: [u8; 32],
478 K: [u8; 32],
479 tr: [u8; 64],
480 s1_hat: Vector<l>,
481 s2_hat: Vector<k>,
482 t0_hat: Vector<k>,
483 seed: Option<KeyMaterial<32>>,
484 ) -> Self;
485 fn K(&self) -> &[u8; 32];
487 fn s1_hat(&self) -> &Vector<l>;
489 fn s2_hat(&self) -> &Vector<k>;
491 fn t0_hat(&self) -> &Vector<k>;
493}
494
495impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
496 MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN> for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
497{
498 fn seed(&self) -> &Option<KeyMaterial<32>> {
499 &self.seed
500 }
501
502 fn tr(&self) -> &[u8; 64] {
503 &self.tr
504 }
505
506 fn A_hat(&self) -> Matrix<k, l> {
507 expandA::<k, l>(&self.rho)
508 }
509
510 fn derive_pk(&self) -> MLDSAPublicKey<k, l, PK_LEN> {
511 let mut t = {
514 let A_hat = expandA::<k, l>(&self.rho);
518
519 let mut t_ntt = A_hat.matrix_vector_ntt(&self.s1_hat);
520 t_ntt.inv_ntt();
521 t_ntt
522 };
523
524 {
525 let mut s2 = self.s2_hat.clone();
528 s2.reduce();
529 s2.inv_ntt();
530
531 t.add_vector_ntt(&s2);
532 t.conditional_add_q();
533 }
534 let (t1, _) = power_2_round_vec::<k>(&t);
538
539 MLDSAPublicKey::<k, l, PK_LEN>::new(self.rho.clone(), t1)
540 }
541 fn sk_encode(&self) -> [u8; SK_LEN] {
547 let mut out = [0u8; SK_LEN];
548 let bytes_written = self.sk_encode_out(&mut out);
549 debug_assert_eq!(bytes_written, SK_LEN);
550 out
551 }
552 fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
558 let mut off: usize = 0;
560
561 out[0..32].copy_from_slice(&self.rho);
562 out[32..64].copy_from_slice(&self.K);
563 out[64..128].copy_from_slice(&self.tr);
564 off += 128;
565
566 let mut buf = [0u8; 32 * 4]; let eta_pack_len = bitlen_eta(eta);
568
569 let sk_chunks = out[off..off + l * bitlen_eta(eta)].chunks_mut(bitlen_eta(eta));
570 debug_assert_eq!(sk_chunks.len(), l);
571 for (sk_chunk, s1_hat_i) in sk_chunks.into_iter().zip(&self.s1_hat.vec) {
572 let mut s1_hat_i = s1_hat_i.clone();
575 s1_hat_i.reduce();
576 s1_hat_i.inv_ntt();
577 let s1_i = s1_hat_i;
578
579 bit_pack_eta::<eta>(&s1_i, &mut buf);
580 sk_chunk.copy_from_slice(&buf[..eta_pack_len]);
581 }
582 off += l * bitlen_eta(eta);
583
584 let sk_chunks = out[off..off + k * bitlen_eta(eta)].chunks_mut(bitlen_eta(eta));
585 debug_assert_eq!(sk_chunks.len(), k);
586 for (sk_chunk, s2_hat_i) in sk_chunks.into_iter().zip(&self.s2_hat.vec) {
587 let mut s2_hat_i = s2_hat_i.clone();
590 s2_hat_i.reduce();
591 s2_hat_i.inv_ntt();
592 let s2_i = s2_hat_i;
593
594 bit_pack_eta::<eta>(&s2_i, &mut buf);
595 sk_chunk.copy_from_slice(&buf[..eta_pack_len]);
596 }
597 off += k * bitlen_eta(eta);
598
599 let sk_chunks = out[off..off + k * POLY_T0PACKED_LEN].chunks_mut(POLY_T0PACKED_LEN);
600 debug_assert_eq!(sk_chunks.len(), k);
601 for (sk_chunk, t0_hat_i) in sk_chunks.into_iter().zip(&self.t0_hat.vec) {
602 let mut t0_hat_i = t0_hat_i.clone();
605 t0_hat_i.reduce();
606 t0_hat_i.inv_ntt();
607 let t0_i = t0_hat_i;
608
609 sk_chunk.copy_from_slice(&bit_pack_t0(&t0_i));
610 }
611
612 SK_LEN
613 }
614 fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, SignatureError> {
615 let rho = sk[0..32].try_into().unwrap();
616 let K = sk[32..64].try_into().unwrap();
617 let tr = sk[64..128].try_into().unwrap();
618 let mut s1 = Vector::<l>::new();
619 let mut s2 = Vector::<k>::new();
620 let mut t0 = Vector::<k>::new();
621 let mut off = 128;
622
623 let sk_chunks = sk[128..128 + (l * bitlen_eta(eta))].chunks(bitlen_eta(eta));
626 debug_assert_eq!(sk_chunks.len(), l);
627 for (s1_i, sk_chunk) in s1.vec.iter_mut().zip(sk_chunks) {
628 s1_i.coeffs.copy_from_slice(&bit_unpack_eta::<eta>(&sk_chunk).coeffs);
631
632 for coeff in s1_i.coeffs.iter() {
634 if *coeff < -(eta as i32) || *coeff > (eta as i32) {
635 return Err(SignatureError::DecodingError("Invalid or corrupted key"));
636 }
637 }
638 }
639 s1.ntt();
642 off += l * bitlen_eta(eta);
643
644 let sk_chunks = sk[off..off + (k * bitlen_eta(eta))].chunks(bitlen_eta(eta));
646 debug_assert_eq!(sk_chunks.len(), k);
647 for (s2_i, sk_chunk) in s2.vec.iter_mut().zip(sk_chunks) {
648 s2_i.coeffs.copy_from_slice(&bit_unpack_eta::<eta>(&sk_chunk).coeffs);
651
652 for coeff in s2_i.coeffs.iter() {
654 if *coeff < -(eta as i32) || *coeff > (eta as i32) {
655 return Err(SignatureError::DecodingError("Invalid or corrupted key"));
656 }
657 }
658 }
659 s2.ntt();
662 off += k * bitlen_eta(eta);
663
664 let (sk_chunks, last_chunk) =
666 sk[off..off + (k * POLY_T0PACKED_LEN)].as_chunks::<POLY_T0PACKED_LEN>();
667
668 debug_assert_eq!(sk_chunks.len(), k);
670 debug_assert_eq!(last_chunk.len(), 0);
671
672 for (t0_i, sk_chunk) in t0.vec.iter_mut().zip(sk_chunks) {
673 t0_i.coeffs.copy_from_slice(&bit_unpack_t0(sk_chunk).coeffs);
674 }
675 t0.ntt();
678
679 Ok(Self { rho, K, tr, s1_hat: s1, s2_hat: s2, t0_hat: t0, seed: None })
680 }
681}
682
683impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
684 MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>
685 for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
686{
687 fn new(
688 rho: [u8; 32],
689 K: [u8; 32],
690 tr: [u8; 64],
691 s1_hat: Vector<l>,
692 s2_hat: Vector<k>,
693 t0_hat: Vector<k>,
694 seed: Option<KeyMaterial<32>>,
695 ) -> Self {
696 Self {
697 rho: rho.clone(),
698 K: K.clone(),
699 tr: tr.clone(),
700 s1_hat: s1_hat.clone(),
701 s2_hat: s2_hat.clone(),
702 t0_hat: t0_hat.clone(),
703 seed: seed.clone(),
704 }
705 }
706
707 fn K(&self) -> &[u8; 32] {
708 &self.K
709 }
710
711 fn s1_hat(&self) -> &Vector<l> {
712 &self.s1_hat
713 }
714
715 fn s2_hat(&self) -> &Vector<k> {
716 &self.s2_hat
717 }
718
719 fn t0_hat(&self) -> &Vector<k> {
720 &self.t0_hat
721 }
722}
723
724impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
725 SignaturePrivateKey<SK_LEN> for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
726{
727 fn encode(&self) -> [u8; SK_LEN] {
728 self.sk_encode()
729 }
730
731 fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
732 self.sk_encode_out(out)
733 }
734
735 fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
736 if bytes.len() != SK_LEN {
737 return Err(SignatureError::DecodingError(
738 "Provided key bytes are the incorrect length",
739 ));
740 }
741 let bytes_sized: [u8; SK_LEN] = bytes[..SK_LEN].try_into().unwrap();
742
743 Ok(Self::sk_decode(&bytes_sized)?)
744 }
745}
746
747impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize> Eq
748 for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
749{
750}
751
752impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
753 PartialEq for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
754{
755 fn eq(&self, other: &Self) -> bool {
756 let self_encoded = self.sk_encode();
757 let other_encoded = other.sk_encode();
758 bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
759 }
760}
761
762impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
763 Secret for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
764{
765}
766
767impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
769 fmt::Debug for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
770{
771 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
772 let alg = match k {
773 4 => ML_DSA_44_NAME,
774 6 => ML_DSA_65_NAME,
775 8 => ML_DSA_87_NAME,
776 _ => panic!("Unsupported key length"),
777 };
778 write!(
779 f,
780 "MLDSAPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
781 alg,
782 self.tr,
783 self.seed.is_some(),
784 )
785 }
786}
787
788impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
790 Display for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
791{
792 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
793 let alg = match k {
794 4 => ML_DSA_44_NAME,
795 6 => ML_DSA_65_NAME,
796 8 => ML_DSA_87_NAME,
797 _ => panic!("Unsupported key length"),
798 };
799 write!(
800 f,
801 "MLDSAPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
802 alg,
803 self.tr,
804 self.seed.is_some(),
805 )
806 }
807}
808
809impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
811 Drop for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
812{
813 fn drop(&mut self) {
814 self.K.fill(0u8);
815 }
817}
818
819#[derive(Clone)]
823pub struct MLDSAPrivateKeyExpanded<
824 const k: usize,
825 const l: usize,
826 const eta: usize,
827 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
828 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
829 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
830 const SK_LEN: usize,
831 const PK_LEN: usize,
832> {
833 _phantom: core::marker::PhantomData<PK>,
834 pub(crate) sk: SK,
835 pub(crate) A_hat: Matrix<k, l>,
836}
837
838impl<
839 const k: usize,
840 const l: usize,
841 const eta: usize,
842 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
843 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
844 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
845 const SK_LEN: usize,
846 const PK_LEN: usize,
847> PartialEq for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
848{
849 fn eq(&self, other: &Self) -> bool {
850 self.sk.eq(&other.sk)
851 }
852}
853
854impl<
855 const k: usize,
856 const l: usize,
857 const eta: usize,
858 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
859 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
860 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
861 const SK_LEN: usize,
862 const PK_LEN: usize,
863> Eq for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
864{
865}
866
867impl<
868 const k: usize,
869 const l: usize,
870 const eta: usize,
871 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
872 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
873 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
874 const SK_LEN: usize,
875 const PK_LEN: usize,
876> Secret for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
877{
878}
879
880impl<
881 const k: usize,
882 const l: usize,
883 const eta: usize,
884 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
885 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
886 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
887 const SK_LEN: usize,
888 const PK_LEN: usize,
889> Drop for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
890{
891 fn drop(&mut self) {
892 }
894}
895
896impl<
897 const k: usize,
898 const l: usize,
899 const eta: usize,
900 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
901 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
902 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
903 const SK_LEN: usize,
904 const PK_LEN: usize,
905> Debug for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
906{
907 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
908 let alg = match k {
909 4 => ML_DSA_44_NAME,
910 6 => ML_DSA_65_NAME,
911 8 => ML_DSA_87_NAME,
912 _ => panic!("Unsupported key length"),
913 };
914 write!(
915 f,
916 "MLDSAPrivateKeyExpanded {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
917 alg,
918 self.sk.tr(),
919 self.sk.seed().is_some(),
920 )
921 }
922}
923
924impl<
925 const k: usize,
926 const l: usize,
927 const eta: usize,
928 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
929 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
930 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
931 const SK_LEN: usize,
932 const PK_LEN: usize,
933> Display for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
934{
935 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
936 let alg = match k {
937 4 => ML_DSA_44_NAME,
938 6 => ML_DSA_65_NAME,
939 8 => ML_DSA_87_NAME,
940 _ => panic!("Unsupported key length"),
941 };
942 write!(
943 f,
944 "MLDSAPrivateKeyExpanded {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
945 alg,
946 self.sk.tr(),
947 self.sk.seed().is_some(),
948 )
949 }
950}
951
952impl<
953 const k: usize,
954 const l: usize,
955 const eta: usize,
956 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
957 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
958 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
959 const SK_LEN: usize,
960 const PK_LEN: usize,
961> From<&SK> for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
962{
963 fn from(sk: &SK) -> Self {
966 let A_hat = sk.derive_pk().A_hat();
967
968 Self { _phantom: core::marker::PhantomData, sk: sk.clone(), A_hat }
969 }
970}
971
972impl<
973 const k: usize,
974 const l: usize,
975 const eta: usize,
976 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
977 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
978 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
979 const SK_LEN: usize,
980 const PK_LEN: usize,
981> SignaturePrivateKey<SK_LEN> for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
982{
983 fn encode(&self) -> [u8; SK_LEN] {
984 self.sk.encode()
985 }
986
987 fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
988 self.sk.encode_out(out)
989 }
990
991 fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
992 let sk = SK::from_bytes(bytes)?;
993 Ok(Self::from(&sk))
994 }
995}
996
997impl<
998 const k: usize,
999 const l: usize,
1000 const eta: usize,
1001 PK: MLDSAPublicKeyInternalTrait<k, PK_LEN>,
1002 SK: MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
1003 + MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN>,
1004 const SK_LEN: usize,
1005 const PK_LEN: usize,
1006> MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN>
1007 for MLDSAPrivateKeyExpanded<k, l, eta, PK, SK, SK_LEN, PK_LEN>
1008{
1009 fn seed(&self) -> &Option<KeyMaterial<32>> {
1010 self.sk.seed()
1011 }
1012
1013 fn tr(&self) -> &[u8; 64] {
1014 self.sk.tr()
1015 }
1016
1017 fn A_hat(&self) -> Matrix<k, l> {
1018 self.sk.A_hat()
1019 }
1020
1021 fn derive_pk(&self) -> MLDSAPublicKey<k, l, PK_LEN> {
1022 self.sk.derive_pk()
1023 }
1024
1025 fn sk_encode(&self) -> [u8; SK_LEN] {
1026 self.sk.sk_encode()
1027 }
1028
1029 fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
1030 self.sk.sk_encode_out(out)
1031 }
1032
1033 fn sk_decode(sk: &[u8; SK_LEN]) -> Result<Self, SignatureError> {
1034 let sk1 = SK::sk_decode(sk)?;
1035 let A_hat = sk1.derive_pk().A_hat();
1036
1037 Ok(Self { _phantom: core::marker::PhantomData, sk: sk1, A_hat })
1038 }
1039}