diff --git a/mvpoly/src/lib.rs b/mvpoly/src/lib.rs index 87d0532862..1b4683948a 100644 --- a/mvpoly/src/lib.rs +++ b/mvpoly/src/lib.rs @@ -300,3 +300,39 @@ pub trait MVPoly: /// variable in each monomial is of maximum degree 1. fn is_multilinear(&self) -> bool; } + +/// Compute the cross terms of a list of polynomials. The polynomials are +/// linearly combined using the power of a combiner, often called `α`. +pub fn compute_combined_cross_terms< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>( + polys: Vec, + eval1: [F; N], + eval2: [F; N], + u1: F, + u2: F, + combiner1: F, + combiner2: F, +) -> HashMap { + // These should never happen as they should be random + // It also makes the code cleaner as we do not need to handle 0^0 + assert!(combiner1 != F::zero()); + assert!(combiner2 != F::zero()); + assert!(u1 != F::zero()); + assert!(u2 != F::zero()); + polys + .into_iter() + .enumerate() + .fold(HashMap::new(), |mut acc, (i, poly)| { + let scalar1 = combiner1.pow([i as u64]); + let scalar2 = combiner2.pow([i as u64]); + let res = poly.compute_cross_terms_scaled(&eval1, &eval2, u1, u2, scalar1, scalar2); + res.iter().for_each(|(p, r)| { + acc.entry(*p).and_modify(|e| *e += r).or_insert(*r); + }); + acc + }) +} diff --git a/mvpoly/tests/monomials.rs b/mvpoly/tests/monomials.rs index b169f8bc01..b81ad623fd 100644 --- a/mvpoly/tests/monomials.rs +++ b/mvpoly/tests/monomials.rs @@ -1,4 +1,5 @@ use ark_ff::{Field, One, UniformRand, Zero}; +use core::cmp::Ordering; use kimchi::circuits::{ berkeley_columns::BerkeleyChallengeTerm, expr::{ConstantExpr, Expr, ExprInner, Variable}, @@ -804,3 +805,62 @@ fn test_cross_terms_scaled() { let scaled_cross_terms = scaled_p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); assert_eq!(cross_terms, scaled_cross_terms); } + +#[test] +fn test_cross_terms_aggregated_polynomial() { + let mut rng = o1_utils::tests::make_test_rng(None); + const M: usize = 20; + let polys: Vec> = (0..M) + .map(|_| unsafe { Sparse::::random(&mut rng, None) }) + .collect(); + + let random_eval1: [Fp; 5] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 5] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let scalar1: Fp = Fp::rand(&mut rng); + let scalar2: Fp = Fp::rand(&mut rng); + + const N: usize = 5 + M; + const D: usize = 4 + 1; + let aggregated_poly: Sparse = { + let vars: [Sparse; M] = std::array::from_fn(|j| { + let mut res = Sparse::::zero(); + let monomial: [usize; N] = std::array::from_fn(|i| if i == 5 + j { 1 } else { 0 }); + res.add_monomial(monomial, Fp::one()); + res + }); + polys + .iter() + .enumerate() + .fold(Sparse::::zero(), |acc, (j, poly)| { + let poly: Result, String> = (*poly).clone().into(); + let poly: Sparse = poly.unwrap(); + poly * vars[j].clone() + acc + }) + }; + + let res = mvpoly::compute_combined_cross_terms( + polys, + random_eval1, + random_eval2, + u1, + u2, + scalar1, + scalar2, + ); + let random_eval1_prime: [Fp; N] = std::array::from_fn(|i| match i.cmp(&5) { + Ordering::Greater => scalar1.pow([(i as u64) - 5_u64]), + Ordering::Less => random_eval1[i], + Ordering::Equal => Fp::one(), + }); + + let random_eval2_prime: [Fp; N] = std::array::from_fn(|i| match i.cmp(&5) { + Ordering::Greater => scalar2.pow([(i as u64) - 5_u64]), + Ordering::Less => random_eval2[i], + Ordering::Equal => Fp::one(), + }); + let cross_terms_aggregated = + aggregated_poly.compute_cross_terms(&random_eval1_prime, &random_eval2_prime, u1, u2); + assert_eq!(res, cross_terms_aggregated); +}