bouncycastle_mldsa/matrix.rs
1//! These are somewhat unnecessary wrappers around simple arrays, but they are helpful to me in clearly
2//! keeping the types and sizes obvious.
3
4use crate::aux_functions::multiply_ntt;
5use crate::mldsa::H;
6use crate::polynomial::Polynomial;
7use bouncycastle_core::traits::XOF;
8use core::ops::{Index, IndexMut};
9
10/// A matrix over the ML-DSA ring.
11#[derive(Clone)]
12pub struct Matrix<const k: usize, const l: usize>(/*pub(crate)*/ [[Polynomial; l]; k]);
13
14/// Convenience function to avoid ".0" all over the place.
15impl<const k: usize, const l: usize> Index<usize> for Matrix<k, l> {
16 type Output = [Polynomial; l];
17
18 fn index(&self, index: usize) -> &Self::Output {
19 &self.0[index]
20 }
21}
22/// Convenience function to avoid ".0" all over the place.
23impl<const k: usize, const l: usize> IndexMut<usize> for Matrix<k, l> {
24 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
25 &mut self.0[index]
26 }
27}
28
29impl<const k: usize, const l: usize> Matrix<k, l> {
30 pub(crate) fn new() -> Self {
31 Self { 0: [[(); l]; k].map(|_| [(); l].map(|_| Polynomial::new())) }
32 }
33
34 /// Algorithm 48 MatrixVectorNTT(π, π―)
35 /// Computes the product π βΜ π―_hat of a matrix π_hat and a vector π―_hat over ππ.
36 /// Input: π, β β β, π β ππ
37 /// πΓβ Μ π .
38 /// Performs dot product multiplication of this matrix by a vector
39 /// Input: vector of length l
40 /// Output: vector of length k
41 pub fn matrix_vector_ntt(&self, v: &Vector<l>) -> Vector<k> {
42 let mut w = Vector::<k>::new();
43 for i in 0..k {
44 // split out the 0 case to skip a no-op add_ntt()
45 w[i].coeffs.copy_from_slice(&multiply_ntt(&self[i][0], &v[0]).coeffs);
46
47 let mut w1: Polynomial;
48 for j in 1..l {
49 // dot product a vector into a matrix: multiply the input vector
50 // into each row of the matrix, then sum the results to produce a vector of
51 // length k.
52 w1 = multiply_ntt(&self[i][j], &v[j]);
53 w[i].add_ntt(&w1);
54 }
55 }
56
57 w
58 }
59}
60
61// Matrix and Vector do not need to impl Secret because the actual data is in the polynomials, which have their own zeroizing drop.
62// Technically all matrices and some vectors are only part of the public key and might not need to be zeroized,
63// but I'll leave it zeroizing for now and leave this as a potential future optimization.
64
65#[derive(Clone)]
66pub(crate) struct Vector<const k: usize> {
67 pub(crate) vec: [Polynomial; k],
68}
69
70/// Convenience function to avoid ".0" all over the place.
71impl<const k: usize> Index<usize> for Vector<k> {
72 type Output = Polynomial;
73
74 fn index(&self, index: usize) -> &Self::Output {
75 &self.vec[index]
76 }
77}
78/// Convenience function to avoid ".0" all over the place.
79impl<const k: usize> IndexMut<usize> for Vector<k> {
80 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81 &mut self.vec[index]
82 }
83}
84
85impl<const LEN: usize> Vector<LEN> {
86 pub(crate) fn new() -> Self {
87 Self { vec: [(); LEN].map(|_| Polynomial::new()) }
88 }
89
90 /// Algorithm 46 AddVectorNTT(π―, π°)Μ
91 /// Computes the sum π―_hat + π°_hat of two vectors π―_hat, π°_hat over ππ.
92 /// Input: β β β, v_hat β T^β, w_hat β π^β
93 /// Output: u_hat β T^β_π.
94 /// Add another vector to this vector
95 pub(crate) fn add_vector_ntt(&mut self, s: &Self) {
96 for i in 0..LEN {
97 // perform montgomery addition of each polynomial in the vector
98 self[i].add_ntt(&s[i]);
99 }
100 }
101
102 pub(crate) fn sub_vector(&self, s: &Self) -> Self {
103 let mut out = self.clone();
104 for i in 0..LEN {
105 out[i].sub(&s[i]);
106 }
107 out
108 }
109
110 /// Algorithm 47 ScalarVectorNTT(π,Μ π―)Μ
111 /// Computes the product π_hat * π―_hat of a scalar π_hat and a vector π―_hat over ππ.
112 /// Input: π_hat β ππ, β β β, π―_hat β π^β
113 /// Output: π .
114 pub(crate) fn scalar_vector_ntt(&self, w: &Polynomial) -> Self {
115 let mut s_hat = Self::new();
116 for i in 0..LEN {
117 s_hat[i] = multiply_ntt(&self[i], &w);
118 }
119
120 s_hat
121 }
122
123 pub(crate) fn conditional_add_q(&mut self) {
124 for i in 0..LEN {
125 self[i].conditional_add_q();
126 }
127 }
128
129 pub(crate) fn reduce(&mut self) {
130 for i in 0..LEN {
131 self[i].reduce();
132 }
133 }
134
135 pub(crate) fn ntt(&mut self) {
136 for i in 0..LEN {
137 self[i].ntt();
138 }
139 }
140
141 pub(crate) fn inv_ntt(&mut self) {
142 for i in 0..LEN {
143 self[i].inv_ntt();
144 }
145 }
146
147 pub(crate) fn high_bits<const GAMMA2: i32>(&self) -> Self {
148 let mut s = Self::new();
149
150 for i in 0..LEN {
151 s[i] = self[i].high_bits::<GAMMA2>();
152 }
153
154 s
155 }
156
157 pub(crate) fn low_bits<const GAMMA2: i32>(&self) -> Self {
158 let mut s = Self::new();
159
160 for i in 0..LEN {
161 s[i] = self[i].low_bits::<GAMMA2>();
162 }
163
164 s
165 }
166
167 pub(crate) fn shift_left<const d: i32>(&self) -> Self {
168 let mut out = self.clone();
169 for i in 0..LEN {
170 out[i].shift_left::<d>();
171 }
172
173 out
174 }
175
176 pub(crate) fn check_norm<const BOUND: i32>(&self) -> bool {
177 // Fine that this is not constant-time because it is used in a rejection loop -- the early quit leads to rejection.
178 for x in self.vec.iter() {
179 if x.check_norm::<BOUND>() {
180 return true;
181 }
182 }
183 false
184 }
185
186 /// Algorithm 28 w1Encode(π°1)
187 /// Encodes a polynomial vector π°1 into a byte string.
188 /// Input: π°1 β π
π whose polynomial coordinates have coefficients in \[0, (π β 1)/(2πΎ2) β 1].
189 /// Output: A byte string representation π°1_tilde β πΉ32πβ
bitlen ((πβ1)/(2πΎ2)β1)
190 /// Optimized from FIPS 204 to feed into the hash one row at a time to reduce overall memory footprint.
191 pub(crate) fn w1_encode_and_hash<const POLY_W1_PACKED_LEN: usize>(&self, h: &mut H) {
192 // 1: π°Μ1 β ()
193 // don't need to allocate anything since we're feeding it into the hash row-wise
194
195 // 2: for π from 0 to π β 1 do
196 // 3: π°Μ1 β π°Μ1 || SimpleBitPack (π°1[π], (π β 1)/(2πΎ2) β 1)
197 // 4: end for
198 for w in self.vec.iter() {
199 h.absorb(&w.w1_encode::<POLY_W1_PACKED_LEN>());
200 }
201 }
202}