Skip to main content

bouncycastle_mldsa_lowmemory/
mldsa_keys.rs

1use crate::aux_functions::{bit_pack_eta, bitlen_eta, power_2_round, rej_bounded_poly, rej_ntt_poly, simple_bit_pack_t1, simple_bit_unpack_t1};
2use crate::mldsa::{H, N};
3use crate::{ML_DSA_44_NAME, ML_DSA_65_NAME, ML_DSA_87_NAME};
4use crate::mldsa::{MLDSA44_LAMBDA, MLDSA44_GAMMA2, MLDSA44_ETA, MLDSA44_PK_LEN, MLDSA44_SK_LEN, MLDSA44_k, MLDSA44_l, MLDSA44_S1_PACKED_LEN, MLDSA44_S2_PACKED_LEN};
5use crate::mldsa::{MLDSA65_LAMBDA, MLDSA65_GAMMA2, MLDSA65_ETA, MLDSA65_PK_LEN, MLDSA65_SK_LEN, MLDSA65_k, MLDSA65_l, MLDSA65_S1_PACKED_LEN, MLDSA65_S2_PACKED_LEN};
6use crate::mldsa::{MLDSA87_LAMBDA, MLDSA87_GAMMA2, MLDSA87_ETA, MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_k, MLDSA87_l, MLDSA87_S1_PACKED_LEN, MLDSA87_S2_PACKED_LEN};
7use crate::mldsa::{POLY_T1PACKED_LEN, MLDSA44_T1_PACKED_LEN, MLDSA65_T1_PACKED_LEN, MLDSA87_T1_PACKED_LEN};
8use bouncycastle_core_interface::errors::SignatureError;
9use bouncycastle_core_interface::key_material::{KeyMaterialSized, KeyType};
10use bouncycastle_core_interface::traits::{KeyMaterial, Secret, SecurityStrength, SignaturePrivateKey, SignaturePublicKey, XOF};
11use core::fmt;
12use core::fmt::{Debug, Display, Formatter};
13use crate::low_memory_helpers::s_unpack;
14// imports just for docs
15#[allow(unused_imports)]
16use crate::mldsa::MLDSATrait;
17use crate::polynomial::Polynomial;
18
19
20/* Pub Types */
21
22/// ML-DSA-44 Public Key
23pub type MLDSA44PublicKey = MLDSAPublicKey<MLDSA44_k, MLDSA44_T1_PACKED_LEN, MLDSA44_PK_LEN>;
24/// ML-DSA-44 Private Key
25pub type MLDSA44PrivateKey = MLDSASeedPrivateKey<MLDSA44_LAMBDA, MLDSA44_GAMMA2, MLDSA44_k, MLDSA44_l, MLDSA44_ETA, MLDSA44_S1_PACKED_LEN, MLDSA44_S2_PACKED_LEN, MLDSA44_T1_PACKED_LEN, MLDSA44_PK_LEN, MLDSA44_SK_LEN>;
26/// ML-DSA-65 Public Key
27pub type MLDSA65PublicKey = MLDSAPublicKey<MLDSA65_k, MLDSA65_T1_PACKED_LEN, MLDSA65_PK_LEN>;
28/// ML-DSA-65 Private Key
29pub type MLDSA65PrivateKey = MLDSASeedPrivateKey<MLDSA65_LAMBDA, MLDSA65_GAMMA2, MLDSA65_k, MLDSA65_l, MLDSA65_ETA, MLDSA65_S1_PACKED_LEN, MLDSA65_S2_PACKED_LEN, MLDSA65_T1_PACKED_LEN, MLDSA65_PK_LEN, MLDSA65_SK_LEN>;
30/// ML-DSA-87 Public Key
31pub type MLDSA87PublicKey = MLDSAPublicKey<MLDSA87_k, MLDSA87_T1_PACKED_LEN, MLDSA87_PK_LEN>;
32/// ML-DSA-87 Private Key
33pub type MLDSA87PrivateKey = MLDSASeedPrivateKey<MLDSA87_LAMBDA, MLDSA87_GAMMA2, MLDSA87_k, MLDSA87_l, MLDSA87_ETA, MLDSA87_S1_PACKED_LEN, MLDSA87_S2_PACKED_LEN, MLDSA87_T1_PACKED_LEN, MLDSA87_PK_LEN, MLDSA87_SK_LEN>;
34
35/// An ML-DSA public key.
36#[derive(Clone)]
37pub struct MLDSAPublicKey<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> {
38    pub(crate) rho: [u8; 32],
39    pub(crate) t1_packed: [u8; T1_PACKED_LEN],
40}
41
42/// General trait for all ML-DSA public keys types.
43pub trait MLDSAPublicKeyTrait<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> : SignaturePublicKey {
44    /// Algorithm 22 pkEncode(𝜌, 𝐭1)
45    /// Encodes a public key for ML-DSA into a byte string.
46    /// Input:𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
47    /// Output: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
48    fn pk_encode(&self) -> [u8; PK_LEN];
49
50    /// Algorithm 23 pkDecode(π‘π‘˜)
51    /// Reverses the procedure pkEncode.
52    /// Input: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
53    /// Output: 𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
54    fn pk_decode(pk: &[u8; PK_LEN]) -> Self;
55
56    /// Compute the public key hash (tr) from the public key.
57    ///
58    /// This is exposed as a public API for a few reasons:
59    /// 1. `tr` is required for some external-prehashing schemes such as the so-called "external mu" signing mode.
60    /// 2. `tr` is the canonical fingerprint of an ML-DSA public key, so would be an appropriate value
61    ///     to use, for example, to build a public key lookup or deny-listing table.
62    fn compute_tr(&self) -> [u8; 64];
63}
64
65pub(crate) trait MLDSAPublicKeyInternalTrait<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> {
66    /// Not exposing a constructor publicly because you should have to get an instance either by
67    /// running a keygen, or by decoding an existing key.
68    fn new(rho: &[u8; 32], t1_packed: &[u8; T1_PACKED_LEN]) -> Self;
69
70    /// Get a ref to rho
71    fn rho(&self) -> &[u8; 32];
72
73    /// Get a ref to t1
74    fn unpack_t1_row(&self, row: usize) -> Polynomial;
75}
76
77impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> MLDSAPublicKeyTrait<k, T1_PACKED_LEN, PK_LEN> for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
78    // todo -- I think this becomes trivial
79    // fn pk_encode(&self) -> [u8; PK_LEN] {
80    //     let mut pk = [0u8; PK_LEN];
81    //
82    //     pk[0..32].copy_from_slice(&self.rho);
83    //
84    //     let (pk_chunks, last_chunk) = pk[32..].as_chunks_mut::<POLY_T1PACKED_LEN>();
85    //
86    //     // that should divide evenly the remainder of the array
87    //     debug_assert_eq!(pk_chunks.len(), k);
88    //     debug_assert_eq!(last_chunk.len(), 0);
89    //
90    //     for (pk_chunk, t1_i) in pk_chunks.into_iter().zip(&self.t1.vec) {
91    //         pk_chunk.copy_from_slice(&simple_bit_pack_t1(&t1_i));
92    //     }
93    //
94    //     pk
95    // }
96    fn pk_encode(&self) -> [u8; PK_LEN] {
97        let mut pk = [0u8; PK_LEN];
98        pk[..32].copy_from_slice(&self.rho);
99        pk[32..].copy_from_slice(&self.t1_packed);
100        pk
101    }
102
103    // fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
104    //     let rho = pk[0..32].try_into().unwrap();
105    //     let mut t1 = Vector::<k>::new();
106    //
107    //     let (pk_chunks, last_chunk) = pk[32..].as_chunks::<POLY_T1PACKED_LEN>();
108    //
109    //     // that should divide evenly the remainder of the array
110    //     debug_assert_eq!(pk_chunks.len(), k);
111    //     debug_assert_eq!(last_chunk.len(), 0);
112    //
113    //     for (t1_i, pk_chunk) in t1.vec.iter_mut().zip(pk_chunks) {
114    //         t1_i.0.copy_from_slice(&simple_bit_unpack_t1(pk_chunk).0);
115    //     }
116    //
117    //     Self::new(&rho, &t1)
118    // }
119    fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
120        Self {
121            rho: pk[..32].try_into().unwrap(),
122            t1_packed: pk[32..].try_into().unwrap()
123        }
124    }
125
126    fn compute_tr(&self) -> [u8; 64] {
127        let mut tr = [0u8; 64];
128        H::new().hash_xof_out(&self.pk_encode(), &mut tr);
129
130        tr
131    }
132}
133
134impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> MLDSAPublicKeyInternalTrait<k, T1_PACKED_LEN, PK_LEN> for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
135    fn new(rho: &[u8; 32], t1_packed: &[u8; T1_PACKED_LEN]) -> Self {
136        Self { rho: rho.clone(), t1_packed: t1_packed.clone() }
137    }
138
139    fn rho(&self) -> &[u8; 32] { &self.rho }
140
141    fn unpack_t1_row(&self, row: usize) -> Polynomial {
142        simple_bit_unpack_t1(&self.t1_packed[row * POLY_T1PACKED_LEN .. (row + 1) * POLY_T1PACKED_LEN].try_into().unwrap())
143    }
144}
145
146impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize>  SignaturePublicKey for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
147    fn encode(&self) -> Vec<u8> {
148        Vec::from(self.pk_encode())
149    }
150
151    fn encode_out(&self, out: &mut [u8]) -> Result<usize, SignatureError> {
152        if out.len() < PK_LEN {
153            Err(SignatureError::EncodingError("Output buffer too small"))
154        } else {
155            let tmp = self.pk_encode();
156            debug_assert_eq!(tmp.len(), PK_LEN);
157            out[..PK_LEN].copy_from_slice(&tmp);
158            Ok(PK_LEN)
159        }
160    }
161
162    fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
163        if bytes.len() != PK_LEN { return Err(SignatureError::DecodingError("Provided key bytes are the incorrect length")) }
164        let sized_bytes: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
165        Ok(Self::pk_decode(&sized_bytes))
166    }
167}
168
169impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> Eq for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> { }
170
171impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> PartialEq for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
172    fn eq(&self, other: &Self) -> bool {
173        let self_encoded = self.pk_encode();
174        let other_encoded = other.pk_encode();
175        bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
176    }
177}
178
179impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> fmt::Debug for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        let alg = match k {
182            4 => ML_DSA_44_NAME,
183            6 => ML_DSA_65_NAME,
184            8 => ML_DSA_87_NAME,
185            _ => panic!("Unsupported key length"),
186        };
187        write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
188    }
189}
190
191impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> Display for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
192    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
193        let alg = match k {
194            4 => ML_DSA_44_NAME,
195            6 => ML_DSA_65_NAME,
196            8 => ML_DSA_87_NAME,
197            _ => panic!("Unsupported key length"),
198        };
199        write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
200    }
201}
202
203
204
205/// General trait for all ML-DSA private keys types.
206pub trait MLDSAPrivateKeyTrait<
207    const k: usize,
208    const l: usize,
209    const S1_PACKED_LEN: usize,
210    const S2_PACKED_LEN: usize,
211    const T1_PACKED_LEN: usize,
212    const PK_LEN: usize,
213    const SK_LEN: usize
214> : SignaturePrivateKey {
215    /// New from KeyMaterial. Can throw a SignatureError if the KeyMaterial does not contain sufficient entropy.
216    fn from_keymaterial(seed: &KeyMaterialSized<32>) -> Result<Self, SignatureError>;
217
218    /// Get a ref to the seed, if there is one stored with this private key
219    fn seed(&self) -> &KeyMaterialSized<32>;
220
221    /// Get a copy of the key hash `tr`.
222    /// This is computationally intensive as it requires fully re-computing the public key (and then discarding it).
223    /// It is highly recommended that if you already have a copy of the public key, get `tr` from that,
224    /// or else compute tr once and store it.
225    fn tr(&self) -> [u8; 64];
226
227    /// Returns the full public key, and has the side-effect of setting the public key hash tr in this MLDSASeedSK object.
228    fn derive_pk(&self) -> MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN>;
229    /// Algorithm 24 skEncode(𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0)
230    /// Encodes a secret key for ML-DSA into a byte string.
231    /// Input: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 , 𝐬1 ∈ 𝑅ℓ with coefficients in [βˆ’πœ‚, πœ‚], 𝐬2 ∈ π‘…π‘˜ with
232    /// coefficients in [βˆ’πœ‚, πœ‚], 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
233    /// Output: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((π‘˜+β„“)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
234    fn sk_encode(&self) -> [u8; SK_LEN];
235    /// Algorithm 24 skEncode(𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0)
236    /// Encodes a secret key for ML-DSA into a byte string.
237    /// Input: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 , 𝐬1 ∈ 𝑅ℓ with coefficients in [βˆ’πœ‚, πœ‚], 𝐬2 ∈ π‘…π‘˜ with
238    /// coefficients in [βˆ’πœ‚, πœ‚], 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
239    /// Output: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((π‘˜+β„“)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
240    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize;
241    /// Algorithm 25 skDecode(π‘ π‘˜)
242    /// Reverses the procedure skEncode.
243    /// Input: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((β„“+π‘˜)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
244    /// Output: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 ,
245    /// 𝐬1 ∈ 𝑅ℓ , 𝐬2 ∈ π‘…π‘˜ , 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
246    ///
247    /// Note: this object contains only the simple decoding routine to unpack a semi-expanded key.
248    /// See [MLDSATrait] for key generation functions, including derive-from-seed and consistency-check functions.
249    fn sk_decode(sk: &[u8; SK_LEN]) -> Self;
250}
251
252/// Internal structure for holding a seed-based private key for ML-DSA.
253#[derive(Clone, PartialEq, Eq)]
254pub struct MLDSASeedPrivateKey<
255    const LAMBDA: i32,
256    const GAMMA2: i32,
257    const k: usize,
258    const l: usize,
259    const eta: usize,
260    const S1_PACKED_LEN: usize,
261    const S2_PACKED_LEN: usize,
262    const T1_PACKED_LEN: usize,
263    const PK_LEN: usize,
264    const SK_LEN: usize,
265> {
266    seed: KeyMaterialSized<32>,
267    rho: [u8; 32],
268    rho_prime: [u8; 64],
269    K: [u8; 32],
270    tr: Option<[u8; 64]>,
271}
272
273
274impl<
275    const LAMBDA: i32,
276    const GAMMA2: i32,
277    const k: usize,
278    const l: usize,
279    const eta: usize,
280    const S1_PACKED_LEN: usize,
281    const S2_PACKED_LEN: usize,
282    const T1_PACKED_LEN: usize,
283    const SK_LEN: usize,
284    const PK_LEN: usize,
285>  Drop for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN,> {
286    fn drop(&mut self) {
287        // seed is a KeyMaterialSized which will zeroize itself
288        self.rho.fill(0u8);
289        self.rho_prime.fill(0u8);
290        self.K.fill(0u8);
291        if self.tr.is_some() {
292            self.tr.unwrap().as_mut().fill(0u8);
293            debug_assert_eq!(&self.tr.unwrap(), &[0u8; 64]);
294        }
295    }
296}
297
298impl<
299    const LAMBDA: i32,
300    const GAMMA2: i32,
301    const k: usize,
302    const l: usize,
303    const eta: usize,
304    const S1_PACKED_LEN: usize,
305    const S2_PACKED_LEN: usize,
306    const T1_PACKED_LEN: usize,
307    const PK_LEN: usize,
308    const SK_LEN: usize,
309> Secret for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {}
310
311impl<
312    const LAMBDA: i32,
313    const GAMMA2: i32,
314    const k: usize,
315    const l: usize,
316    const eta: usize,
317    const S1_PACKED_LEN: usize,
318    const S2_PACKED_LEN: usize,
319    const T1_PACKED_LEN: usize,
320    const PK_LEN: usize,
321    const SK_LEN: usize,
322> Debug for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {
323    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
324        let alg = match k {
325            4 => ML_DSA_44_NAME,
326            6 => ML_DSA_65_NAME,
327            8 => ML_DSA_87_NAME,
328            _ => panic!("Unsupported key length"),
329        };
330        write!(
331            f,
332            "MLDSASeedPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?} }}",
333            alg,
334            self.tr(),
335        )
336    }
337}
338
339impl<
340    const LAMBDA: i32,
341    const GAMMA2: i32,
342    const k: usize,
343    const l: usize,
344    const eta: usize,
345    const S1_PACKED_LEN: usize,
346    const S2_PACKED_LEN: usize,
347    const T1_PACKED_LEN: usize,
348    const PK_LEN: usize,
349    const SK_LEN: usize,
350> Display for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {
351    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
352        let alg = match k {
353            4 => ML_DSA_44_NAME,
354            6 => ML_DSA_65_NAME,
355            8 => ML_DSA_87_NAME,
356            _ => panic!("Unsupported key length"),
357        };
358        write!(
359            f,
360            "MLDSASeedPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?} }}",
361            alg,
362            self.tr(),
363        )
364    }
365}
366
367impl<
368    const LAMBDA: i32,
369    const GAMMA2: i32,
370    const k: usize,
371    const l: usize,
372    const eta: usize,
373    const S1_PACKED_LEN: usize,
374    const S2_PACKED_LEN: usize,
375    const T1_PACKED_LEN: usize,
376    const PK_LEN: usize,
377    const SK_LEN: usize,
378> MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {
379    /// Create a new MLDSASeedPrivateKey from a 32-byte KeyMaterial.
380    pub fn new(seed: &KeyMaterialSized<32>) -> Result<Self, SignatureError> {
381        if !(seed.key_type() == KeyType::Seed || seed.key_type() == KeyType::BytesFullEntropy)
382                || seed.key_len() != 32
383        {
384            return Err(SignatureError::KeyGenError(
385                "Seed must be 32 bytes and KeyType::Seed or KeyType::BytesFullEntropy.",
386            ));
387        }
388
389        if seed.security_strength() < SecurityStrength::from_bits(LAMBDA as usize) {
390            return Err(SignatureError::KeyGenError("Seed SecurityStrength must match algorithm security strength: 128-bit (ML-DSA-44), 192-bit (ML-DSA-65), or 256-bit (ML-DSA-87)."));
391        }
392
393        let (rho, rho_prime, K) = Self::compute_rhos_and_K(&seed);
394        Ok(Self { seed: seed.clone(), rho, rho_prime, K, tr: None, })
395    }
396
397    fn compute_rhos_and_K(seed: &KeyMaterialSized<32>) -> ([u8; 32], [u8; 64], [u8; 32]) {
398        // derive sk.K
399        // Alg 6; 1: (rho, rho_prime, K) <- H(πœ‰||IntegerToBytes(π‘˜, 1)||IntegerToBytes(β„“, 1), 128)
400        //   β–· expand seed
401        let mut rho: [u8; 32] = [0u8; 32];
402        let mut rho_prime: [u8; 64] = [0u8; 64];
403        let mut K: [u8; 32] = [0u8; 32];
404
405        let mut h = H::default();
406        h.absorb(seed.ref_to_bytes());
407        h.absorb(&(k as u8).to_le_bytes());
408        h.absorb(&(l as u8).to_le_bytes());
409        let bytes_written = h.squeeze_out(&mut rho);
410        debug_assert_eq!(bytes_written, 32);
411        let bytes_written = h.squeeze_out(&mut rho_prime);
412        debug_assert_eq!(bytes_written, 64);
413        let bytes_written = h.squeeze_out(&mut K);
414        debug_assert_eq!(bytes_written, 32);
415
416        (rho, rho_prime, K)
417    }
418
419    fn compute_t_row(
420        &self,
421        idx: usize,
422        s1_packed: &[u8],
423        s2_packed: &[u8],
424    ) -> Polynomial {
425        debug_assert!(idx < k);
426
427        // [Optimization Note]:
428        // This is one of the places that a row of s1 can be re-computed instead of expanded from the compressed form.
429        // let mut s1 = self.compute_s1_row(0);
430        let mut s1_hat = s_unpack::<eta>(s1_packed, 0);
431        s1_hat.ntt();
432
433        let mut t_hat = rej_ntt_poly(&self.rho, &[0u8, idx as u8]);
434            // polynomial::multiply_ntt(&rej_ntt_poly(&self.rho, &[0u8, idx as u8]), &s1_hat);
435        t_hat.multiply_ntt(&s1_hat);
436
437        for col in 1..l {
438            // [Optimization Note]:
439            // This is one of the places that a row of s1 can be re-computed instead of expanded from the compressed form.
440            // s1 = self.compute_s1_row(col);
441            let mut s1_hat = s_unpack::<eta>(s1_packed, col);
442            s1_hat.ntt();
443            // let tmp = polynomial::multiply_ntt(
444            //     // [Optimization Note]:
445            //     // this is reconstructing a row of the public matrix A_hat,
446            //     // which nobody is proposing to keep in memory.
447            //     &rej_ntt_poly(&self.rho, &[col as u8, idx as u8]),
448            //     &s1_hat,
449            // );
450            let mut tmp = rej_ntt_poly(&self.rho, &[col as u8, idx as u8]);
451            tmp.multiply_ntt(&s1_hat);
452            t_hat.add_ntt(&tmp);
453        }
454
455        t_hat.inv_ntt();
456        let mut t = t_hat;
457        // [Optimization Note]:
458        // This is one of the places that a row of s2 can be re-computed instead of unpacked from the compressed form.
459        // let s2 = self.compute_s2_row(idx);
460        let s2 = s_unpack::<eta>(s2_packed, idx);
461        t.add_ntt(&s2);
462        t.conditional_add_q();
463
464        t
465    }
466}
467
468impl<
469    const LAMBDA: i32,
470    const GAMMA2: i32,
471    const k: usize,
472    const l: usize,
473    const eta: usize,
474    const S1_PACKED_LEN: usize,
475    const S2_PACKED_LEN: usize,
476    const T1_PACKED_LEN: usize,
477    const PK_LEN: usize,
478    const SK_LEN: usize,
479> SignaturePrivateKey for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {
480    fn encode(&self) -> Vec<u8> {
481        let mut out = [0u8; 32];
482        out.copy_from_slice(self.seed.ref_to_bytes());
483        out.to_vec()
484    }
485
486    fn encode_out(&self, out: &mut [u8]) -> Result<usize, SignatureError> {
487        if out.len() < 32 {
488            return Err(SignatureError::EncodingError("Output buffer too small"));
489        }
490        out[..32].copy_from_slice(self.seed.ref_to_bytes());
491        Ok(32)
492
493    }
494
495    fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
496        if bytes.len() != 32 {
497            return Err(SignatureError::DecodingError("Invalid seed length"));
498        }
499        let mut keymat = KeyMaterialSized::<32>::from_bytes(bytes)?;
500        keymat.allow_hazardous_operations();
501        keymat.set_key_type(KeyType::Seed)?;
502        keymat.set_security_strength(SecurityStrength::_256bit)?;
503        keymat.drop_hazardous_operations();
504
505        Self::new(&keymat)
506    }
507}
508
509impl<
510    const LAMBDA: i32,
511    const GAMMA2: i32,
512    const k: usize,
513    const l: usize,
514    const eta: usize,
515    const S1_PACKED_LEN: usize,
516    const S2_PACKED_LEN: usize,
517    const T1_PACKED_LEN: usize,
518    const PK_LEN: usize,
519    const SK_LEN: usize,
520> MLDSAPrivateKeyTrait<k, l, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN>
521for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {
522    fn from_keymaterial(seed: &KeyMaterialSized<32>) -> Result<Self, SignatureError> {
523        Self::new(seed)
524    }
525
526    fn seed(&self) -> &KeyMaterialSized<32> { &self.seed }
527
528    fn tr(&self) -> [u8; 64] {
529        let pk: MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> = self.derive_pk();
530        pk.compute_tr()
531    }
532
533    fn derive_pk(&self) -> MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
534        // The goal here is to get t1, which we will build and compress one row at a time.
535
536        let s1_packed: [u8; S1_PACKED_LEN] = self.compute_s1_packed();
537        let s2_packed: [u8; S2_PACKED_LEN] = self.compute_s2_packed();
538
539        let mut t1_packed = [0u8; T1_PACKED_LEN];
540        debug_assert_eq!(T1_PACKED_LEN, POLY_T1PACKED_LEN * k);
541
542        for i in 0..k {
543            t1_packed[i * POLY_T1PACKED_LEN .. (i+1) * POLY_T1PACKED_LEN]
544                .copy_from_slice(
545                    &simple_bit_pack_t1(&self.compute_t1_row(i, &s1_packed, &s2_packed))
546                );
547        }
548
549        MLDSAPublicKey::<k, T1_PACKED_LEN, PK_LEN>::new(&self.rho, &t1_packed)
550    }
551
552    fn sk_encode(&self) -> [u8; SK_LEN] {
553       self.seed.ref_to_bytes().try_into().unwrap()
554    }
555
556    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
557        out.copy_from_slice(self.seed.ref_to_bytes());
558
559        SK_LEN
560    }
561    fn sk_decode(sk: &[u8; SK_LEN]) -> Self {
562        Self::from_bytes(sk).unwrap()
563    }
564}
565
566pub(crate) trait MLDSAPrivateKeyInternalTrait<
567    const LAMBDA: i32,
568    const GAMMA2: i32,
569    const k: usize,
570    const l: usize,
571    const eta: usize,
572    const S1_PACKED_LEN: usize,
573    const S2_PACKED_LEN: usize,
574    const PK_LEN: usize,
575    const SK_LEN: usize,
576> : Sized
577{
578    fn rho(&self) -> &[u8; 32];
579    fn K(&self) -> &[u8; 32];
580
581    fn compute_s1_row(
582        &self,
583        idx: usize,
584    ) -> Polynomial;
585
586    fn compute_s1_packed(&self) -> [u8; S1_PACKED_LEN];
587
588    fn compute_s2_row(
589        &self,
590        idx: usize,
591    ) -> Polynomial;
592
593    fn compute_s2_packed(&self) -> [u8; S2_PACKED_LEN];
594
595    fn compute_t0_row(
596        &self,
597        idx: usize,
598        s1_packed: &[u8],
599        s2_packed: &[u8],
600    ) -> Polynomial;
601
602    fn compute_t1_row(
603        &self,
604        idx: usize,
605        s1_packed: &[u8],
606        s2_packed: &[u8],
607    ) -> Polynomial;
608}
609
610impl<
611    const LAMBDA: i32,
612    const GAMMA2: i32,
613    const k: usize,
614    const l: usize,
615    const eta: usize,
616    const S1_PACKED_LEN: usize,
617    const S2_PACKED_LEN: usize,
618    const T1_PACKED_LEN: usize,
619    const PK_LEN: usize,
620    const SK_LEN: usize,
621> MLDSAPrivateKeyInternalTrait<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, PK_LEN, SK_LEN>
622for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN> {
623    fn rho(&self) -> &[u8; 32] {
624        &self.rho
625    }
626
627    fn K(&self) -> &[u8; 32] {
628        &self.K
629    }
630
631    fn compute_s1_row(
632        &self,
633        idx: usize,
634    ) -> Polynomial {
635        debug_assert!(idx < l);
636        rej_bounded_poly::<eta>(&self.rho_prime, &(idx as u16).to_le_bytes())
637    }
638
639    fn compute_s1_packed(&self) -> [u8; S1_PACKED_LEN] {
640        let mut s1_packed = [0u8; S1_PACKED_LEN];
641        for idx in 0..l {
642            let s1_i = self.compute_s1_row(idx);
643            bit_pack_eta::<eta>(&s1_i, &mut s1_packed[idx * bitlen_eta(eta)..(idx + 1) * bitlen_eta(eta)]);
644        }
645        s1_packed
646    }
647
648    fn compute_s2_row(
649        &self,
650        idx: usize,
651    ) -> Polynomial {
652        debug_assert!(idx < k);
653        rej_bounded_poly::<eta>(&self.rho_prime, &((idx + l) as u16).to_le_bytes())
654    }
655
656    fn compute_s2_packed(&self) -> [u8; S2_PACKED_LEN] {
657        let mut s2_packed = [0u8; S2_PACKED_LEN];
658        for idx in 0..k {
659            let s2_i = self.compute_s2_row(idx);
660            bit_pack_eta::<eta>(&s2_i, &mut s2_packed[idx * bitlen_eta(eta)..(idx + 1) * bitlen_eta(eta)]);
661        }
662        s2_packed
663    }
664
665    fn compute_t0_row(
666        &self,
667        idx: usize,
668        s1_packed: &[u8],
669        s2_packed: &[u8],
670    ) -> Polynomial {
671        let mut t0 = self.compute_t_row(idx, s1_packed, s2_packed);
672        for j in 0..N {
673            (_, t0.0[j]) = power_2_round(t0.0[j]);
674        }
675
676        t0
677    }
678
679    fn compute_t1_row(
680        &self,
681        idx: usize,
682        s1_packed: &[u8],
683        s2_packed: &[u8],
684    ) -> Polynomial {
685        let mut t1 = self.compute_t_row(idx, s1_packed, s2_packed);
686        for j in 0..N {
687            (t1.0[j], _) = power_2_round(t1.0[j]);
688        }
689
690        t1
691    }
692}
693