Skip to content

Commit

Permalink
feat: Faster computation of L0
Browse files Browse the repository at this point in the history
refactor: parallelize denominator computation
  • Loading branch information
davidnevadoc committed Sep 5, 2024
1 parent bc857a7 commit e3e1737
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions halo2_backend/src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#![allow(clippy::int_plus_one)]

use ff::{BatchInvert, WithSmallOrderMulGroup};
use group::Curve;
use halo2_middleware::ff::{Field, FromUniformBytes};
use halo2_middleware::zal::impls::H2cEngine;
Expand Down Expand Up @@ -107,7 +108,6 @@ where
}

// Compute fixeds

let fixed_polys: Vec<_> = circuit
.preprocessing
.fixed
Expand All @@ -131,13 +131,68 @@ where
.map(Polynomial::new_lagrange_from_vec)
.collect();

// Compute l_0(X)
// TODO: this can be done more efficiently
// https://github.com/privacy-scaling-explorations/halo2/issues/269
let mut l0 = vk.domain.empty_lagrange();
l0[0] = C::Scalar::ONE;
let l0 = vk.domain.lagrange_to_coeff(l0);
let l0 = vk.domain.coeff_to_extended(l0);
// Compute L_0(X) in the extended co-domain.
// L_0(X) the 0th Lagrange polynomial in the original domain.
// Its representation in the original domain H = {1, g, g^2, ..., g^(n-1)}
// is [1, 0, ..., 0].
// We compute its represenation in the extended co-domain
// zH = {z, z*w, z*w^2, ... , z*w^(n*k - 1)}, where k is the extension factor
// of the domain, and z is the extended root such that w^k = g.
// We assume z = F::ZETA, a cubic root the field. This simplifies the computation.
//
// The computation uses the fomula:
// L_i(X) = g^i/n * (X^n -1)/(X-g^i)
// L_0(X) = 1/n * (X^n -1)/(X-1)
let start = std::time::Instant::now();
let l0 = {
let one = C::ScalarExt::ONE;
let zeta = <C::ScalarExt as WithSmallOrderMulGroup<3>>::ZETA;

let n: u64 = 1 << vk.domain.k();
let c = (C::ScalarExt::from(n)).invert().unwrap();
let mut l0 = vec![C::ScalarExt::ZERO; vk.domain.extended_len()];

let w = vk.domain.get_extended_omega();
let wn = w.pow_vartime(&[n]);

Check warning on line 156 in halo2_backend/src/plonk/keygen.rs

View workflow job for this annotation

GitHub Actions / Clippy (beta)

the borrowed expression implements the required traits

warning: the borrowed expression implements the required traits --> halo2_backend/src/plonk/keygen.rs:156:32 | 156 | let wn = w.pow_vartime(&[n]); | ^^^^ help: change this to: `[n]` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args = note: `-W clippy::needless-borrows-for-generic-args` implied by `-W clippy::all` = help: to override `-W clippy::all` add `#[allow(clippy::needless_borrows_for_generic_args)]`

Check warning on line 156 in halo2_backend/src/plonk/keygen.rs

View workflow job for this annotation

GitHub Actions / Clippy (beta)

the borrowed expression implements the required traits

warning: the borrowed expression implements the required traits --> halo2_backend/src/plonk/keygen.rs:156:32 | 156 | let wn = w.pow_vartime(&[n]); | ^^^^ help: change this to: `[n]` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args = note: `-W clippy::needless-borrows-for-generic-args` implied by `-W clippy::all` = help: to override `-W clippy::all` add `#[allow(clippy::needless_borrows_for_generic_args)]`

Check failure on line 156 in halo2_backend/src/plonk/keygen.rs

View workflow job for this annotation

GitHub Actions / Clippy (1.56.1)

the borrowed expression implements the required traits

error: the borrowed expression implements the required traits --> halo2_backend/src/plonk/keygen.rs:156:32 | 156 | let wn = w.pow_vartime(&[n]); | ^^^^ help: change this to: `[n]` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args = note: `-D clippy::needless-borrows-for-generic-args` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::needless_borrows_for_generic_args)]`
let zeta_n = match n % 3 {
1 => zeta,
2 => zeta * zeta,
_ => one,
};

// Compute denominators.
parallelize(&mut l0, |e, mut index| {
let mut acc = zeta * w.pow_vartime(&[index as u64]);

Check warning on line 165 in halo2_backend/src/plonk/keygen.rs

View workflow job for this annotation

GitHub Actions / Clippy (beta)

the borrowed expression implements the required traits

warning: the borrowed expression implements the required traits --> halo2_backend/src/plonk/keygen.rs:165:48 | 165 | let mut acc = zeta * w.pow_vartime(&[index as u64]); | ^^^^^^^^^^^^^^^ help: change this to: `[index as u64]` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args

Check warning on line 165 in halo2_backend/src/plonk/keygen.rs

View workflow job for this annotation

GitHub Actions / Clippy (beta)

the borrowed expression implements the required traits

warning: the borrowed expression implements the required traits --> halo2_backend/src/plonk/keygen.rs:165:48 | 165 | let mut acc = zeta * w.pow_vartime(&[index as u64]); | ^^^^^^^^^^^^^^^ help: change this to: `[index as u64]` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args

Check failure on line 165 in halo2_backend/src/plonk/keygen.rs

View workflow job for this annotation

GitHub Actions / Clippy (1.56.1)

the borrowed expression implements the required traits

error: the borrowed expression implements the required traits --> halo2_backend/src/plonk/keygen.rs:165:48 | 165 | let mut acc = zeta * w.pow_vartime(&[index as u64]); | ^^^^^^^^^^^^^^^ help: change this to: `[index as u64]` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrows_for_generic_args
for e in e {
*e = acc - one;
acc *= w;
index += 1;
}
});
l0.batch_invert();

// Compute numinators.
// C * (zeta * w^i)^n = (C * zeta^n) * w^(i*n)
// We use w^k = g and g^n = 1 to save multiplications.
let k = 1 << (vk.domain.extended_k() - vk.domain.k());
let mut wn_powers = vec![zeta_n * c; k];
for i in 1..k {
wn_powers[i] = wn_powers[i - 1] * wn
}

parallelize(&mut l0, |e, mut index| {
for e in e {
*e *= wn_powers[index % k] - c;
index += 1;
}
});

Polynomial {
values: l0,
_marker: std::marker::PhantomData,
}
};
println!("L0 gen: {:?}", start.elapsed());

// Compute l_blind(X) which evaluates to 1 for each blinding factor row
// and 0 otherwise over the domain.
Expand All @@ -150,6 +205,7 @@ where

// Compute l_last(X) which evaluates to 1 on the first inactive row (just
// before the blinding factors) and 0 otherwise over the domain
// TODO L_0 method could be used here too.
let mut l_last = vk.domain.empty_lagrange();
l_last[params.n() as usize - vk.cs.blinding_factors() - 1] = C::Scalar::ONE;
let l_last = vk.domain.lagrange_to_coeff(l_last);
Expand Down

0 comments on commit e3e1737

Please sign in to comment.