diff --git a/src/spartan/batched.rs b/src/spartan/batched.rs index 3604c1282..b68ac33c6 100644 --- a/src/spartan/batched.rs +++ b/src/spartan/batched.rs @@ -490,17 +490,7 @@ impl> BatchedRelaxedR1CSSNARKTrait let evals_Z = zip_with!(iter, (self.evals_W, U, r_y), |eval_W, U, r_y| { let eval_X = { // constant term - let poly_X = iter::once((0, U.u)) - .chain( - //remaining inputs - U.X - .iter() - .enumerate() - // filter_map uses the sparsity of the polynomial, if irrelevant - // we should replace by UniPoly - .filter_map(|(i, x_i)| (!x_i.is_zero_vartime()).then_some((i + 1, *x_i))), - ) - .collect(); + let poly_X = iter::once(U.u).chain(U.X.iter().cloned()).collect(); SparsePolynomial::new(r_y.len() - 1, poly_X).evaluate(&r_y[1..]) }; (E::Scalar::ONE - r_y[0]) * eval_W + r_y[0] * eval_X diff --git a/src/spartan/batched_ppsnark.rs b/src/spartan/batched_ppsnark.rs index f60e562e1..caf604804 100644 --- a/src/spartan/batched_ppsnark.rs +++ b/src/spartan/batched_ppsnark.rs @@ -927,15 +927,7 @@ impl> BatchedRelaxedR1CSSNARKTrait let X = { // constant term - let poly_X = std::iter::once((0, U.u)) - .chain( - //remaining inputs - (0..U.X.len()) - // filter_map uses the sparsity of the polynomial, if irrelevant - // we should replace by UniPoly - .filter_map(|i| (!U.X[i].is_zero_vartime()).then_some((i + 1, U.X[i]))), - ) - .collect(); + let poly_X = std::iter::once(U.u).chain(U.X.iter().cloned()).collect(); SparsePolynomial::new(num_vars_log, poly_X).evaluate(&rand_sc_unpad[1..]) }; diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 663addb98..ff17677c8 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -23,7 +23,6 @@ use crate::{ }; use ff::Field; use itertools::Itertools as _; -use polys::multilinear::SparsePolynomial; use rayon::{iter::IntoParallelRefIterator, prelude::*}; use rayon_scan::ScanParallelIterator as _; use ref_cast::RefCast; diff --git a/src/spartan/polys/multilinear.rs b/src/spartan/polys/multilinear.rs index bdb18e06f..4217342a6 100644 --- a/src/spartan/polys/multilinear.rs +++ b/src/spartan/polys/multilinear.rs @@ -8,8 +8,7 @@ use ff::PrimeField; use itertools::Itertools as _; use rand_core::{CryptoRng, RngCore}; use rayon::prelude::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, - IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; use serde::{Deserialize, Serialize}; @@ -130,47 +129,36 @@ impl Index for MultilinearPolynomial { } /// Sparse multilinear polynomial, which means the $Z(\cdot)$ is zero at most points. -/// So we do not have to store every evaluations of $Z(\cdot)$, only store the non-zero points. -/// -/// For example, the evaluations are [0, 0, 0, 1, 0, 1, 0, 2]. -/// The sparse polynomial only store the non-zero values, [(3, 1), (5, 1), (7, 2)]. -/// In the tuple, the first is index, the second is value. +/// In our context, sparse polynomials are non-zeros over the hypercube at locations that map to "small" integers +/// We exploit this property to implement a time-optimal algorithm pub(crate) struct SparsePolynomial { num_vars: usize, - Z: Vec<(usize, Scalar)>, + Z: Vec, } impl SparsePolynomial { - pub fn new(num_vars: usize, Z: Vec<(usize, Scalar)>) -> Self { - Self { num_vars, Z } + pub fn new(num_vars: usize, Z: Vec) -> Self { + SparsePolynomial { num_vars, Z } } - /// Computes the $\tilde{eq}$ extension polynomial. - /// return 1 when a == r, otherwise return 0. - fn compute_chi(a: &[bool], r: &[Scalar]) -> Scalar { - assert_eq!(a.len(), r.len()); - let mut chi_i = Scalar::ONE; - for j in 0..r.len() { - if a[j] { - chi_i *= r[j]; - } else { - chi_i *= Scalar::ONE - r[j]; - } - } - chi_i - } - - // Takes O(m log n) where m is the number of non-zero evaluations and n is the number of variables. + // a time-optimal algorithm to evaluate sparse polynomials pub fn evaluate(&self, r: &[Scalar]) -> Scalar { assert_eq!(self.num_vars, r.len()); - (0..self.Z.len()) - .into_par_iter() - .map(|i| { - let bits = (self.Z[i].0).get_bits(r.len()); - Self::compute_chi(&bits, r) * self.Z[i].1 - }) - .sum() + let num_vars_z = self.Z.len().next_power_of_two().log_2(); + let chis = EqPolynomial::evals_from_points(&r[self.num_vars - 1 - num_vars_z..]); + let eval_partial: Scalar = self + .Z + .iter() + .zip(chis.iter()) + .map(|(z, chi)| *z * *chi) + .sum(); + + let common = (0..self.num_vars - 1 - num_vars_z) + .map(|i| (Scalar::ONE - r[i])) + .product::(); + + common * eval_partial } } @@ -232,18 +220,21 @@ mod tests { } fn test_sparse_polynomial_with() { - // Let the polynomial have 3 variables, p(x_1, x_2, x_3) = (x_1 + x_2) * x_3 - // Evaluations of the polynomial at boolean cube are [0, 0, 0, 1, 0, 1, 0, 2]. + // Let the polynomial have 4 variables, but is non-zero at only 3 locations (out of 2^4 = 16) over the hypercube + let mut Z = vec![F::ONE, F::ONE, F::from(2)]; + let m_poly = SparsePolynomial::::new(4, Z.clone()); - let TWO = F::from(2); - let Z = vec![(3, F::ONE), (5, F::ONE), (7, TWO)]; - let m_poly = SparsePolynomial::::new(3, Z); + Z.resize(16, F::ZERO); // append with zeros to make it a dense polynomial + let m_poly_dense = MultilinearPolynomial::new(Z); - let x = vec![F::ONE, F::ONE, F::ONE]; - assert_eq!(m_poly.evaluate(x.as_slice()), TWO); + // evaluation point + let x = vec![F::from(5), F::from(8), F::from(5), F::from(3)]; - let x = vec![F::ONE, F::ZERO, F::ONE]; - assert_eq!(m_poly.evaluate(x.as_slice()), F::ONE); + // check evaluations + assert_eq!( + m_poly.evaluate(x.as_slice()), + m_poly_dense.evaluate(x.as_slice()) + ); } #[test] diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index aba5a98be..5cc723c34 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -24,7 +24,7 @@ use crate::{ }, SumcheckProof, }, - PolyEvalInstance, PolyEvalWitness, SparsePolynomial, + PolyEvalInstance, PolyEvalWitness, }, traits::{ commitment::{CommitmentEngineTrait, CommitmentTrait, Len}, @@ -42,7 +42,7 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::sync::Arc; -use super::polys::masked_eq::MaskedEqPolynomial; +use super::polys::{masked_eq::MaskedEqPolynomial, multilinear::SparsePolynomial}; fn padded(v: &[E::Scalar], n: usize, e: &E::Scalar) -> Vec { let mut v_padded = vec![*e; n]; @@ -930,17 +930,15 @@ impl> RelaxedR1CSSNARKTrait for Relax }; let eval_X = { - // constant term - let poly_X = std::iter::once((0, U.u)) - .chain( - //remaining inputs - (0..U.X.len()) - // filter_map uses the sparsity of the polynomial, if irrelevant - // we should replace by UniPoly - .filter_map(|i| (!U.X[i].is_zero_vartime()).then_some((i + 1, U.X[i]))), - ) - .collect(); - SparsePolynomial::new(vk.num_vars.log_2(), poly_X).evaluate(&rand_sc_unpad[1..]) + // public IO is (u, X) + let X = vec![U.u] + .into_iter() + .chain(U.X.iter().cloned()) + .collect::>(); + + // evaluate the sparse polynomial at rand_sc_unpad[1..] + let poly_X = SparsePolynomial::new(rand_sc_unpad.len() - 1, X); + poly_X.evaluate(&rand_sc_unpad[1..]) }; self.eval_W + factor * rand_sc_unpad[0] * eval_X diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 636a0d7d8..538647e21 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -32,7 +32,7 @@ use itertools::Itertools as _; use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; -use std::{iter, sync::Arc}; +use std::sync::Arc; /// A type that represents the prover's key #[derive(Debug, Clone)] @@ -328,17 +328,12 @@ impl> RelaxedR1CSSNARKTrait for Relax // verify claim_inner_final let eval_Z = { let eval_X = { - // constant term - let poly_X = iter::once((0, U.u)) - .chain( - //remaining inputs - (0..U.X.len()) - // filter_map uses the sparsity of the polynomial, if irrelevant - // we should replace by UniPoly - .filter_map(|i| (!U.X[i].is_zero_vartime()).then_some((i + 1, U.X[i]))), - ) - .collect(); - SparsePolynomial::new(usize::try_from(vk.S.num_vars.ilog2()).unwrap(), poly_X) + // public IO is (u, X) + let X = vec![U.u] + .into_iter() + .chain(U.X.iter().cloned()) + .collect::>(); + SparsePolynomial::new(usize::try_from(vk.S.num_vars.ilog2()).unwrap(), X) .evaluate(&r_y[1..]) }; (E::Scalar::ONE - r_y[0]) * self.eval_W + r_y[0] * eval_X