Skip to content

Commit

Permalink
Merge pull request #2883 from o1-labs/dw/cross-terms-aggregated-polyn…
Browse files Browse the repository at this point in the history
…omial

mvpoly: compute the cross terms of a list of polynomials
  • Loading branch information
dannywillems authored Dec 19, 2024
2 parents 163ff27 + 73b4bf4 commit d23dec2
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
36 changes: 36 additions & 0 deletions mvpoly/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,39 @@ pub trait MVPoly<F: PrimeField, const N: usize, const D: usize>:
/// variable in each monomial is of maximum degree 1.
fn is_multilinear(&self) -> bool;
}

/// Compute the cross terms of a list of polynomials. The polynomials are
/// linearly combined using the power of a combiner, often called `α`.
pub fn compute_combined_cross_terms<
F: PrimeField,
const N: usize,
const D: usize,
T: MVPoly<F, N, D>,
>(
polys: Vec<T>,
eval1: [F; N],
eval2: [F; N],
u1: F,
u2: F,
combiner1: F,
combiner2: F,
) -> HashMap<usize, F> {
// These should never happen as they should be random
// It also makes the code cleaner as we do not need to handle 0^0
assert!(combiner1 != F::zero());
assert!(combiner2 != F::zero());
assert!(u1 != F::zero());
assert!(u2 != F::zero());
polys
.into_iter()
.enumerate()
.fold(HashMap::new(), |mut acc, (i, poly)| {
let scalar1 = combiner1.pow([i as u64]);
let scalar2 = combiner2.pow([i as u64]);
let res = poly.compute_cross_terms_scaled(&eval1, &eval2, u1, u2, scalar1, scalar2);
res.iter().for_each(|(p, r)| {
acc.entry(*p).and_modify(|e| *e += r).or_insert(*r);
});
acc
})
}
60 changes: 60 additions & 0 deletions mvpoly/tests/monomials.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ark_ff::{Field, One, UniformRand, Zero};
use core::cmp::Ordering;
use kimchi::circuits::{
berkeley_columns::BerkeleyChallengeTerm,
expr::{ConstantExpr, Expr, ExprInner, Variable},
Expand Down Expand Up @@ -804,3 +805,62 @@ fn test_cross_terms_scaled() {
let scaled_cross_terms = scaled_p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);
assert_eq!(cross_terms, scaled_cross_terms);
}

#[test]
fn test_cross_terms_aggregated_polynomial() {
let mut rng = o1_utils::tests::make_test_rng(None);
const M: usize = 20;
let polys: Vec<Sparse<Fp, 5, 4>> = (0..M)
.map(|_| unsafe { Sparse::<Fp, 5, 4>::random(&mut rng, None) })
.collect();

let random_eval1: [Fp; 5] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 5] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let scalar1: Fp = Fp::rand(&mut rng);
let scalar2: Fp = Fp::rand(&mut rng);

const N: usize = 5 + M;
const D: usize = 4 + 1;
let aggregated_poly: Sparse<Fp, { N }, { D }> = {
let vars: [Sparse<Fp, N, D>; M] = std::array::from_fn(|j| {
let mut res = Sparse::<Fp, { N }, { D }>::zero();
let monomial: [usize; N] = std::array::from_fn(|i| if i == 5 + j { 1 } else { 0 });
res.add_monomial(monomial, Fp::one());
res
});
polys
.iter()
.enumerate()
.fold(Sparse::<Fp, { N }, { D }>::zero(), |acc, (j, poly)| {
let poly: Result<Sparse<Fp, { N }, { D }>, String> = (*poly).clone().into();
let poly: Sparse<Fp, { N }, { D }> = poly.unwrap();
poly * vars[j].clone() + acc
})
};

let res = mvpoly::compute_combined_cross_terms(
polys,
random_eval1,
random_eval2,
u1,
u2,
scalar1,
scalar2,
);
let random_eval1_prime: [Fp; N] = std::array::from_fn(|i| match i.cmp(&5) {
Ordering::Greater => scalar1.pow([(i as u64) - 5_u64]),
Ordering::Less => random_eval1[i],
Ordering::Equal => Fp::one(),
});

let random_eval2_prime: [Fp; N] = std::array::from_fn(|i| match i.cmp(&5) {
Ordering::Greater => scalar2.pow([(i as u64) - 5_u64]),
Ordering::Less => random_eval2[i],
Ordering::Equal => Fp::one(),
});
let cross_terms_aggregated =
aggregated_poly.compute_cross_terms(&random_eval1_prime, &random_eval2_prime, u1, u2);
assert_eq!(res, cross_terms_aggregated);
}

0 comments on commit d23dec2

Please sign in to comment.