1use crate::aux_functions::{
2 bit_pack_eta, bit_pack_t0, bit_unpack_eta, bit_unpack_t0, bitlen_eta, expandA, power_2_round_vec,
3 simple_bit_pack_t1, simple_bit_unpack_t1
4};
5use crate::matrix::Vector;
6use crate::mldsa::H;
7use crate::{ML_DSA_44_NAME, ML_DSA_65_NAME, ML_DSA_87_NAME};
8use crate::mldsa::{MLDSA44_ETA, MLDSA44_PK_LEN, MLDSA44_SK_LEN, MLDSA44_k, MLDSA44_l};
9use crate::mldsa::{MLDSA65_ETA, MLDSA65_PK_LEN, MLDSA65_SK_LEN, MLDSA65_k, MLDSA65_l};
10use crate::mldsa::{MLDSA87_ETA, MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_k, MLDSA87_l};
11use crate::mldsa::{POLY_T0PACKED_LEN, POLY_T1PACKED_LEN, SEED_LEN};
12use bouncycastle_core_interface::errors::SignatureError;
13use bouncycastle_core_interface::key_material::KeyMaterialSized;
14use bouncycastle_core_interface::traits::{Secret, SignaturePrivateKey, SignaturePublicKey, XOF};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17
18#[allow(unused_imports)]
20use crate::mldsa::MLDSATrait;
21
22
23
24pub type MLDSA44PublicKey = MLDSAPublicKey<MLDSA44_k, MLDSA44_PK_LEN>;
28pub type MLDSA44PrivateKey = MLDSAPrivateKey<MLDSA44_k, MLDSA44_l, MLDSA44_ETA, MLDSA44_SK_LEN, MLDSA44_PK_LEN>;
30pub type MLDSA65PublicKey = MLDSAPublicKey<MLDSA65_k, MLDSA65_PK_LEN>;
32pub type MLDSA65PrivateKey = MLDSAPrivateKey<MLDSA65_k, MLDSA65_l, MLDSA65_ETA, MLDSA65_SK_LEN, MLDSA65_PK_LEN>;
34pub type MLDSA87PublicKey = MLDSAPublicKey<MLDSA87_k, MLDSA87_PK_LEN>;
36pub type MLDSA87PrivateKey = MLDSAPrivateKey<MLDSA87_k, MLDSA87_l, MLDSA87_ETA, MLDSA87_SK_LEN, MLDSA87_PK_LEN>;
38
39#[derive(Clone)]
41pub struct MLDSAPublicKey<const k: usize, const PK_LEN: usize> {
42 rho: [u8; SEED_LEN],
43 t1: Vector<k>,
44}
45
46pub trait MLDSAPublicKeyTrait<const k: usize, const PK_LEN: usize> : SignaturePublicKey {
48 fn pk_encode(&self) -> [u8; PK_LEN];
53
54 fn pk_decode(pk: &[u8; PK_LEN]) -> Self;
59
60 fn compute_tr(&self) -> [u8; 64];
67}
68
69pub(crate) trait MLDSAPublicKeyInternalTrait<const k: usize, const PK_LEN: usize> {
70 fn new(rho: &[u8; SEED_LEN], t1: &Vector<k>) -> Self;
73
74 fn rho(&self) -> &[u8; 32];
76
77 fn t1(&self) -> &Vector<k>;
79}
80
81impl<const k: usize, const PK_LEN: usize> MLDSAPublicKeyTrait<k, PK_LEN> for MLDSAPublicKey<k, PK_LEN> {
82 fn pk_encode(&self) -> [u8; PK_LEN] {
83 let mut pk = [0u8; PK_LEN];
84
85 pk[0..SEED_LEN].copy_from_slice(&self.rho);
86
87 let (pk_chunks, last_chunk) = pk[SEED_LEN..].as_chunks_mut::<POLY_T1PACKED_LEN>();
88
89 debug_assert_eq!(pk_chunks.len(), k);
91 debug_assert_eq!(last_chunk.len(), 0);
92
93 for (pk_chunk, t1_i) in pk_chunks.into_iter().zip(&self.t1.vec) {
100 pk_chunk.copy_from_slice(&simple_bit_pack_t1(&t1_i));
101 }
102
103 pk
104 }
105
106 fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
107 let rho = pk[0..32].try_into().unwrap();
108 let mut t1 = Vector::<k>::new();
109
110 let (pk_chunks, last_chunk) = pk[32..].as_chunks::<POLY_T1PACKED_LEN>();
111
112 debug_assert_eq!(pk_chunks.len(), k);
114 debug_assert_eq!(last_chunk.len(), 0);
115
116 for (t1_i, pk_chunk) in t1.vec.iter_mut().zip(pk_chunks) {
117 t1_i.0.copy_from_slice(&simple_bit_unpack_t1(pk_chunk).0);
118 }
119
120 Self::new(&rho, &t1)
121 }
122
123 fn compute_tr(&self) -> [u8; 64] {
124 let mut tr = [0u8; 64];
125 H::new().hash_xof_out(&self.pk_encode(), &mut tr);
126
127 tr
128 }
129}
130
131impl<const k: usize, const PK_LEN: usize> MLDSAPublicKeyInternalTrait<k, PK_LEN> for MLDSAPublicKey<k, PK_LEN> {
132 fn new(rho: &[u8; SEED_LEN], t1: &Vector<k>) -> Self {
133 Self { rho: rho.clone(), t1: t1.clone() }
134 }
135
136 fn rho(&self) -> &[u8; 32] { &self.rho }
137
138 fn t1(&self) -> &Vector<k> { &self.t1 }
139}
140
141impl<const k: usize, const PK_LEN: usize> SignaturePublicKey for MLDSAPublicKey<k, PK_LEN> {
142 fn encode(&self) -> Vec<u8> {
143 Vec::from(self.pk_encode())
144 }
145
146 fn encode_out(&self, out: &mut [u8]) -> Result<usize, SignatureError> {
147 if out.len() < PK_LEN {
148 Err(SignatureError::EncodingError("Output buffer too small"))
149 } else {
150 let tmp = self.pk_encode();
151 debug_assert_eq!(tmp.len(), PK_LEN);
152 out[..PK_LEN].copy_from_slice(&tmp);
153 Ok(PK_LEN)
154 }
155 }
156
157 fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
158 if bytes.len() != PK_LEN { return Err(SignatureError::DecodingError("Provided key bytes are the incorrect length")) }
159 let sized_bytes: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
160 Ok(Self::pk_decode(&sized_bytes))
161 }
162}
163
164impl<const k: usize, const PK_LEN: usize> Eq for MLDSAPublicKey<k, PK_LEN> { }
165
166impl<const k: usize, const PK_LEN: usize> PartialEq for MLDSAPublicKey<k, PK_LEN> {
167 fn eq(&self, other: &Self) -> bool {
168 let self_encoded = self.pk_encode();
169 let other_encoded = other.pk_encode();
170 bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
171 }
172}
173
174impl<const k: usize, const PK_LEN: usize> fmt::Debug for MLDSAPublicKey<k, PK_LEN> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
176 let alg = match k {
177 4 => ML_DSA_44_NAME,
178 6 => ML_DSA_65_NAME,
179 8 => ML_DSA_87_NAME,
180 _ => panic!("Unsupported key length"),
181 };
182 write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
183 }
184}
185
186impl<const k: usize, const PK_LEN: usize> Display for MLDSAPublicKey<k, PK_LEN> {
187 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
188 let alg = match k {
189 4 => ML_DSA_44_NAME,
190 6 => ML_DSA_65_NAME,
191 8 => ML_DSA_87_NAME,
192 _ => panic!("Unsupported key length"),
193 };
194 write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
195 }
196}
197
198#[derive(Clone)]
200pub struct MLDSAPrivateKey<
201 const k: usize,
202 const l: usize,
203 const eta: usize,
204 const SK_LEN: usize,
205 const PK_LEN: usize,
206> {
207 rho: [u8; 32],
208 K: [u8; 32],
209 tr: [u8; 64],
210 s1: Vector<l>,
211 s2: Vector<k>,
212 t0: Vector<k>,
213 seed: Option<KeyMaterialSized<32>>,
214}
215
216pub trait MLDSAPrivateKeyTrait<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize> : SignaturePrivateKey {
218 fn seed(&self) -> &Option<KeyMaterialSized<32>>;
220
221 fn tr(&self) -> &[u8; 64];
223
224 fn derive_pk(&self) -> MLDSAPublicKey<k, PK_LEN>;
226 fn sk_encode(&self) -> [u8; SK_LEN];
232 fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize;
238 fn sk_decode(sk: &[u8; SK_LEN]) -> Self;
247}
248
249pub(crate) trait MLDSAPrivateKeyInternalTrait<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize> {
250 fn new(
253 rho: &[u8; 32],
254 K: &[u8; 32],
255 tr: &[u8; 64],
256 s1: &Vector<l>,
257 s2: &Vector<k>,
258 t0: &Vector<k>,
259 seed: Option<KeyMaterialSized<32>>,
260 ) -> Self;
261 fn rho(&self) -> &[u8; 32];
263
264 fn K(&self) -> &[u8; 32];
266
267 fn s1(&self) -> &Vector<l>;
273
274 fn s2(&self) -> &Vector<k>;
276
277 fn t0(&self) -> &Vector<k>;
279}
280
281
282impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
283 MLDSAPrivateKeyTrait<k, l, eta, SK_LEN, PK_LEN> for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {
284 fn seed(&self) -> &Option<KeyMaterialSized<32>> { &self.seed }
285
286 fn tr(&self) -> &[u8; 64] {
287 &self.tr
288 }
289
290 fn derive_pk(&self) -> MLDSAPublicKey<k, PK_LEN> {
291
292 let mut s1_hat = self.s1.clone();
295 s1_hat.ntt();
296
297 let mut t = { let A_hat = expandA::<k, l>(&self.rho);
299
300 let mut t_ntt = A_hat.matrix_vector_ntt(&s1_hat);
302 t_ntt.inv_ntt();
303 t_ntt
304 };
305 t.add_vector_ntt(&self.s2);
306 t.conditional_add_q();
307
308 let (t1, _) = power_2_round_vec::<k>(&t);
312
313 MLDSAPublicKey::<k, PK_LEN>::new(&self.rho, &t1)
314 }
315
316 fn sk_encode(&self) -> [u8; SK_LEN] {
317 let mut out = [0u8; SK_LEN];
318 let bytes_written = self.sk_encode_out(&mut out);
319 debug_assert_eq!(bytes_written, SK_LEN);
320 out
321 }
322
323 fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
324 let mut off: usize = 0;
326
327 out[0..32].copy_from_slice(&self.rho);
328 out[32..64].copy_from_slice(&self.K);
329 out[64..128].copy_from_slice(&self.tr);
330 off += 128;
331
332 let mut buf = [0u8; 32 * 4]; let eta_pack_len = bitlen_eta(eta);
334
335 let sk_chunks = out[off..off + l * bitlen_eta(eta)].chunks_mut(bitlen_eta(eta));
336 debug_assert_eq!(sk_chunks.len(), l);
337 for (sk_chunk, s1_i) in sk_chunks.into_iter().zip(&self.s1.vec) {
338 bit_pack_eta::<eta>(s1_i, &mut buf);
339 sk_chunk.copy_from_slice(&buf[..eta_pack_len]);
340 }
341 off += l * bitlen_eta(eta);
342
343 let sk_chunks = out[off..off + k * bitlen_eta(eta)].chunks_mut(bitlen_eta(eta));
344 debug_assert_eq!(sk_chunks.len(), k);
345 for (sk_chunk, s2_i) in sk_chunks.into_iter().zip(&self.s2.vec) {
346 bit_pack_eta::<eta>(s2_i, &mut buf);
347 sk_chunk.copy_from_slice(&buf[..eta_pack_len]);
348 }
349 off += k * bitlen_eta(eta);
350
351 let sk_chunks = out[off..off + k * POLY_T0PACKED_LEN].chunks_mut(POLY_T0PACKED_LEN);
352 debug_assert_eq!(sk_chunks.len(), k);
353 for (sk_chunk, t0_i) in sk_chunks.into_iter().zip(&self.t0.vec) {
354 sk_chunk.copy_from_slice(&bit_pack_t0(t0_i));
355 }
356
357 SK_LEN
358 }
359 fn sk_decode(sk: &[u8; SK_LEN]) -> Self {
360 let rho = sk[0..32].try_into().unwrap();
361 let K = sk[32..64].try_into().unwrap();
362 let tr = sk[64..128].try_into().unwrap();
363 let mut s1 = Vector::<l>::new();
364 let mut s2 = Vector::<k>::new();
365 let mut t0 = Vector::<k>::new();
366 let mut off = 128;
367
368 let sk_chunks = sk[128..128 + (l * bitlen_eta(eta))].chunks(bitlen_eta(eta));
371 debug_assert_eq!(sk_chunks.len(), l);
372 for (s1_i, sk_chunk) in s1.vec.iter_mut().zip(sk_chunks) {
373 s1_i.0.copy_from_slice(&bit_unpack_eta::<eta>(&sk_chunk).0);
374 }
375 off += l * bitlen_eta(eta);
376
377 let sk_chunks = sk[off..off + (k * bitlen_eta(eta))].chunks(bitlen_eta(eta));
379 debug_assert_eq!(sk_chunks.len(), k);
380 for (s2_i, sk_chunk) in s2.vec.iter_mut().zip(sk_chunks) {
381 s2_i.0.copy_from_slice(&bit_unpack_eta::<eta>(&sk_chunk).0);
382 }
383 off += k * bitlen_eta(eta);
384
385 let (sk_chunks, last_chunk) =
387 sk[off..off + (k * POLY_T0PACKED_LEN)].as_chunks::<POLY_T0PACKED_LEN>();
388
389 debug_assert_eq!(sk_chunks.len(), k);
391 debug_assert_eq!(last_chunk.len(), 0);
392
393 for (t0_i, sk_chunk) in t0.vec.iter_mut().zip(sk_chunks) {
394 t0_i.0.copy_from_slice(&bit_unpack_t0(sk_chunk).0);
395 }
396
397 Self::new(&rho, &K, &tr, &s1, &s2, &t0, None)
398 }
399}
400
401impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
402 MLDSAPrivateKeyInternalTrait<k, l, eta, SK_LEN, PK_LEN> for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {
403 fn new(
404 rho: &[u8; 32],
405 K: &[u8; 32],
406 tr: &[u8; 64],
407 s1: &Vector<l>,
408 s2: &Vector<k>,
409 t0: &Vector<k>,
410 seed: Option<KeyMaterialSized<32>>,
411 ) -> Self {
412 Self {
413 rho: rho.clone(),
414 K: K.clone(),
415 tr: tr.clone(),
416 s1: s1.clone(),
417 s2: s2.clone(),
418 t0: t0.clone(),
419 seed: seed.clone(),
420 }
421 }
422
423 fn rho(&self) -> &[u8; 32] { &self.rho }
424
425 fn K(&self) -> &[u8; 32] { &self.K }
426
427 fn s1(&self) -> &Vector<l> { &self.s1 }
431
432 fn s2(&self) -> &Vector<k> { &self.s2 }
433
434 fn t0(&self) -> &Vector<k> { &self.t0 }
435}
436
437impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
438 SignaturePrivateKey for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {
439 fn encode(&self) -> Vec<u8> {
440 self.sk_encode().to_vec()
441 }
442
443 fn encode_out(&self, out: &mut [u8]) -> Result<usize, SignatureError> {
444 if out.len() < SK_LEN {
445 Err(SignatureError::EncodingError("Output buffer too small"))
446 } else {
447 let out_sized: &mut [u8; SK_LEN] = out[..SK_LEN].as_mut().try_into().unwrap();
448 Ok(self.sk_encode_out(out_sized))
449 }
450 }
451
452 fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
453 if bytes.len() != SK_LEN { return Err(SignatureError::DecodingError("Provided key bytes are the incorrect length")) }
454 let sized_bytes: [u8; SK_LEN] = bytes[..SK_LEN].try_into().unwrap();
455 Ok(Self::sk_decode(&sized_bytes))
456 }
457}
458
459impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
460 Eq for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {}
461
462impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
463 PartialEq for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
464{
465 fn eq(&self, other: &Self) -> bool {
466 let self_encoded = self.sk_encode();
467 let other_encoded = other.sk_encode();
468 bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
469 }
470}
471
472impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
473Secret for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN> {}
474
475impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
477 fmt::Debug for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
478{
479 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
480 let alg = match k {
481 4 => ML_DSA_44_NAME,
482 6 => ML_DSA_65_NAME,
483 8 => ML_DSA_87_NAME,
484 _ => panic!("Unsupported key length"),
485 };
486 write!(
487 f,
488 "MLDSAPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
489 alg,
490 self.tr,
491 self.seed.is_some(),
492 )
493 }
494}
495
496impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
498 Display for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
499{
500 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
501 let alg = match k {
502 4 => ML_DSA_44_NAME,
503 6 => ML_DSA_65_NAME,
504 8 => ML_DSA_87_NAME,
505 _ => panic!("Unsupported key length"),
506 };
507 write!(
508 f,
509 "MLDSAPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?}, has_seed: {} }}",
510 alg,
511 self.tr,
512 self.seed.is_some(),
513 )
514 }
515}
516
517impl<const k: usize, const l: usize, const eta: usize, const SK_LEN: usize, const PK_LEN: usize>
519Drop for MLDSAPrivateKey<k, l, eta, SK_LEN, PK_LEN>
520{
521 fn drop(&mut self) {
522 self.K.fill(0u8);
523 }
525}