Skip to main content

bouncycastle_mldsa/
matrix.rs

1//! These are somewhat unnecessary wrappers around simple arrays, but they are helpful to me in clearly
2//! keeping the types and sizes obvious.
3
4use crate::aux_functions::multiply_ntt;
5use crate::mldsa::H;
6use crate::polynomial::Polynomial;
7use bouncycastle_core::traits::XOF;
8use core::ops::{Index, IndexMut};
9
10/// A matrix over the ML-DSA ring.
11#[derive(Clone)]
12pub struct Matrix<const k: usize, const l: usize>(/*pub(crate)*/ [[Polynomial; l]; k]);
13
14/// Convenience function to avoid ".0" all over the place.
15impl<const k: usize, const l: usize> Index<usize> for Matrix<k, l> {
16    type Output = [Polynomial; l];
17
18    fn index(&self, index: usize) -> &Self::Output {
19        &self.0[index]
20    }
21}
22/// Convenience function to avoid ".0" all over the place.
23impl<const k: usize, const l: usize> IndexMut<usize> for Matrix<k, l> {
24    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
25        &mut self.0[index]
26    }
27}
28
29impl<const k: usize, const l: usize> Matrix<k, l> {
30    pub(crate) fn new() -> Self {
31        Self { 0: [[(); l]; k].map(|_| [(); l].map(|_| Polynomial::new())) }
32    }
33
34    /// Algorithm 48 MatrixVectorNTT(𝐌, 𝐯)
35    /// Computes the product 𝐌 βˆ˜Μ‚ 𝐯_hat of a matrix 𝐌_hat and a vector 𝐯_hat over π‘‡π‘ž.
36    /// Input: π‘˜, β„“ ∈ β„•, 𝐌 ∈ π‘‡π‘ž
37    /// π‘˜Γ—β„“ Μ‚ π‘ž .
38    /// Performs dot product multiplication of this matrix by a vector
39    /// Input: vector of length l
40    /// Output: vector of length k
41    pub fn matrix_vector_ntt(&self, v: &Vector<l>) -> Vector<k> {
42        let mut w = Vector::<k>::new();
43        for i in 0..k {
44            // split out the 0 case to skip a no-op add_ntt()
45            w[i].coeffs.copy_from_slice(&multiply_ntt(&self[i][0], &v[0]).coeffs);
46
47            let mut w1: Polynomial;
48            for j in 1..l {
49                // dot product a vector into a matrix: multiply the input vector
50                // into each row of the matrix, then sum the results to produce a vector of
51                // length k.
52                w1 = multiply_ntt(&self[i][j], &v[j]);
53                w[i].add_ntt(&w1);
54            }
55        }
56
57        w
58    }
59}
60
61// Matrix and Vector do not need to impl Secret because the actual data is in the polynomials, which have their own zeroizing drop.
62// Technically all matrices and some vectors are only part of the public key and might not need to be zeroized,
63// but I'll leave it zeroizing for now and leave this as a potential future optimization.
64
65#[derive(Clone)]
66pub(crate) struct Vector<const k: usize> {
67    pub(crate) vec: [Polynomial; k],
68}
69
70/// Convenience function to avoid ".0" all over the place.
71impl<const k: usize> Index<usize> for Vector<k> {
72    type Output = Polynomial;
73
74    fn index(&self, index: usize) -> &Self::Output {
75        &self.vec[index]
76    }
77}
78/// Convenience function to avoid ".0" all over the place.
79impl<const k: usize> IndexMut<usize> for Vector<k> {
80    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81        &mut self.vec[index]
82    }
83}
84
85impl<const LEN: usize> Vector<LEN> {
86    pub(crate) fn new() -> Self {
87        Self { vec: [(); LEN].map(|_| Polynomial::new()) }
88    }
89
90    /// Algorithm 46 AddVectorNTT(𝐯, 𝐰)Μ‚
91    /// Computes the sum 𝐯_hat + 𝐰_hat of two vectors 𝐯_hat, 𝐰_hat over π‘‡π‘ž.
92    /// Input: β„“ ∈ β„•, v_hat ∈ T^β„“, w_hat ∈ 𝑇^β„“
93    /// Output: u_hat ∈ T^β„“_π‘ž.
94    /// Add another vector to this vector
95    pub(crate) fn add_vector_ntt(&mut self, s: &Self) {
96        for i in 0..LEN {
97            // perform montgomery addition of each polynomial in the vector
98            self[i].add_ntt(&s[i]);
99        }
100    }
101
102    pub(crate) fn sub_vector(&self, s: &Self) -> Self {
103        let mut out = self.clone();
104        for i in 0..LEN {
105            out[i].sub(&s[i]);
106        }
107        out
108    }
109
110    /// Algorithm 47 ScalarVectorNTT(𝑐,Μ‚ 𝐯)Μ‚
111    /// Computes the product 𝑐_hat * 𝐯_hat of a scalar 𝑐_hat and a vector 𝐯_hat over π‘‡π‘ž.
112    /// Input: 𝑐_hat ∈ π‘‡π‘ž, β„“ ∈ β„•, 𝐯_hat ∈ 𝑇^β„“
113    /// Output: π‘ž .
114    pub(crate) fn scalar_vector_ntt(&self, w: &Polynomial) -> Self {
115        let mut s_hat = Self::new();
116        for i in 0..LEN {
117            s_hat[i] = multiply_ntt(&self[i], &w);
118        }
119
120        s_hat
121    }
122
123    pub(crate) fn conditional_add_q(&mut self) {
124        for i in 0..LEN {
125            self[i].conditional_add_q();
126        }
127    }
128
129    pub(crate) fn reduce(&mut self) {
130        for i in 0..LEN {
131            self[i].reduce();
132        }
133    }
134
135    pub(crate) fn ntt(&mut self) {
136        for i in 0..LEN {
137            self[i].ntt();
138        }
139    }
140
141    pub(crate) fn inv_ntt(&mut self) {
142        for i in 0..LEN {
143            self[i].inv_ntt();
144        }
145    }
146
147    pub(crate) fn high_bits<const GAMMA2: i32>(&self) -> Self {
148        let mut s = Self::new();
149
150        for i in 0..LEN {
151            s[i] = self[i].high_bits::<GAMMA2>();
152        }
153
154        s
155    }
156
157    pub(crate) fn low_bits<const GAMMA2: i32>(&self) -> Self {
158        let mut s = Self::new();
159
160        for i in 0..LEN {
161            s[i] = self[i].low_bits::<GAMMA2>();
162        }
163
164        s
165    }
166
167    pub(crate) fn shift_left<const d: i32>(&self) -> Self {
168        let mut out = self.clone();
169        for i in 0..LEN {
170            out[i].shift_left::<d>();
171        }
172
173        out
174    }
175
176    pub(crate) fn check_norm<const BOUND: i32>(&self) -> bool {
177        // Fine that this is not constant-time because it is used in a rejection loop -- the early quit leads to rejection.
178        for x in self.vec.iter() {
179            if x.check_norm::<BOUND>() {
180                return true;
181            }
182        }
183        false
184    }
185
186    /// Algorithm 28 w1Encode(𝐰1)
187    /// Encodes a polynomial vector 𝐰1 into a byte string.
188    /// Input: 𝐰1 ∈ π‘…π‘˜ whose polynomial coordinates have coefficients in \[0, (π‘ž βˆ’ 1)/(2𝛾2) βˆ’ 1].
189    /// Output: A byte string representation 𝐰1_tilde ∈ 𝔹32π‘˜β‹…bitlen ((π‘žβˆ’1)/(2𝛾2)βˆ’1)
190    /// Optimized from FIPS 204 to feed into the hash one row at a time to reduce overall memory footprint.
191    pub(crate) fn w1_encode_and_hash<const POLY_W1_PACKED_LEN: usize>(&self, h: &mut H) {
192        // 1: 𝐰̃1 ← ()
193        // don't need to allocate anything since we're feeding it into the hash row-wise
194
195        // 2: for 𝑖 from 0 to π‘˜ βˆ’ 1 do
196        // 3:   𝐰̃1 ← 𝐰̃1 || SimpleBitPack (𝐰1[𝑖], (π‘ž βˆ’ 1)/(2𝛾2) βˆ’ 1)
197        // 4: end for
198        for w in self.vec.iter() {
199            h.absorb(&w.w1_encode::<POLY_W1_PACKED_LEN>());
200        }
201    }
202}