Skip to main content

bouncycastle_mldsa_lowmemory/
mldsa_keys.rs

1use crate::aux_functions::{bit_pack_eta, bit_pack_t0, bitlen_eta, power_2_round, rej_bounded_poly, simple_bit_pack_t1, simple_bit_unpack_t1};
2use crate::mldsa::{H, N, POLY_T0PACKED_LEN};
3use crate::{ML_DSA_44_NAME, ML_DSA_65_NAME, ML_DSA_87_NAME};
4use crate::mldsa::{MLDSA44_LAMBDA, MLDSA44_GAMMA2, MLDSA44_ETA, MLDSA44_PK_LEN, MLDSA44_SK_LEN, MLDSA44_FULL_SK_LEN, MLDSA44_k, MLDSA44_l, MLDSA44_S1_PACKED_LEN, MLDSA44_S2_PACKED_LEN};
5use crate::mldsa::{MLDSA65_LAMBDA, MLDSA65_GAMMA2, MLDSA65_ETA, MLDSA65_PK_LEN, MLDSA65_SK_LEN, MLDSA65_FULL_SK_LEN, MLDSA65_k, MLDSA65_l, MLDSA65_S1_PACKED_LEN, MLDSA65_S2_PACKED_LEN};
6use crate::mldsa::{MLDSA87_LAMBDA, MLDSA87_GAMMA2, MLDSA87_ETA, MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_FULL_SK_LEN, MLDSA87_k, MLDSA87_l, MLDSA87_S1_PACKED_LEN, MLDSA87_S2_PACKED_LEN};
7use crate::mldsa::{POLY_T1PACKED_LEN, MLDSA44_T1_PACKED_LEN, MLDSA65_T1_PACKED_LEN, MLDSA87_T1_PACKED_LEN};
8use crate::low_memory_helpers::{expandA_elem, s_unpack};
9use bouncycastle_core::errors::SignatureError;
10use bouncycastle_core::key_material::{KeyMaterial, KeyMaterialTrait, KeyType};
11use bouncycastle_core::traits::{Secret, SecurityStrength, SignaturePrivateKey, SignaturePublicKey, XOF};
12use core::fmt;
13use core::fmt::{Debug, Display, Formatter};
14
15// imports just for docs
16#[allow(unused_imports)]
17use crate::mldsa::MLDSATrait;
18use crate::polynomial::Polynomial;
19
20
21/* Pub Types */
22
23/// ML-DSA-44 Public Key
24pub type MLDSA44PublicKey = MLDSAPublicKey<MLDSA44_k, MLDSA44_T1_PACKED_LEN, MLDSA44_PK_LEN>;
25/// ML-DSA-44 Private Key
26pub type MLDSA44PrivateKey = MLDSASeedPrivateKey<MLDSA44_LAMBDA, MLDSA44_GAMMA2, MLDSA44_k, MLDSA44_l, MLDSA44_ETA, MLDSA44_S1_PACKED_LEN, MLDSA44_S2_PACKED_LEN, MLDSA44_T1_PACKED_LEN, MLDSA44_PK_LEN, MLDSA44_SK_LEN, MLDSA44_FULL_SK_LEN>;
27/// ML-DSA-65 Public Key
28pub type MLDSA65PublicKey = MLDSAPublicKey<MLDSA65_k, MLDSA65_T1_PACKED_LEN, MLDSA65_PK_LEN>;
29/// ML-DSA-65 Private Key
30pub type MLDSA65PrivateKey = MLDSASeedPrivateKey<MLDSA65_LAMBDA, MLDSA65_GAMMA2, MLDSA65_k, MLDSA65_l, MLDSA65_ETA, MLDSA65_S1_PACKED_LEN, MLDSA65_S2_PACKED_LEN, MLDSA65_T1_PACKED_LEN, MLDSA65_PK_LEN, MLDSA65_SK_LEN, MLDSA65_FULL_SK_LEN>;
31/// ML-DSA-87 Public Key
32pub type MLDSA87PublicKey = MLDSAPublicKey<MLDSA87_k, MLDSA87_T1_PACKED_LEN, MLDSA87_PK_LEN>;
33/// ML-DSA-87 Private Key
34pub type MLDSA87PrivateKey = MLDSASeedPrivateKey<MLDSA87_LAMBDA, MLDSA87_GAMMA2, MLDSA87_k, MLDSA87_l, MLDSA87_ETA, MLDSA87_S1_PACKED_LEN, MLDSA87_S2_PACKED_LEN, MLDSA87_T1_PACKED_LEN, MLDSA87_PK_LEN, MLDSA87_SK_LEN, MLDSA87_FULL_SK_LEN>;
35
36/// An ML-DSA public key.
37#[derive(Clone)]
38pub struct MLDSAPublicKey<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> {
39    pub(crate) rho: [u8; 32],
40    pub(crate) t1_packed: [u8; T1_PACKED_LEN],
41}
42
43/// General trait for all ML-DSA public keys types.
44pub trait MLDSAPublicKeyTrait<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> : SignaturePublicKey<PK_LEN> {
45    /// Algorithm 23 pkDecode(π‘π‘˜)
46    /// Reverses the procedure pkEncode.
47    /// Input: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
48    /// Output: 𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
49    fn pk_decode(pk: &[u8; PK_LEN]) -> Self;
50
51    /// Compute the public key hash (tr) from the public key.
52    ///
53    /// This is exposed as a public API for a few reasons:
54    /// 1. `tr` is required for some external-prehashing schemes such as the so-called "external mu" signing mode.
55    /// 2. `tr` is the canonical fingerprint of an ML-DSA public key, so would be an appropriate value
56    ///     to use, for example, to build a public key lookup or deny-listing table.
57    fn compute_tr(&self) -> [u8; 64];
58}
59
60pub(crate) trait MLDSAPublicKeyInternalTrait<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> {
61    /// Not exposing a constructor publicly because you should have to get an instance either by
62    /// running a keygen, or by decoding an existing key.
63    fn new(rho: [u8; 32], t1_packed: [u8; T1_PACKED_LEN]) -> Self;
64
65    /// Get a ref to rho
66    fn rho(&self) -> &[u8; 32];
67
68    /// Get a ref to t1
69    fn unpack_t1_row(&self, row: usize) -> Polynomial;
70}
71
72impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> MLDSAPublicKeyTrait<k, T1_PACKED_LEN, PK_LEN> for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
73    fn pk_decode(pk: &[u8; PK_LEN]) -> Self {
74        Self {
75            rho: pk[..32].try_into().unwrap(),
76            t1_packed: pk[32..].try_into().unwrap()
77        }
78    }
79
80    fn compute_tr(&self) -> [u8; 64] {
81        let mut tr = [0u8; 64];
82        H::new().hash_xof_out(&self.encode(), &mut tr);
83
84        tr
85    }
86}
87
88impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> MLDSAPublicKeyInternalTrait<k, T1_PACKED_LEN, PK_LEN> for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
89    fn new(rho: [u8; 32], t1_packed: [u8; T1_PACKED_LEN]) -> Self {
90        Self { rho, t1_packed }
91    }
92
93    fn rho(&self) -> &[u8; 32] { &self.rho }
94
95    fn unpack_t1_row(&self, row: usize) -> Polynomial {
96        simple_bit_unpack_t1(&self.t1_packed[row * POLY_T1PACKED_LEN .. (row + 1) * POLY_T1PACKED_LEN].try_into().unwrap())
97    }
98}
99
100impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize>  SignaturePublicKey<PK_LEN> for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
101    /// Algorithm 22 pkEncode(𝜌, 𝐭1)
102    /// Encodes a public key for ML-DSA into a byte string.
103    /// Input:𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
104    /// Output: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
105    fn encode(&self) -> [u8; PK_LEN] {
106        let mut pk = [0u8; PK_LEN];
107        self.encode_out(&mut pk);
108
109        pk
110    }
111    /// Algorithm 22 pkEncode(𝜌, 𝐭1)
112    /// Encodes a public key for ML-DSA into a byte string.
113    /// Input:𝜌 ∈ 𝔹32, 𝐭1 ∈ π‘…π‘˜ with coefficients in [0, 2bitlen (π‘žβˆ’1)βˆ’π‘‘ βˆ’ 1].
114    /// Output: Public key π‘π‘˜ ∈ 𝔹32+32π‘˜(bitlen (π‘žβˆ’1)βˆ’π‘‘).
115    fn encode_out(&self, out: &mut [u8; PK_LEN]) -> usize {
116        debug_assert_eq!(out.len(), PK_LEN);
117
118        out[..32].copy_from_slice(&self.rho);
119        out[32..].copy_from_slice(&self.t1_packed);
120
121        PK_LEN
122    }
123
124    fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
125        if bytes.len() != PK_LEN { return Err(SignatureError::DecodingError("Provided key bytes are the incorrect length")) }
126        let sized_bytes: [u8; PK_LEN] = bytes[..PK_LEN].try_into().unwrap();
127        Ok(Self::pk_decode(&sized_bytes))
128    }
129}
130
131impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> Eq for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> { }
132
133impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> PartialEq for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
134    fn eq(&self, other: &Self) -> bool {
135        let self_encoded = self.encode();
136        let other_encoded = other.encode();
137        bouncycastle_utils::ct::ct_eq_bytes(self_encoded.as_ref(), other_encoded.as_ref())
138    }
139}
140
141impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> fmt::Debug for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        let alg = match k {
144            4 => ML_DSA_44_NAME,
145            6 => ML_DSA_65_NAME,
146            8 => ML_DSA_87_NAME,
147            _ => panic!("Unsupported key length"),
148        };
149        write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
150    }
151}
152
153impl<const k: usize, const T1_PACKED_LEN: usize, const PK_LEN: usize> Display for MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
154    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
155        let alg = match k {
156            4 => ML_DSA_44_NAME,
157            6 => ML_DSA_65_NAME,
158            8 => ML_DSA_87_NAME,
159            _ => panic!("Unsupported key length"),
160        };
161        write!(f, "MLDSAPublicKey {{ alg: {}, pub_key_hash (tr): {:x?} }}", alg, self.compute_tr(),)
162    }
163}
164
165
166
167/// General trait for all ML-DSA private keys types.
168pub trait MLDSAPrivateKeyTrait<
169    const k: usize,
170    const l: usize,
171    const S1_PACKED_LEN: usize,
172    const S2_PACKED_LEN: usize,
173    const T1_PACKED_LEN: usize,
174    const PK_LEN: usize,
175    const SK_LEN: usize,
176    const FULL_SK_LEN: usize,
177> : SignaturePrivateKey<SK_LEN> {
178    /// New from KeyMaterial. Can throw a SignatureError if the KeyMaterial does not contain sufficient entropy.
179    fn from_keymaterial(seed: &KeyMaterial<32>) -> Result<Self, SignatureError>;
180
181    /// Get a ref to the seed, if there is one stored with this private key
182    fn seed(&self) -> &KeyMaterial<32>;
183
184    /// Get a copy of the key hash `tr`.
185    /// This is computationally intensive as it requires fully re-computing the public key (and then discarding it).
186    /// It is highly recommended that if you already have a copy of the public key, get `tr` from that,
187    /// or else compute tr once and store it.
188    fn tr(&self) -> [u8; 64];
189
190    /// Returns the full public key, and has the side-effect of setting the public key hash tr in this MLDSASeedSK object.
191    fn derive_pk(&self) -> MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN>;
192    /// Algorithm 24 skEncode(𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0)
193    /// Encodes a secret key for ML-DSA into a byte string.
194    /// Input: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 , 𝐬1 ∈ 𝑅ℓ with coefficients in [βˆ’πœ‚, πœ‚], 𝐬2 ∈ π‘…π‘˜ with
195    /// coefficients in [βˆ’πœ‚, πœ‚], 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
196    /// Output: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((π‘˜+β„“)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
197    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize;
198    /// This produces the full private key in the encoding specified in FIPS 204 Algorithm 24 skEncode()
199    /// so that it is compatible with other implementations.
200    ///
201    /// Note that since this encoding does not include the seed, this is a one-way operation;
202    /// after exporting in this encoding, it will be impossible to re-import it into a [MLDSASeedPrivateKey].
203    fn encode_full_sk(&self) -> [u8; FULL_SK_LEN] {
204        let mut out = [0; FULL_SK_LEN];
205        self.encode_full_sk_out(&mut out);
206
207        out
208    }
209    /// This produces the full private key in the encoding specified in FIPS 204 Algorithm 24 skEncode()
210    /// so that it is compatible with other implementations.
211    ///
212    /// Note that since this encoding does not include the seed, this is a one-way operation;
213    /// after exporting in this encoding, it will be impossible to re-import it into a [MLDSASeedPrivateKey].
214    fn encode_full_sk_out(&self, out: &mut [u8; FULL_SK_LEN]);
215    /// Algorithm 25 skDecode(π‘ π‘˜)
216    /// Reverses the procedure skEncode.
217    /// Input: Private key π‘ π‘˜ ∈ 𝔹32+32+64+32β‹…((β„“+π‘˜)β‹…bitlen (2πœ‚)+π‘‘π‘˜).
218    /// Output: 𝜌 ∈ 𝔹32, 𝐾 ∈ 𝔹32, π‘‘π‘Ÿ ∈ 𝔹64 ,
219    /// 𝐬1 ∈ 𝑅ℓ , 𝐬2 ∈ π‘…π‘˜ , 𝐭0 ∈ π‘…π‘˜ with coefficients in [βˆ’2π‘‘βˆ’1 + 1, 2π‘‘βˆ’1].
220    ///
221    /// Note: this object contains only the simple decoding routine to unpack a semi-expanded key.
222    /// See [MLDSATrait] for key generation functions, including derive-from-seed and consistency-check functions.
223    fn sk_decode(sk: &[u8; SK_LEN]) -> Self;
224}
225
226/// Internal structure for holding a seed-based private key for ML-DSA.
227#[derive(Clone, PartialEq, Eq)]
228pub struct MLDSASeedPrivateKey<
229    const LAMBDA: i32,
230    const GAMMA2: i32,
231    const k: usize,
232    const l: usize,
233    const eta: usize,
234    const S1_PACKED_LEN: usize,
235    const S2_PACKED_LEN: usize,
236    const T1_PACKED_LEN: usize,
237    const PK_LEN: usize,
238    const SK_LEN: usize,
239    const FULL_SK_LEN: usize,
240> {
241    seed: KeyMaterial<32>,
242    rho: [u8; 32],
243    rho_prime: [u8; 64],
244    K: [u8; 32],
245}
246
247
248impl<
249    const LAMBDA: i32,
250    const GAMMA2: i32,
251    const k: usize,
252    const l: usize,
253    const eta: usize,
254    const S1_PACKED_LEN: usize,
255    const S2_PACKED_LEN: usize,
256    const T1_PACKED_LEN: usize,
257    const SK_LEN: usize,
258    const PK_LEN: usize,
259    const FULL_SK_LEN: usize,
260>  Drop for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN, > {
261    fn drop(&mut self) {
262        // seed is a KeyMaterialSized which will zeroize itself
263        self.rho.fill(0u8);
264        self.rho_prime.fill(0u8);
265        self.K.fill(0u8);
266    }
267}
268
269impl<
270    const LAMBDA: i32,
271    const GAMMA2: i32,
272    const k: usize,
273    const l: usize,
274    const eta: usize,
275    const S1_PACKED_LEN: usize,
276    const S2_PACKED_LEN: usize,
277    const T1_PACKED_LEN: usize,
278    const PK_LEN: usize,
279    const SK_LEN: usize,
280    const FULL_SK_LEN: usize,
281> Secret for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {}
282
283impl<
284    const LAMBDA: i32,
285    const GAMMA2: i32,
286    const k: usize,
287    const l: usize,
288    const eta: usize,
289    const S1_PACKED_LEN: usize,
290    const S2_PACKED_LEN: usize,
291    const T1_PACKED_LEN: usize,
292    const PK_LEN: usize,
293    const SK_LEN: usize,
294    const FULL_SK_LEN: usize,
295> Debug for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {
296    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
297        let alg = match k {
298            4 => ML_DSA_44_NAME,
299            6 => ML_DSA_65_NAME,
300            8 => ML_DSA_87_NAME,
301            _ => panic!("Unsupported key length"),
302        };
303        write!(
304            f,
305            "MLDSASeedPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?} }}",
306            alg,
307            self.tr(),
308        )
309    }
310}
311
312impl<
313    const LAMBDA: i32,
314    const GAMMA2: i32,
315    const k: usize,
316    const l: usize,
317    const eta: usize,
318    const S1_PACKED_LEN: usize,
319    const S2_PACKED_LEN: usize,
320    const T1_PACKED_LEN: usize,
321    const PK_LEN: usize,
322    const SK_LEN: usize,
323    const FULL_SK_LEN: usize,
324> Display for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {
325    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
326        let alg = match k {
327            4 => ML_DSA_44_NAME,
328            6 => ML_DSA_65_NAME,
329            8 => ML_DSA_87_NAME,
330            _ => panic!("Unsupported key length"),
331        };
332        write!(
333            f,
334            "MLDSASeedPrivateKey {{ alg: {}, pub_key_hash (tr): {:x?} }}",
335            alg,
336            self.tr(),
337        )
338    }
339}
340
341impl<
342    const LAMBDA: i32,
343    const GAMMA2: i32,
344    const k: usize,
345    const l: usize,
346    const eta: usize,
347    const S1_PACKED_LEN: usize,
348    const S2_PACKED_LEN: usize,
349    const T1_PACKED_LEN: usize,
350    const PK_LEN: usize,
351    const SK_LEN: usize,
352    const FULL_SK_LEN: usize,
353> MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {
354    /// Create a new MLDSASeedPrivateKey from a 32-byte KeyMaterial.
355    /// Seed SecurityStrength must match algorithm security strength: 128-bit (ML-DSA-44), 192-bit (ML-DSA-65), or 256-bit (ML-DSA-87),
356    /// otherwise it throws a SignatureError::KeyGenError("SecurityStrength".
357    pub fn new(seed: &KeyMaterial<32>) -> Result<Self, SignatureError> {
358        if !(seed.key_type() == KeyType::Seed || seed.key_type() == KeyType::BytesFullEntropy)
359                || seed.key_len() != 32
360        {
361            return Err(SignatureError::KeyGenError(
362                "Seed must be 32 bytes and KeyType::Seed or KeyType::BytesFullEntropy.",
363            ));
364        }
365
366        if seed.security_strength() < SecurityStrength::from_bits(LAMBDA as usize) {
367            return Err(SignatureError::KeyGenError("SecurityStrength"));
368        }
369
370        let (rho, rho_prime, K) = Self::compute_rhos_and_K(&seed);
371        Ok(Self { seed: seed.clone(), rho, rho_prime, K})
372    }
373
374    fn compute_rhos_and_K(seed: &KeyMaterial<32>) -> ([u8; 32], [u8; 64], [u8; 32]) {
375        // derive sk.K
376        // Alg 6; 1: (rho, rho_prime, K) <- H(πœ‰||IntegerToBytes(π‘˜, 1)||IntegerToBytes(β„“, 1), 128)
377        //   β–· expand seed
378        let mut rho: [u8; 32] = [0u8; 32];
379        let mut rho_prime: [u8; 64] = [0u8; 64];
380        let mut K: [u8; 32] = [0u8; 32];
381
382        let mut h = H::default();
383        h.absorb(seed.ref_to_bytes());
384        h.absorb(&(k as u8).to_le_bytes());
385        h.absorb(&(l as u8).to_le_bytes());
386        let bytes_written = h.squeeze_out(&mut rho);
387        debug_assert_eq!(bytes_written, 32);
388        let bytes_written = h.squeeze_out(&mut rho_prime);
389        debug_assert_eq!(bytes_written, 64);
390        let bytes_written = h.squeeze_out(&mut K);
391        debug_assert_eq!(bytes_written, 32);
392
393        (rho, rho_prime, K)
394    }
395
396    fn compute_t_row(
397        &self,
398        idx: usize,
399        s1_packed: &[u8],
400        s2_packed: &[u8],
401    ) -> Polynomial {
402        debug_assert!(idx < k);
403
404        // [Optimization Note]:
405        // This is one of the places that a row of s1 can be re-computed instead of expanded from the compressed form.
406        // let mut s1 = self.compute_s1_row(0);
407        let mut s1_hat_i = s_unpack::<eta>(s1_packed, 0);
408        s1_hat_i.ntt();
409
410        let mut t_i = {
411            let mut t_hat_i = expandA_elem(&self.rho, idx, 0);
412            t_hat_i.multiply_ntt(&s1_hat_i);
413
414            for col in 1..l {
415                // [Optimization Note]:
416                // This is one of the places that a row of s1 can be re-computed instead of expanded from the compressed form.
417                // s1 = self.compute_s1_row(col);
418                let mut s1_hat = s_unpack::<eta>(s1_packed, col);
419                s1_hat.ntt();
420                // let tmp = polynomial::multiply_ntt(
421                //     // [Optimization Note]:
422                //     // this is reconstructing a row of the public matrix A_hat,
423                //     // which nobody is proposing to keep in memory.
424                //     &rej_ntt_poly(&self.rho, &[col as u8, idx as u8]),
425                //     &s1_hat,
426                // );
427                let mut A_elem = expandA_elem(&self.rho, idx, col);
428                A_elem.multiply_ntt(&s1_hat);
429                t_hat_i.add_ntt(&A_elem);
430            }
431            t_hat_i.inv_ntt();
432
433            t_hat_i
434        };
435
436        // [Optimization Note]:
437        // This is one of the places that a row of s2 can be re-computed instead of unpacked from the compressed form.
438        // let s2 = self.compute_s2_row(idx);
439        let s2 = s_unpack::<eta>(s2_packed, idx);
440        t_i.add_ntt(&s2);
441        t_i.conditional_add_q();
442
443        t_i
444    }
445}
446
447impl<
448    const LAMBDA: i32,
449    const GAMMA2: i32,
450    const k: usize,
451    const l: usize,
452    const eta: usize,
453    const S1_PACKED_LEN: usize,
454    const S2_PACKED_LEN: usize,
455    const T1_PACKED_LEN: usize,
456    const PK_LEN: usize,
457    const SK_LEN: usize,
458    const FULL_SK_LEN: usize,
459> SignaturePrivateKey<SK_LEN> for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {
460    /// Encodes the private key seed.
461    fn encode(&self) -> [u8; SK_LEN] {
462        debug_assert_eq!(SK_LEN, /* seed */ 32);
463
464        self.seed.ref_to_bytes().try_into().unwrap()
465    }
466
467    fn encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
468        out.copy_from_slice(self.seed.ref_to_bytes());
469
470        debug_assert_eq!(self.seed.ref_to_bytes().len(), SK_LEN);
471        SK_LEN
472    }
473
474    fn from_bytes(bytes: &[u8]) -> Result<Self, SignatureError> {
475        if bytes.len() != 32 {
476            return Err(SignatureError::DecodingError("Invalid seed length"));
477        }
478        let mut keymat = KeyMaterial::<32>::from_bytes(bytes)?;
479        keymat.allow_hazardous_operations();
480        keymat.set_key_type(KeyType::Seed)?;
481        keymat.set_security_strength(SecurityStrength::_256bit)?;
482        keymat.drop_hazardous_operations();
483
484        Self::new(&keymat)
485    }
486}
487
488impl<
489    const LAMBDA: i32,
490    const GAMMA2: i32,
491    const k: usize,
492    const l: usize,
493    const eta: usize,
494    const S1_PACKED_LEN: usize,
495    const S2_PACKED_LEN: usize,
496    const T1_PACKED_LEN: usize,
497    const PK_LEN: usize,
498    const SK_LEN: usize,
499    const FULL_SK_LEN: usize,
500> MLDSAPrivateKeyTrait<k, l, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN>
501for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {
502    fn from_keymaterial(seed: &KeyMaterial<32>) -> Result<Self, SignatureError> {
503        Self::new(seed)
504    }
505
506    fn seed(&self) -> &KeyMaterial<32> { &self.seed }
507
508    fn tr(&self) -> [u8; 64] {
509        let pk: MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> = self.derive_pk();
510        pk.compute_tr()
511    }
512
513    fn derive_pk(&self) -> MLDSAPublicKey<k, T1_PACKED_LEN, PK_LEN> {
514        // The goal here is to get t1, which we will build and compress one row at a time.
515
516        let s1_packed: [u8; S1_PACKED_LEN] = self.compute_s1_packed();
517        let s2_packed: [u8; S2_PACKED_LEN] = self.compute_s2_packed();
518
519        let mut t1_packed = [0u8; T1_PACKED_LEN];
520        debug_assert_eq!(T1_PACKED_LEN, POLY_T1PACKED_LEN * k);
521
522        for i in 0..k {
523            t1_packed[i * POLY_T1PACKED_LEN .. (i+1) * POLY_T1PACKED_LEN]
524                .copy_from_slice(
525                    &simple_bit_pack_t1(&self.compute_t1_row(i, &s1_packed, &s2_packed))
526                );
527        }
528
529        MLDSAPublicKey::<k, T1_PACKED_LEN, PK_LEN>::new(self.rho.clone(), t1_packed)
530    }
531    fn sk_encode_out(&self, out: &mut [u8; SK_LEN]) -> usize {
532        out.copy_from_slice(self.seed.ref_to_bytes());
533
534        SK_LEN
535    }
536    fn encode_full_sk(&self) -> [u8; FULL_SK_LEN] {
537        let mut out = [0; FULL_SK_LEN];
538        self.encode_full_sk_out(&mut out);
539
540        out
541    }
542    fn encode_full_sk_out(&self, out: &mut [u8; FULL_SK_LEN]) {
543        out.fill(0);
544
545        // Algorithm 24 skEncode(𝜌, 𝐾, π‘‘π‘Ÿ, 𝐬1, 𝐬2, 𝐭0)
546
547        let mut off: usize = 0;
548
549        // 1: π‘ π‘˜ ← 𝜌||𝐾||π‘‘π‘Ÿ
550        out[0..32].copy_from_slice(&self.rho);
551        out[32..64].copy_from_slice(&self.K);
552        out[64..128].copy_from_slice(&self.tr());
553        off += 128;
554
555        // 2: for 𝑖 from 0 to β„“ βˆ’ 1 do
556        // 3:   π‘ π‘˜ ← π‘ π‘˜ || BitPack (𝐬1[𝑖], πœ‚, πœ‚)
557        // 4: end for
558        let s1_packed = self.compute_s1_packed();
559        out[off .. off + S1_PACKED_LEN].copy_from_slice(&s1_packed);
560        off += S1_PACKED_LEN;
561
562
563        // 5: for 𝑖 from 0 to π‘˜ βˆ’ 1 do
564        // 6:   π‘ π‘˜ ← π‘ π‘˜ || BitPack (𝐬2[𝑖], πœ‚, πœ‚)
565        // 7: end for
566        let s2_packed = self.compute_s2_packed();
567        out[off .. off + S2_PACKED_LEN].copy_from_slice(&s2_packed);
568        off += S2_PACKED_LEN;
569
570
571        // 8: for 𝑖 from 0 to π‘˜ βˆ’ 1 do
572        // 9:   π‘ π‘˜ ← π‘ π‘˜ || BitPack (𝐭0[𝑖], 2π‘‘βˆ’1 βˆ’ 1, 2π‘‘βˆ’1)
573        // 10: end for
574        debug_assert_eq!(off + k * POLY_T0PACKED_LEN, FULL_SK_LEN);
575        for row in 0..k {
576            let t0_i = self.compute_t0_row(row, &s1_packed, &s2_packed);
577            out[off .. off + POLY_T0PACKED_LEN].copy_from_slice(&bit_pack_t0(&t0_i));
578            off += POLY_T0PACKED_LEN;
579        }
580        debug_assert_eq!(off, FULL_SK_LEN);
581    }
582    fn sk_decode(sk: &[u8; SK_LEN]) -> Self {
583        Self::from_bytes(sk).unwrap()
584    }
585}
586
587pub(crate) trait MLDSAPrivateKeyInternalTrait<
588    const LAMBDA: i32,
589    const GAMMA2: i32,
590    const k: usize,
591    const l: usize,
592    const eta: usize,
593    const S1_PACKED_LEN: usize,
594    const S2_PACKED_LEN: usize,
595    const PK_LEN: usize,
596    const SK_LEN: usize,
597> : Sized
598{
599    fn rho(&self) -> &[u8; 32];
600    fn K(&self) -> &[u8; 32];
601
602    fn compute_s1_row(&self, idx: usize) -> Polynomial;
603
604    fn compute_s1_packed(&self) -> [u8; S1_PACKED_LEN];
605
606    fn compute_s2_row(&self, idx: usize) -> Polynomial;
607
608    fn compute_s2_packed(&self) -> [u8; S2_PACKED_LEN];
609
610    fn compute_t0_row(&self, idx: usize, s1_packed: &[u8], s2_packed: &[u8]) -> Polynomial;
611
612    fn compute_t1_row(&self, idx: usize, s1_packed: &[u8], s2_packed: &[u8]) -> Polynomial;
613}
614
615impl<
616    const LAMBDA: i32,
617    const GAMMA2: i32,
618    const k: usize,
619    const l: usize,
620    const eta: usize,
621    const S1_PACKED_LEN: usize,
622    const S2_PACKED_LEN: usize,
623    const T1_PACKED_LEN: usize,
624    const PK_LEN: usize,
625    const SK_LEN: usize,
626    const FULL_SK_LEN: usize,
627> MLDSAPrivateKeyInternalTrait<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, PK_LEN, SK_LEN>
628for MLDSASeedPrivateKey<LAMBDA, GAMMA2, k, l, eta, S1_PACKED_LEN, S2_PACKED_LEN, T1_PACKED_LEN, PK_LEN, SK_LEN, FULL_SK_LEN> {
629    fn rho(&self) -> &[u8; 32] {
630        &self.rho
631    }
632
633    fn K(&self) -> &[u8; 32] {
634        &self.K
635    }
636
637    fn compute_s1_row(
638        &self,
639        idx: usize,
640    ) -> Polynomial {
641        debug_assert!(idx < l);
642        rej_bounded_poly::<eta>(&self.rho_prime, &(idx as u16).to_le_bytes())
643    }
644
645    fn compute_s1_packed(&self) -> [u8; S1_PACKED_LEN] {
646        let mut s1_packed = [0u8; S1_PACKED_LEN];
647        for idx in 0..l {
648            let s1_i = self.compute_s1_row(idx);
649            bit_pack_eta::<eta>(&s1_i, &mut s1_packed[idx * bitlen_eta(eta)..(idx + 1) * bitlen_eta(eta)]);
650        }
651        s1_packed
652    }
653
654    fn compute_s2_row(
655        &self,
656        idx: usize,
657    ) -> Polynomial {
658        debug_assert!(idx < k);
659        rej_bounded_poly::<eta>(&self.rho_prime, &((idx + l) as u16).to_le_bytes())
660    }
661
662    fn compute_s2_packed(&self) -> [u8; S2_PACKED_LEN] {
663        let mut s2_packed = [0u8; S2_PACKED_LEN];
664        for idx in 0..k {
665            let s2_i = self.compute_s2_row(idx);
666            bit_pack_eta::<eta>(&s2_i, &mut s2_packed[idx * bitlen_eta(eta)..(idx + 1) * bitlen_eta(eta)]);
667        }
668        s2_packed
669    }
670
671    fn compute_t0_row(
672        &self,
673        idx: usize,
674        s1_packed: &[u8],
675        s2_packed: &[u8],
676    ) -> Polynomial {
677        let mut t0 = self.compute_t_row(idx, s1_packed, s2_packed);
678        for j in 0..N {
679            (_, t0[j]) = power_2_round(t0[j]);
680        }
681
682        t0
683    }
684
685    fn compute_t1_row(
686        &self,
687        idx: usize,
688        s1_packed: &[u8],
689        s2_packed: &[u8],
690    ) -> Polynomial {
691        let mut t1 = self.compute_t_row(idx, s1_packed, s2_packed);
692        for j in 0..N {
693            (t1[j], _) = power_2_round(t1[j]);
694        }
695
696        t1
697    }
698}
699