From e931d5882dee83826404a318deb8a900c5a55f0c Mon Sep 17 00:00:00 2001 From: "adria0.eth" <5526331+adria0@users.noreply.github.com> Date: Tue, 28 May 2024 08:39:22 +0200 Subject: [PATCH 1/7] feat: improve ZAL - Pending coverage - Do not rely on random for tests - Move performance tests to specific test --- halo2_middleware/Cargo.toml | 6 +- halo2_middleware/src/zal.rs | 149 ++++++++++++++++++++---------------- 2 files changed, 87 insertions(+), 68 deletions(-) diff --git a/halo2_middleware/Cargo.toml b/halo2_middleware/Cargo.toml index 9af93add3b..1c24c77c69 100644 --- a/halo2_middleware/Cargo.toml +++ b/halo2_middleware/Cargo.toml @@ -35,10 +35,8 @@ rayon = "1.8" ark-std = { version = "0.3" } proptest = "1" group = "0.13" -rand_core = { version = "0.6", default-features = false } - -[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] -getrandom = { version = "0.2", features = ["js"] } +rand_xorshift = "0.3.0" +rand_core = "0.6.4" [lib] bench = false diff --git a/halo2_middleware/src/zal.rs b/halo2_middleware/src/zal.rs index 28d9491ce6..e2610b6b96 100644 --- a/halo2_middleware/src/zal.rs +++ b/halo2_middleware/src/zal.rs @@ -257,94 +257,115 @@ mod test { use ark_std::{end_timer, start_timer}; use ff::Field; use group::{Curve, Group}; - use rand_core::OsRng; + use rand_core::SeedableRng; + use rand_xorshift::XorShiftRng; - fn run_msm_zal_default(min_k: usize, max_k: usize) { - let points = (0..1 << max_k) - .map(|_| C::Curve::random(OsRng)) + fn gen_points_scalars(k: usize) -> (Vec, Vec) { + let rand = || XorShiftRng::from_seed([1; 16]); + let points = (0..1 << k) + .map(|_| C::Curve::random(rand())) .collect::>(); - let mut affine_points = vec![C::identity(); 1 << max_k]; + let mut affine_points = vec![C::identity(); 1 << k]; C::Curve::batch_normalize(&points[..], &mut affine_points[..]); let points = affine_points; - let scalars = (0..1 << max_k) - .map(|_| C::Scalar::random(OsRng)) + let scalars = (0..1 << k) + .map(|_| C::Scalar::random(rand())) .collect::>(); - for k in min_k..=max_k { - let points = &points[..1 << k]; - let scalars = &scalars[..1 << k]; + (points, scalars) + } - let t0 = start_timer!(|| format!("freestanding msm k={}", k)); - let e0 = best_multiexp(scalars, points); - end_timer!(t0); + fn run_msm_zal_default(points: &[C], scalars: &[C::Scalar], k: usize) { + let points = &points[..1 << k]; + let scalars = &scalars[..1 << k]; - let engine = PlonkEngineConfig::build_default::(); - let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); - let e1 = engine.msm_backend.msm(scalars, points); - end_timer!(t1); + let t0 = start_timer!(|| format!("freestanding msm k={}", k)); + let e0 = best_multiexp(scalars, points); + end_timer!(t0); - assert_eq!(e0, e1); + let engine = PlonkEngineConfig::build_default::(); + let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); + let e1 = engine.msm_backend.msm(scalars, points); + end_timer!(t1); - // Caching API - // ----------- - let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); - let base_descriptor = engine.msm_backend.get_base_descriptor(points); - let e2 = engine - .msm_backend - .msm_with_cached_base(scalars, &base_descriptor); - end_timer!(t2); + assert_eq!(e0, e1); - assert_eq!(e0, e2) - } + // Caching API + // ----------- + let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); + let base_descriptor = engine.msm_backend.get_base_descriptor(points); + let e2 = engine + .msm_backend + .msm_with_cached_base(scalars, &base_descriptor); + end_timer!(t2); + assert_eq!(e0, e2); + + let t3 = start_timer!(|| format!("H2cEngine msm cached coeffs k={}", k)); + let coeffs_descriptor = engine.msm_backend.get_coeffs_descriptor(scalars); + let e3 = engine + .msm_backend + .msm_with_cached_scalars(&coeffs_descriptor, points); + end_timer!(t3); + assert_eq!(e0, e3); + + let t4 = start_timer!(|| format!("H2cEngine msm cached inputs k={}", k)); + let e4 = engine + .msm_backend + .msm_with_cached_inputs(&coeffs_descriptor, &base_descriptor); + end_timer!(t4); + assert_eq!(e0, e4); } - fn run_msm_zal_custom(min_k: usize, max_k: usize) { - let points = (0..1 << max_k) - .map(|_| C::Curve::random(OsRng)) - .collect::>(); - let mut affine_points = vec![C::identity(); 1 << max_k]; - C::Curve::batch_normalize(&points[..], &mut affine_points[..]); - let points = affine_points; + fn run_msm_zal_custom(points: &[C], scalars: &[C::Scalar], k: usize) { + let points = &points[..1 << k]; + let scalars = &scalars[..1 << k]; - let scalars = (0..1 << max_k) - .map(|_| C::Scalar::random(OsRng)) - .collect::>(); + let t0 = start_timer!(|| format!("freestanding msm k={}", k)); + let e0 = best_multiexp(scalars, points); + end_timer!(t0); + + let engine = PlonkEngineConfig::new() + .set_curve::() + .set_msm(H2cEngine::new()) + .build(); + let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); + let e1 = engine.msm_backend.msm(scalars, points); + end_timer!(t1); + + assert_eq!(e0, e1); + + // Caching API + // ----------- + let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); + let base_descriptor = engine.msm_backend.get_base_descriptor(points); + let e2 = engine + .msm_backend + .msm_with_cached_base(scalars, &base_descriptor); + end_timer!(t2); + + assert_eq!(e0, e2) + } + + #[test] + #[ignore] + fn test_performance_h2c_msm_zal() { + let (min_k, max_k) = (3, 14); + let (points, scalars) = gen_points_scalars::(max_k); for k in min_k..=max_k { let points = &points[..1 << k]; let scalars = &scalars[..1 << k]; - let t0 = start_timer!(|| format!("freestanding msm k={}", k)); - let e0 = best_multiexp(scalars, points); - end_timer!(t0); - - let engine = PlonkEngineConfig::new() - .set_curve::() - .set_msm(H2cEngine::new()) - .build(); - let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); - let e1 = engine.msm_backend.msm(scalars, points); - end_timer!(t1); - - assert_eq!(e0, e1); - - // Caching API - // ----------- - let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); - let base_descriptor = engine.msm_backend.get_base_descriptor(points); - let e2 = engine - .msm_backend - .msm_with_cached_base(scalars, &base_descriptor); - end_timer!(t2); - - assert_eq!(e0, e2) + run_msm_zal_default(points, scalars, k); + run_msm_zal_custom(points, scalars, k); } } #[test] fn test_msm_zal() { - run_msm_zal_default::(3, 14); - run_msm_zal_custom::(3, 14); + let (points, scalars) = gen_points_scalars::(4); + run_msm_zal_default(&points, &scalars, 4); + run_msm_zal_custom(&points, &scalars, 4); } } From 80e69b969966a6b912d645a52fbd46a6a02b2e0a Mon Sep 17 00:00:00 2001 From: David Nevado Date: Thu, 30 May 2024 10:11:37 +0200 Subject: [PATCH 2/7] Quick fix: random generator in ZAL MSM test (#344) fix: random point and scalar generation Also increased the size of the test MSM from k=4 to k=12. --- halo2_middleware/src/zal.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/halo2_middleware/src/zal.rs b/halo2_middleware/src/zal.rs index e2610b6b96..5d376e3e5b 100644 --- a/halo2_middleware/src/zal.rs +++ b/halo2_middleware/src/zal.rs @@ -261,16 +261,17 @@ mod test { use rand_xorshift::XorShiftRng; fn gen_points_scalars(k: usize) -> (Vec, Vec) { - let rand = || XorShiftRng::from_seed([1; 16]); + let mut rng = XorShiftRng::seed_from_u64(3141592u64); + let points = (0..1 << k) - .map(|_| C::Curve::random(rand())) + .map(|_| C::Curve::random(&mut rng)) .collect::>(); let mut affine_points = vec![C::identity(); 1 << k]; C::Curve::batch_normalize(&points[..], &mut affine_points[..]); let points = affine_points; let scalars = (0..1 << k) - .map(|_| C::Scalar::random(rand())) + .map(|_| C::Scalar::random(&mut rng)) .collect::>(); (points, scalars) @@ -364,8 +365,9 @@ mod test { #[test] fn test_msm_zal() { - let (points, scalars) = gen_points_scalars::(4); - run_msm_zal_default(&points, &scalars, 4); - run_msm_zal_custom(&points, &scalars, 4); + const MSM_SIZE: usize = 12; + let (points, scalars) = gen_points_scalars::(MSM_SIZE); + run_msm_zal_default(&points, &scalars, MSM_SIZE); + run_msm_zal_custom(&points, &scalars, MSM_SIZE); } } From b4d1c4c9c2cd4fdeab2b66b6872cdabbf700fcc9 Mon Sep 17 00:00:00 2001 From: guorong009 Date: Mon, 3 Jun 2024 21:07:10 +0800 Subject: [PATCH 3/7] improve: apply tachyon optimizations(1) (#342) * feat: remove "permutation_product_coset" from "Committed" * fix: update "h_commitments" part in "vanishing/prover.rs" * feat: remove the "Constructed" from permutation proving --- halo2_backend/src/plonk/evaluation.rs | 40 +++++++++++-------- halo2_backend/src/plonk/permutation/prover.rs | 37 ++--------------- halo2_backend/src/plonk/prover.rs | 4 +- halo2_backend/src/plonk/vanishing/prover.rs | 22 +++++----- 4 files changed, 41 insertions(+), 62 deletions(-) diff --git a/halo2_backend/src/plonk/evaluation.rs b/halo2_backend/src/plonk/evaluation.rs index ef5c05f1c4..09d8b452d3 100644 --- a/halo2_backend/src/plonk/evaluation.rs +++ b/halo2_backend/src/plonk/evaluation.rs @@ -408,8 +408,16 @@ impl Evaluator { let chunk_len = pk.vk.cs.degree() - 2; let delta_start = beta * C::Scalar::ZETA; - let first_set = sets.first().unwrap(); - let last_set = sets.last().unwrap(); + let permutation_product_cosets: Vec< + Polynomial, + > = sets + .iter() + .map(|set| domain.coeff_to_extended(set.permutation_product_poly.clone())) + .collect(); + + let first_set_permutation_product_coset = + permutation_product_cosets.first().unwrap(); + let last_set_permutation_product_coset = permutation_product_cosets.last().unwrap(); // Permutation constraints parallelize(&mut values, |values, start| { @@ -422,22 +430,21 @@ impl Evaluator { // Enforce only for the first set. // l_0(X) * (1 - z_0(X)) = 0 *value = *value * y - + ((one - first_set.permutation_product_coset[idx]) * l0[idx]); + + ((one - first_set_permutation_product_coset[idx]) * l0[idx]); // Enforce only for the last set. // l_last(X) * (z_l(X)^2 - z_l(X)) = 0 *value = *value * y - + ((last_set.permutation_product_coset[idx] - * last_set.permutation_product_coset[idx] - - last_set.permutation_product_coset[idx]) + + ((last_set_permutation_product_coset[idx] + * last_set_permutation_product_coset[idx] + - last_set_permutation_product_coset[idx]) * l_last[idx]); // Except for the first set, enforce. // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0 - for (set_idx, set) in sets.iter().enumerate() { + for set_idx in 0..sets.len() { if set_idx != 0 { *value = *value * y - + ((set.permutation_product_coset[idx] - - permutation.sets[set_idx - 1].permutation_product_coset - [r_last]) + + ((permutation_product_cosets[set_idx][idx] + - permutation_product_cosets[set_idx - 1][r_last]) * l0[idx]); } } @@ -447,12 +454,13 @@ impl Evaluator { // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) // ) let mut current_delta = delta_start * beta_term; - for ((set, columns), cosets) in sets - .iter() - .zip(p.columns.chunks(chunk_len)) - .zip(pk.permutation.cosets.chunks(chunk_len)) + for ((permutation_product_coset, columns), cosets) in + permutation_product_cosets + .iter() + .zip(p.columns.chunks(chunk_len)) + .zip(pk.permutation.cosets.chunks(chunk_len)) { - let mut left = set.permutation_product_coset[r_next]; + let mut left = permutation_product_coset[r_next]; for (values, permutation) in columns .iter() .map(|&column| match column.column_type { @@ -465,7 +473,7 @@ impl Evaluator { left *= values[idx] + beta * permutation[idx] + gamma; } - let mut right = set.permutation_product_coset[idx]; + let mut right = permutation_product_coset[idx]; for values in columns.iter().map(|&column| match column.column_type { Any::Advice => &advice[column.index], Any::Fixed => &fixed[column.index], diff --git a/halo2_backend/src/plonk/permutation/prover.rs b/halo2_backend/src/plonk/permutation/prover.rs index c80ce2102d..585fa3ab47 100644 --- a/halo2_backend/src/plonk/permutation/prover.rs +++ b/halo2_backend/src/plonk/permutation/prover.rs @@ -13,7 +13,7 @@ use crate::{ plonk::{self, permutation::ProvingKey, ChallengeBeta, ChallengeGamma, ChallengeX, Error}, poly::{ commitment::{Blind, Params}, - Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, ProverQuery, + Coeff, LagrangeCoeff, Polynomial, ProverQuery, }, transcript::{EncodedChallenge, TranscriptWrite}, }; @@ -25,7 +25,6 @@ use halo2_middleware::poly::Rotation; pub(crate) struct CommittedSet { pub(crate) permutation_product_poly: Polynomial, - pub(crate) permutation_product_coset: Polynomial, permutation_product_blind: Blind, } @@ -33,17 +32,8 @@ pub(crate) struct Committed { pub(crate) sets: Vec>, } -pub(crate) struct ConstructedSet { - permutation_product_poly: Polynomial, - permutation_product_blind: Blind, -} - -pub(crate) struct Constructed { - sets: Vec>, -} - pub(crate) struct Evaluated { - constructed: Constructed, + constructed: Committed, } #[allow(clippy::too_many_arguments)] @@ -177,17 +167,13 @@ pub(in crate::plonk) fn permutation_commit< .commit_lagrange(&engine.msm_backend, &z, blind) .to_affine(); let permutation_product_blind = blind; - let z = domain.lagrange_to_coeff(z); - let permutation_product_poly = z.clone(); - - let permutation_product_coset = domain.coeff_to_extended(z); + let permutation_product_poly = domain.lagrange_to_coeff(z); // Hash the permutation product commitment transcript.write_point(permutation_product_commitment)?; sets.push(CommittedSet { permutation_product_poly, - permutation_product_coset, permutation_product_blind, }); } @@ -195,21 +181,6 @@ pub(in crate::plonk) fn permutation_commit< Ok(Committed { sets }) } -impl Committed { - pub(in crate::plonk) fn construct(self) -> Constructed { - Constructed { - sets: self - .sets - .iter() - .map(|set| ConstructedSet { - permutation_product_poly: set.permutation_product_poly.clone(), - permutation_product_blind: set.permutation_product_blind, - }) - .collect(), - } - } -} - impl super::ProvingKey { pub(in crate::plonk) fn open( &self, @@ -236,7 +207,7 @@ impl super::ProvingKey { } } -impl Constructed { +impl Committed { pub(in crate::plonk) fn evaluate, T: TranscriptWrite>( self, pk: &plonk::ProvingKey, diff --git a/halo2_backend/src/plonk/prover.rs b/halo2_backend/src/plonk/prover.rs index 2e4bc173d5..009298ed73 100644 --- a/halo2_backend/src/plonk/prover.rs +++ b/halo2_backend/src/plonk/prover.rs @@ -804,9 +804,7 @@ impl< let permutations_evaluated: Vec> = permutations_commited .into_iter() - .map(|permutation| -> Result<_, _> { - permutation.construct().evaluate(pk, x, self.transcript) - }) + .map(|permutation| -> Result<_, _> { permutation.evaluate(pk, x, self.transcript) }) .collect::, _>>()?; // Evaluate the lookups, if any, at omega^i x. diff --git a/halo2_backend/src/plonk/vanishing/prover.rs b/halo2_backend/src/plonk/vanishing/prover.rs index fb0e5b5e76..96ce797ee4 100644 --- a/halo2_backend/src/plonk/vanishing/prover.rs +++ b/halo2_backend/src/plonk/vanishing/prover.rs @@ -131,18 +131,20 @@ impl Committed { .collect(); // Compute commitments to each h(X) piece - let h_commitments_projective: Vec<_> = h_pieces - .iter() - .zip(h_blinds.iter()) - .map(|(h_piece, blind)| params.commit(&engine.msm_backend, h_piece, *blind)) - .collect(); - let mut h_commitments = vec![C::identity(); h_commitments_projective.len()]; - C::Curve::batch_normalize(&h_commitments_projective, &mut h_commitments); - let h_commitments = h_commitments; + let h_commitments = { + let h_commitments_projective: Vec<_> = h_pieces + .iter() + .zip(h_blinds.iter()) + .map(|(h_piece, blind)| params.commit(&engine.msm_backend, h_piece, *blind)) + .collect(); + let mut h_commitments = vec![C::identity(); h_commitments_projective.len()]; + C::Curve::batch_normalize(&h_commitments_projective, &mut h_commitments); + h_commitments + }; // Hash each h(X) piece - for c in h_commitments.iter() { - transcript.write_point(*c)?; + for c in h_commitments { + transcript.write_point(c)?; } Ok(Constructed { From d52ebca29cb47582f4622e4fe6f736cbef2978fa Mon Sep 17 00:00:00 2001 From: guorong009 Date: Thu, 6 Jun 2024 17:56:57 +0800 Subject: [PATCH 4/7] patch: include `shuffles` in transcript, even when empty (#348) * feat: include "shuffles" in "PinnedConstraintSystem::fmt" * chore: fix the "plonk_api" test --- halo2_backend/src/plonk/circuit.rs | 8 +++----- halo2_proofs/tests/plonk_api.rs | 1 + 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/halo2_backend/src/plonk/circuit.rs b/halo2_backend/src/plonk/circuit.rs index c36bd7b9b4..d9a53edb4e 100644 --- a/halo2_backend/src/plonk/circuit.rs +++ b/halo2_backend/src/plonk/circuit.rs @@ -275,11 +275,9 @@ impl<'a, F: Field> std::fmt::Debug for PinnedConstraintSystem<'a, F> { .field("instance_queries", self.instance_queries) .field("fixed_queries", self.fixed_queries) .field("permutation", self.permutation) - .field("lookups", self.lookups); - if !self.shuffles.is_empty() { - debug_struct.field("shuffles", self.shuffles); - } - debug_struct.field("minimum_degree", self.minimum_degree); + .field("lookups", self.lookups) + .field("shuffles", self.shuffles) + .field("minimum_degree", self.minimum_degree); debug_struct.finish() } } diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index e44484e514..c7512285fe 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -1112,6 +1112,7 @@ fn plonk_api() { ], }, ], + shuffles: [], minimum_degree: None, }, fixed_commitments: [ From e86b7f75d2a08cb82ff2147ae5c89296021b5747 Mon Sep 17 00:00:00 2001 From: guorong009 Date: Fri, 7 Jun 2024 15:43:05 +0800 Subject: [PATCH 5/7] patch: include more fields to transcript (#351) * patch: add multi-phase related challenge fields within transcript * chore: fix the "plonk_api" test * patch: add the Debug attribute --- halo2_backend/src/plonk/circuit.rs | 29 ++--------------------------- halo2_proofs/tests/plonk_api.rs | 9 +++++++++ 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/halo2_backend/src/plonk/circuit.rs b/halo2_backend/src/plonk/circuit.rs index d9a53edb4e..c57b91b76e 100644 --- a/halo2_backend/src/plonk/circuit.rs +++ b/halo2_backend/src/plonk/circuit.rs @@ -238,6 +238,8 @@ impl<'a, F: Field> std::fmt::Debug for PinnedGates<'a, F> { } /// Represents the minimal parameters that determine a `ConstraintSystem`. +#[allow(dead_code)] +#[derive(Debug)] pub(crate) struct PinnedConstraintSystem<'a, F: Field> { num_fixed_columns: &'a usize, num_advice_columns: &'a usize, @@ -255,33 +257,6 @@ pub(crate) struct PinnedConstraintSystem<'a, F: Field> { minimum_degree: &'a Option, } -impl<'a, F: Field> std::fmt::Debug for PinnedConstraintSystem<'a, F> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut debug_struct = f.debug_struct("PinnedConstraintSystem"); - debug_struct - .field("num_fixed_columns", self.num_fixed_columns) - .field("num_advice_columns", self.num_advice_columns) - .field("num_instance_columns", self.num_instance_columns); - // Only show multi-phase related fields if it's used. - if *self.num_challenges > 0 { - debug_struct - .field("num_challenges", self.num_challenges) - .field("advice_column_phase", self.advice_column_phase) - .field("challenge_phase", self.challenge_phase); - } - debug_struct - .field("gates", &self.gates) - .field("advice_queries", self.advice_queries) - .field("instance_queries", self.instance_queries) - .field("fixed_queries", self.fixed_queries) - .field("permutation", self.permutation) - .field("lookups", self.lookups) - .field("shuffles", self.shuffles) - .field("minimum_degree", self.minimum_degree); - debug_struct.finish() - } -} - // Cost functions: arguments required degree /// Returns the minimum circuit degree required by the permutation argument. diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index c7512285fe..9436aff387 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -674,6 +674,15 @@ fn plonk_api() { num_fixed_columns: 7, num_advice_columns: 5, num_instance_columns: 1, + num_challenges: 0, + advice_column_phase: [ + 0, + 0, + 0, + 0, + 0, + ], + challenge_phase: [], gates: [ Sum( Sum( From edffc1e538ef7e1040394a304c9edc9469937bc7 Mon Sep 17 00:00:00 2001 From: Eduard S Date: Mon, 10 Jun 2024 08:33:15 +0200 Subject: [PATCH 6/7] feat: add halo2_debug package (#346) --- Cargo.toml | 1 + halo2_backend/src/plonk/circuit.rs | 10 +- halo2_debug/Cargo.toml | 27 ++ halo2_debug/src/display.rs | 351 ++++++++++++++++++ halo2_debug/src/lib.rs | 1 + .../src/plonk/circuit/constraint_system.rs | 10 + halo2_middleware/src/circuit.rs | 59 ++- halo2_middleware/src/expression.rs | 2 +- halo2_proofs/Cargo.toml | 1 + halo2_proofs/tests/compress_selectors.rs | 64 ++++ 10 files changed, 510 insertions(+), 16 deletions(-) create mode 100644 halo2_debug/Cargo.toml create mode 100644 halo2_debug/src/display.rs create mode 100644 halo2_debug/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 233bf95ad0..1e75bcb87e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "halo2_frontend", "halo2_middleware", "halo2_backend", + "halo2_debug", "p3_frontend", ] resolver = "2" diff --git a/halo2_backend/src/plonk/circuit.rs b/halo2_backend/src/plonk/circuit.rs index c57b91b76e..a754bf4bf8 100644 --- a/halo2_backend/src/plonk/circuit.rs +++ b/halo2_backend/src/plonk/circuit.rs @@ -25,6 +25,12 @@ pub enum VarBack { Challenge(ChallengeMid), } +impl std::fmt::Display for VarBack { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + impl Variable for VarBack { fn degree(&self) -> usize { match self { @@ -40,8 +46,8 @@ impl Variable for VarBack { } } - fn write_identifier(&self, _writer: &mut W) -> std::io::Result<()> { - unimplemented!("unused method") + fn write_identifier(&self, writer: &mut W) -> std::io::Result<()> { + write!(writer, "{}", self) } } diff --git a/halo2_debug/Cargo.toml b/halo2_debug/Cargo.toml new file mode 100644 index 0000000000..d86ea3755b --- /dev/null +++ b/halo2_debug/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "halo2_debug" +version = "0.3.0" +authors = [ + "Privacy Scaling Explorations team", +] +edition = "2021" +rust-version = "1.66.0" +description = """ +Halo2 Debug. This package contains utilities for debugging and testing within +the halo2 ecosystem. +""" +license = "MIT OR Apache-2.0" +repository = "https://github.com/privacy-scaling-explorations/halo2" +documentation = "https://privacy-scaling-explorations.github.io/halo2/" +categories = ["cryptography"] +keywords = ["halo", "proofs", "zkp", "zkSNARKs"] + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] + +[dependencies] +ff = "0.13" +halo2curves = { version = "0.6.1", default-features = false } +num-bigint = "0.4.5" +halo2_middleware = { path = "../halo2_middleware" } diff --git a/halo2_debug/src/display.rs b/halo2_debug/src/display.rs new file mode 100644 index 0000000000..f4feba48f6 --- /dev/null +++ b/halo2_debug/src/display.rs @@ -0,0 +1,351 @@ +use ff::PrimeField; +use halo2_middleware::circuit::{ColumnMid, VarMid}; +use halo2_middleware::expression::{Expression, Variable}; +use halo2_middleware::{lookup, shuffle}; +use num_bigint::BigUint; +use std::collections::HashMap; +use std::fmt; + +/// Wrapper type over `PrimeField` that implements Display with nice output. +/// - If the value is a power of two, format it as `2^k` +/// - If the value is smaller than 2^16, format it in decimal +/// - If the value is bigger than congruent -2^16, format it in decimal as the negative congruent +/// (between -2^16 and 0). +/// - Else format it in hex without leading zeros. +pub struct FDisp<'a, F: PrimeField>(pub &'a F); + +impl fmt::Display for FDisp<'_, F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let v = (*self.0).to_repr(); + let v = v.as_ref(); + let v = BigUint::from_bytes_le(v); + let v_bits = v.bits(); + if v_bits >= 8 && v.count_ones() == 1 { + write!(f, "2^{}", v.trailing_zeros().unwrap_or_default()) + } else if v_bits < 16 { + write!(f, "{}", v) + } else { + let v_neg = (F::ZERO - self.0).to_repr(); + let v_neg = v_neg.as_ref(); + let v_neg = BigUint::from_bytes_le(v_neg); + let v_neg_bits = v_neg.bits(); + if v_neg_bits < 16 { + write!(f, "-{}", v_neg) + } else { + write!(f, "0x{:x}", v) + } + } + } +} + +/// Wrapper type over `Expression` that implements Display with nice output. +/// The formatting of the `Expression::Variable` case is parametrized with the second field, which +/// take as auxiliary value the third field. +/// Use the constructor `expr_disp` to format variables using their `Display` implementation. +/// Use the constructor `expr_disp_names` for an `Expression` with `V: VarMid` to format column +/// queries according to their string names. +pub struct ExprDisp<'a, F: PrimeField, V: Variable, A>( + /// Expression to display + pub &'a Expression, + /// `V: Variable` formatter method that uses auxiliary value + pub fn(&V, &mut fmt::Formatter<'_>, a: &A) -> fmt::Result, + /// Auxiliary value to be passed to the `V: Variable` formatter + pub &'a A, +); + +fn var_fmt_default(v: &V, f: &mut fmt::Formatter<'_>, _: &()) -> fmt::Result { + write!(f, "{}", v) +} + +fn var_fmt_names( + v: &VarMid, + f: &mut fmt::Formatter<'_>, + names: &HashMap, +) -> fmt::Result { + if let VarMid::Query(q) = v { + if let Some(name) = names.get(&ColumnMid::new(q.column_type, q.column_index)) { + return write!(f, "{}", name); + } + } + write!(f, "{}", v) +} + +/// ExprDisp constructor that formats viariables using their `Display` implementation. +pub fn expr_disp(e: &Expression) -> ExprDisp { + ExprDisp(e, var_fmt_default, &()) +} + +/// ExprDisp constructor for an `Expression` with `V: VarMid` that formats column queries according +/// to their string names. +pub fn expr_disp_names<'a, F: PrimeField>( + e: &'a Expression, + names: &'a HashMap, +) -> ExprDisp<'a, F, VarMid, HashMap> { + ExprDisp(e, var_fmt_names, names) +} + +impl fmt::Display for ExprDisp<'_, F, V, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let is_sum = |e: &Expression| -> bool { matches!(e, Expression::Sum(_, _)) }; + let fmt_expr = + |e: &Expression, f: &mut fmt::Formatter<'_>, parens: bool| -> fmt::Result { + if parens { + write!(f, "(")?; + } + write!(f, "{}", ExprDisp(e, self.1, self.2))?; + if parens { + write!(f, ")")?; + } + Ok(()) + }; + + match self.0 { + Expression::Constant(c) => write!(f, "{}", FDisp(c)), + Expression::Var(v) => self.1(v, f, self.2), + Expression::Negated(a) => { + write!(f, "-")?; + fmt_expr(a, f, is_sum(a)) + } + Expression::Sum(a, b) => { + fmt_expr(a, f, false)?; + if let Expression::Negated(neg) = &**b { + write!(f, " - ")?; + fmt_expr(neg, f, is_sum(neg)) + } else { + write!(f, " + ")?; + fmt_expr(b, f, false) + } + } + Expression::Product(a, b) => { + fmt_expr(a, f, is_sum(a))?; + write!(f, " * ")?; + fmt_expr(b, f, is_sum(b)) + } + } + } +} + +/// Wrapper type over `lookup::Argument` that implements Display with nice output. +/// The formatting of the `Expression::Variable` case is parametrized with the second field, which +/// take as auxiliary value the third field. +/// Use the constructor `lookup_arg_disp` to format variables using their `Display` implementation. +/// Use the constructor `lookup_arg_disp_names` for a lookup of `Expression` with `V: VarMid` that +/// formats column queries according to their string names. +pub struct LookupArgDisp<'a, F: PrimeField, V: Variable, A>( + /// Lookup argument to display + pub &'a lookup::Argument, + /// `V: Variable` formatter method that uses auxiliary value + pub fn(&V, &mut fmt::Formatter<'_>, a: &A) -> fmt::Result, + /// Auxiliary value to be passed to the `V: Variable` formatter + pub &'a A, +); + +/// LookupArgDisp constructor that formats viariables using their `Display` implementation. +pub fn lookup_arg_disp( + a: &lookup::Argument, +) -> LookupArgDisp { + LookupArgDisp(a, var_fmt_default, &()) +} + +/// LookupArgDisp constructor for a lookup of `Expression` with `V: VarMid` that formats column +/// queries according to their string names. +pub fn lookup_arg_disp_names<'a, F: PrimeField>( + a: &'a lookup::Argument, + names: &'a HashMap, +) -> LookupArgDisp<'a, F, VarMid, HashMap> { + LookupArgDisp(a, var_fmt_names, names) +} + +impl fmt::Display for LookupArgDisp<'_, F, V, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for (i, expr) in self.0.input_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "] in [")?; + for (i, expr) in self.0.table_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "]")?; + Ok(()) + } +} + +/// Wrapper type over `shuffle::Argument` that implements Display with nice output. +/// The formatting of the `Expression::Variable` case is parametrized with the second field, which +/// take as auxiliary value the third field. +/// Use the constructor `shuffle_arg_disp` to format variables using their `Display` +/// implementation. +/// Use the constructor `shuffle_arg_disp_names` for a shuffle of `Expression` with `V: VarMid` +/// that formats column queries according to their string names. +pub struct ShuffleArgDisp<'a, F: PrimeField, V: Variable, A>( + /// Shuffle argument to display + pub &'a shuffle::Argument, + /// `V: Variable` formatter method that uses auxiliary value + pub fn(&V, &mut fmt::Formatter<'_>, a: &A) -> fmt::Result, + /// Auxiliary value to be passed to the `V: Variable` formatter + pub &'a A, +); + +/// ShuffleArgDisp constructor that formats viariables using their `Display` implementation. +pub fn shuffle_arg_disp( + a: &shuffle::Argument, +) -> ShuffleArgDisp { + ShuffleArgDisp(a, var_fmt_default, &()) +} + +/// ShuffleArgDisp constructor for a shuffle of `Expression` with `V: VarMid` that formats column +/// queries according to their string names. +pub fn shuffle_arg_disp_names<'a, F: PrimeField>( + a: &'a shuffle::Argument, + names: &'a HashMap, +) -> ShuffleArgDisp<'a, F, VarMid, HashMap> { + ShuffleArgDisp(a, var_fmt_names, names) +} + +impl fmt::Display for ShuffleArgDisp<'_, F, V, A> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for (i, expr) in self.0.input_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "] shuff [")?; + for (i, expr) in self.0.shuffle_expressions.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", ExprDisp(expr, self.1, self.2))?; + } + write!(f, "]")?; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use ff::Field; + use halo2_middleware::circuit::{Any, QueryMid, VarMid}; + use halo2_middleware::poly::Rotation; + use halo2curves::bn256::Fr; + + #[test] + fn test_lookup_shuffle_arg_disp() { + type E = Expression; + let a0 = VarMid::Query(QueryMid::new(Any::Advice, 0, Rotation(0))); + let a1 = VarMid::Query(QueryMid::new(Any::Advice, 1, Rotation(0))); + let f0 = VarMid::Query(QueryMid::new(Any::Fixed, 0, Rotation(0))); + let a0: E = Expression::Var(a0); + let a1: E = Expression::Var(a1); + let f0: E = Expression::Var(f0); + + let names = [ + (ColumnMid::new(Any::Advice, 0), "a".to_string()), + (ColumnMid::new(Any::Advice, 1), "b".to_string()), + (ColumnMid::new(Any::Fixed, 0), "s".to_string()), + ] + .into_iter() + .collect(); + + let arg = lookup::Argument { + name: "lookup".to_string(), + input_expressions: vec![f0.clone() * a0.clone(), f0.clone() * a1.clone()], + table_expressions: vec![f0.clone(), f0.clone() * (a0.clone() + a1.clone())], + }; + assert_eq!( + "[f0 * a0, f0 * a1] in [f0, f0 * (a0 + a1)]", + format!("{}", lookup_arg_disp(&arg)) + ); + assert_eq!( + "[s * a, s * b] in [s, s * (a + b)]", + format!("{}", lookup_arg_disp_names(&arg, &names)) + ); + + let arg = shuffle::Argument { + name: "shuffle".to_string(), + input_expressions: vec![f0.clone() * a0.clone(), f0.clone() * a1.clone()], + shuffle_expressions: vec![f0.clone(), f0.clone() * (a0.clone() + a1.clone())], + }; + assert_eq!( + "[f0 * a0, f0 * a1] shuff [f0, f0 * (a0 + a1)]", + format!("{}", shuffle_arg_disp(&arg)) + ); + assert_eq!( + "[s * a, s * b] shuff [s, s * (a + b)]", + format!("{}", shuffle_arg_disp_names(&arg, &names)) + ); + } + + #[test] + fn test_expr_disp() { + type E = Expression; + let a0 = VarMid::Query(QueryMid::new(Any::Advice, 0, Rotation(0))); + let a1 = VarMid::Query(QueryMid::new(Any::Advice, 1, Rotation(0))); + let a0: E = Expression::Var(a0); + let a1: E = Expression::Var(a1); + + let e = a0.clone() + a1.clone(); + assert_eq!("a0 + a1", format!("{}", expr_disp(&e))); + let e = a0.clone() + a1.clone() + a0.clone(); + assert_eq!("a0 + a1 + a0", format!("{}", expr_disp(&e))); + + let e = a0.clone() * a1.clone(); + assert_eq!("a0 * a1", format!("{}", expr_disp(&e))); + let e = a0.clone() * a1.clone() * a0.clone(); + assert_eq!("a0 * a1 * a0", format!("{}", expr_disp(&e))); + + let e = a0.clone() - a1.clone(); + assert_eq!("a0 - a1", format!("{}", expr_disp(&e))); + let e = (a0.clone() - a1.clone()) - a0.clone(); + assert_eq!("a0 - a1 - a0", format!("{}", expr_disp(&e))); + let e = a0.clone() - (a1.clone() - a0.clone()); + assert_eq!("a0 - (a1 - a0)", format!("{}", expr_disp(&e))); + + let e = a0.clone() * a1.clone() + a0.clone(); + assert_eq!("a0 * a1 + a0", format!("{}", expr_disp(&e))); + let e = a0.clone() * (a1.clone() + a0.clone()); + assert_eq!("a0 * (a1 + a0)", format!("{}", expr_disp(&e))); + + let e = a0.clone() + a1.clone(); + let names = [ + (ColumnMid::new(Any::Advice, 0), "a".to_string()), + (ColumnMid::new(Any::Advice, 1), "b".to_string()), + ] + .into_iter() + .collect(); + assert_eq!("a + b", format!("{}", expr_disp_names(&e, &names))); + } + + #[test] + fn test_f_disp() { + let v = Fr::ZERO; + assert_eq!("0", format!("{}", FDisp(&v))); + + let v = Fr::ONE; + assert_eq!("1", format!("{}", FDisp(&v))); + + let v = Fr::from(12345u64); + assert_eq!("12345", format!("{}", FDisp(&v))); + + let v = Fr::from(0x10000); + assert_eq!("2^16", format!("{}", FDisp(&v))); + + let v = Fr::from(0x12345); + assert_eq!("0x12345", format!("{}", FDisp(&v))); + + let v = -Fr::ONE; + assert_eq!("-1", format!("{}", FDisp(&v))); + + let v = -Fr::from(12345u64); + assert_eq!("-12345", format!("{}", FDisp(&v))); + } +} diff --git a/halo2_debug/src/lib.rs b/halo2_debug/src/lib.rs new file mode 100644 index 0000000000..8754563b71 --- /dev/null +++ b/halo2_debug/src/lib.rs @@ -0,0 +1 @@ +pub mod display; diff --git a/halo2_frontend/src/plonk/circuit/constraint_system.rs b/halo2_frontend/src/plonk/circuit/constraint_system.rs index 7a393a6ac0..cabe718042 100644 --- a/halo2_frontend/src/plonk/circuit/constraint_system.rs +++ b/halo2_frontend/src/plonk/circuit/constraint_system.rs @@ -790,6 +790,16 @@ impl ConstraintSystem { /// Annotate an Instance column. pub fn annotate_lookup_any_column(&mut self, column: T, annotation: A) + where + A: Fn() -> AR, + AR: Into, + T: Into>, + { + self.annotate_column(column, annotation) + } + + /// Annotate a column. + pub fn annotate_column(&mut self, column: T, annotation: A) where A: Fn() -> AR, AR: Into, diff --git a/halo2_middleware/src/circuit.rs b/halo2_middleware/src/circuit.rs index 98c2891b74..4c8d4ee0d9 100644 --- a/halo2_middleware/src/circuit.rs +++ b/halo2_middleware/src/circuit.rs @@ -34,6 +34,16 @@ pub struct QueryMid { pub rotation: Rotation, } +impl QueryMid { + pub fn new(column_type: Any, column_index: usize, rotation: Rotation) -> Self { + Self { + column_index, + column_type, + rotation, + } + } +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum VarMid { /// This is a generic column query @@ -42,6 +52,28 @@ pub enum VarMid { Challenge(ChallengeMid), } +impl fmt::Display for VarMid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VarMid::Query(query) => { + match query.column_type { + Any::Fixed => write!(f, "f")?, + Any::Advice => write!(f, "a")?, + Any::Instance => write!(f, "i")?, + }; + write!(f, "{}", query.column_index)?; + if query.rotation.0 != 0 { + write!(f, "[{}]", query.rotation.0)?; + } + Ok(()) + } + VarMid::Challenge(challenge) => { + write!(f, "ch{}", challenge.index()) + } + } + } +} + impl Variable for VarMid { fn degree(&self) -> usize { match self { @@ -58,19 +90,7 @@ impl Variable for VarMid { } fn write_identifier(&self, writer: &mut W) -> std::io::Result<()> { - match self { - VarMid::Query(query) => { - match query.column_type { - Any::Fixed => write!(writer, "fixed")?, - Any::Advice => write!(writer, "advice")?, - Any::Instance => write!(writer, "instance")?, - }; - write!(writer, "[{}][{}]", query.column_index, query.rotation.0) - } - VarMid::Challenge(challenge) => { - write!(writer, "challenge[{}]", challenge.index()) - } - } + write!(writer, "{}", self) } } @@ -136,6 +156,19 @@ pub struct ConstraintSystemMid { pub minimum_degree: Option, } +impl ConstraintSystemMid { + /// Returns the number of phases + pub fn phases(&self) -> usize { + let max_phase = self + .advice_column_phase + .iter() + .copied() + .max() + .unwrap_or_default(); + max_phase as usize + 1 + } +} + /// Data that needs to be preprocessed from a circuit #[derive(Debug, Clone)] pub struct Preprocessing { diff --git a/halo2_middleware/src/expression.rs b/halo2_middleware/src/expression.rs index 217d26c321..1cbb557ed4 100644 --- a/halo2_middleware/src/expression.rs +++ b/halo2_middleware/src/expression.rs @@ -3,7 +3,7 @@ use core::ops::{Add, Mul, Neg, Sub}; use ff::Field; use std::iter::{Product, Sum}; -pub trait Variable: Clone + Copy + std::fmt::Debug + Eq + PartialEq { +pub trait Variable: Clone + Copy + std::fmt::Debug + std::fmt::Display + Eq + PartialEq { /// Degree that an expression would have if it was only this variable. fn degree(&self) -> usize; diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index a68f422934..4302dd7276 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -61,6 +61,7 @@ gumdrop = "0.8" proptest = "1" dhat = "0.3.2" serde_json = "1" +halo2_debug = { path = "../halo2_debug" } [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/halo2_proofs/tests/compress_selectors.rs b/halo2_proofs/tests/compress_selectors.rs index ec87354fc2..b34a099151 100644 --- a/halo2_proofs/tests/compress_selectors.rs +++ b/halo2_proofs/tests/compress_selectors.rs @@ -3,6 +3,8 @@ use std::marker::PhantomData; use ff::PrimeField; +use halo2_debug::display::expr_disp_names; +use halo2_frontend::circuit::compile_circuit; use halo2_frontend::plonk::Error; use halo2_proofs::circuit::{Cell, Layouter, SimpleFloorPlanner, Value}; use halo2_proofs::poly::Rotation; @@ -10,6 +12,7 @@ use halo2_proofs::poly::Rotation; use halo2_backend::transcript::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }; +use halo2_middleware::circuit::{Any, ColumnMid}; use halo2_middleware::zal::impls::{H2cEngine, PlonkEngineConfig}; use halo2_proofs::arithmetic::Field; use halo2_proofs::plonk::{ @@ -101,12 +104,16 @@ impl MyCircuitChip { let l = meta.advice_column(); let r = meta.advice_column(); let o = meta.advice_column(); + meta.annotate_column(l, || "l"); + meta.annotate_column(r, || "r"); + meta.annotate_column(o, || "o"); let s_add = meta.selector(); let s_mul = meta.selector(); let s_cubed = meta.selector(); let PI = meta.instance_column(); + meta.annotate_column(PI, || "pi"); meta.enable_equality(l); meta.enable_equality(r); @@ -435,6 +442,63 @@ How the `compress_selectors` works in `MyCircuit` under the hood: */ +#[test] +fn test_compress_gates() { + let k = 4; + let circuit: MyCircuit = MyCircuit { + x: Value::known(Fr::one()), + y: Value::known(Fr::one()), + constant: Fr::one(), + }; + + // Without compression + + let (mut compress_false, _, _) = compile_circuit(k, &circuit, false).unwrap(); + + let names = &mut compress_false.cs.general_column_annotations; + names.insert(ColumnMid::new(Any::Fixed, 0), "s_add".to_string()); + names.insert(ColumnMid::new(Any::Fixed, 1), "s_mul".to_string()); + names.insert(ColumnMid::new(Any::Fixed, 2), "s_cubed".to_string()); + let cs = &compress_false.cs; + let names = &cs.general_column_annotations; + assert_eq!(3, cs.gates.len()); + assert_eq!( + "s_add * (l + r - o)", + format!("{}", expr_disp_names(&cs.gates[0].poly, names)) + ); + assert_eq!( + "s_mul * (l * r - o)", + format!("{}", expr_disp_names(&cs.gates[1].poly, names)) + ); + assert_eq!( + "s_cubed * (l * l * l - o)", + format!("{}", expr_disp_names(&cs.gates[2].poly, names)) + ); + + // With compression + + let (mut compress_true, _, _) = compile_circuit(k, &circuit, true).unwrap(); + + let names = &mut compress_true.cs.general_column_annotations; + names.insert(ColumnMid::new(Any::Fixed, 0), "s_add_mul".to_string()); + names.insert(ColumnMid::new(Any::Fixed, 1), "s_cubed".to_string()); + let cs = &compress_true.cs; + let names = &cs.general_column_annotations; + assert_eq!(3, cs.gates.len()); + assert_eq!( + "s_add_mul * (2 - s_add_mul) * (l + r - o)", + format!("{}", expr_disp_names(&cs.gates[0].poly, names)) + ); + assert_eq!( + "s_add_mul * (1 - s_add_mul) * (l * r - o)", + format!("{}", expr_disp_names(&cs.gates[1].poly, names)) + ); + assert_eq!( + "s_cubed * (l * l * l - o)", + format!("{}", expr_disp_names(&cs.gates[2].poly, names)) + ); +} + #[test] fn test_success() { // vk & pk keygen both WITH compress From 32599e898a62b21647c55e93f689e454b093c6fd Mon Sep 17 00:00:00 2001 From: David Nevado Date: Tue, 11 Jun 2024 11:16:22 +0200 Subject: [PATCH 7/7] Use vectors instead of slices for PI (#353) feat!: Use Vectors insead of slices for PI Instances were being passed as a triple slice of field elements: &[&[&[F]]] in many functions. It has been replaced for `&[Vec>]`. --- halo2_backend/src/plonk/prover.rs | 18 +++++---------- halo2_backend/src/plonk/verifier.rs | 11 ++++++---- halo2_backend/src/plonk/verifier/batch.rs | 10 +-------- halo2_frontend/src/circuit.rs | 6 ++--- halo2_proofs/benches/plonk.rs | 4 ++-- halo2_proofs/src/plonk/prover.rs | 12 +++++----- halo2_proofs/tests/compress_selectors.rs | 10 +++------ halo2_proofs/tests/frontend_backend_split.rs | 23 ++++++-------------- halo2_proofs/tests/plonk_api.rs | 21 +++++++----------- halo2_proofs/tests/serialization.rs | 6 ++--- halo2_proofs/tests/shuffle.rs | 4 ++-- halo2_proofs/tests/shuffle_api.rs | 4 ++-- halo2_proofs/tests/vector-ops-unblinded.rs | 5 +++-- p3_frontend/tests/common/mod.rs | 5 ++--- 14 files changed, 55 insertions(+), 84 deletions(-) diff --git a/halo2_backend/src/plonk/prover.rs b/halo2_backend/src/plonk/prover.rs index 009298ed73..af72d863d8 100644 --- a/halo2_backend/src/plonk/prover.rs +++ b/halo2_backend/src/plonk/prover.rs @@ -68,9 +68,7 @@ impl< engine: PlonkEngine, params: &'params Scheme::ParamsProver, pk: &'a ProvingKey, - // TODO: If this was a vector the usage would be simpler - // https://github.com/privacy-scaling-explorations/halo2/issues/265 - instance: &[&[Scheme::Scalar]], + instance: Vec>, rng: R, transcript: &'a mut T, ) -> Result @@ -90,9 +88,7 @@ impl< pub fn new( params: &'params Scheme::ParamsProver, pk: &'a ProvingKey, - // TODO: If this was a vector the usage would be simpler - // https://github.com/privacy-scaling-explorations/halo2/issues/265 - instance: &[&[Scheme::Scalar]], + instance: Vec>, rng: R, transcript: &'a mut T, ) -> Result, Error> @@ -175,9 +171,7 @@ impl< engine: PlonkEngine, params: &'params Scheme::ParamsProver, pk: &'a ProvingKey, - // TODO: If this was a vector the usage would be simpler. - // https://github.com/privacy-scaling-explorations/halo2/issues/265 - circuits_instances: &[&[&[Scheme::Scalar]]], + circuits_instances: &[Vec>], rng: R, transcript: &'a mut T, ) -> Result @@ -201,7 +195,7 @@ impl< // commit_instance_fn is a helper function to return the polynomials (and its commitments) of // instance columns while updating the transcript. let mut commit_instance_fn = - |instance: &[&[Scheme::Scalar]]| -> Result, Error> { + |instance: &[Vec]| -> Result, Error> { // Create a lagrange polynomial for each instance column let instance_values = instance @@ -905,9 +899,7 @@ impl< pub fn new( params: &'params Scheme::ParamsProver, pk: &'a ProvingKey, - // TODO: If this was a vector the usage would be simpler. - // https://github.com/privacy-scaling-explorations/halo2/issues/265 - circuits_instances: &[&[&[Scheme::Scalar]]], + circuits_instances: &[Vec>], rng: R, transcript: &'a mut T, ) -> Result, Error> diff --git a/halo2_backend/src/plonk/verifier.rs b/halo2_backend/src/plonk/verifier.rs index d06224dcf5..53d9da4181 100644 --- a/halo2_backend/src/plonk/verifier.rs +++ b/halo2_backend/src/plonk/verifier.rs @@ -34,7 +34,7 @@ pub fn verify_proof_single<'params, Scheme, V, E, T, Strategy>( params: &'params Scheme::ParamsVerifier, vk: &VerifyingKey, strategy: Strategy, - instance: &[&[Scheme::Scalar]], + instance: Vec>, transcript: &mut T, ) -> Result where @@ -60,7 +60,7 @@ pub fn verify_proof< params: &'params Scheme::ParamsVerifier, vk: &VerifyingKey, strategy: Strategy, - instances: &[&[&[Scheme::Scalar]]], + instances: &[Vec>], transcript: &mut T, ) -> Result where @@ -301,9 +301,12 @@ where .instance_queries .iter() .map(|(column, rotation)| { - let instances = instances[column.index]; + let instances = &instances[column.index]; let offset = (max_rotation - rotation.0) as usize; - compute_inner_product(instances, &l_i_s[offset..offset + instances.len()]) + compute_inner_product( + instances.as_slice(), + &l_i_s[offset..offset + instances.len()], + ) }) .collect::>() }) diff --git a/halo2_backend/src/plonk/verifier/batch.rs b/halo2_backend/src/plonk/verifier/batch.rs index f33a5bf5a0..54b06450d9 100644 --- a/halo2_backend/src/plonk/verifier/batch.rs +++ b/halo2_backend/src/plonk/verifier/batch.rs @@ -99,7 +99,6 @@ where // `is_zero() == false` then this argument won't be able to interfere with it // to make it true, with high probability. acc.scale(C::Scalar::random(OsRng)); - acc.add_msm(&msm); acc } @@ -109,16 +108,9 @@ where .into_par_iter() .enumerate() .map(|(i, item)| { - let instances: Vec> = item - .instances - .iter() - .map(|i| i.iter().map(|c| &c[..]).collect()) - .collect(); - let instances: Vec<_> = instances.iter().map(|i| &i[..]).collect(); - let strategy = BatchStrategy::new(params); let mut transcript = Blake2bRead::init(&item.proof[..]); - verify_proof(params, vk, strategy, &instances, &mut transcript).map_err(|e| { + verify_proof(params, vk, strategy, &item.instances, &mut transcript).map_err(|e| { tracing::debug!("Batch item {} failed verification: {}", i, e); e }) diff --git a/halo2_frontend/src/circuit.rs b/halo2_frontend/src/circuit.rs index 5c6a75c731..4f67378545 100644 --- a/halo2_frontend/src/circuit.rs +++ b/halo2_frontend/src/circuit.rs @@ -117,7 +117,7 @@ struct WitnessCollection<'a, F: Field> { advice_column_phase: &'a Vec, advice: Vec>>, challenges: &'a HashMap, - instances: &'a [&'a [F]], + instances: &'a [Vec], usable_rows: RangeTo, } @@ -259,7 +259,7 @@ pub struct WitnessCalculator<'a, F: Field, ConcreteCircuit: Circuit> { circuit: &'a ConcreteCircuit, config: &'a ConcreteCircuit::Config, cs: &'a ConstraintSystem, - instances: &'a [&'a [F]], + instances: &'a [Vec], next_phase: u8, } @@ -270,7 +270,7 @@ impl<'a, F: Field, ConcreteCircuit: Circuit> WitnessCalculator<'a, F, Concret circuit: &'a ConcreteCircuit, config: &'a ConcreteCircuit::Config, cs: &'a ConstraintSystem, - instances: &'a [&'a [F]], + instances: &'a [Vec], ) -> Self { let n = 2usize.pow(k); let unusable_rows_start = n - (cs.blinding_factors() + 1); diff --git a/halo2_proofs/benches/plonk.rs b/halo2_proofs/benches/plonk.rs index ac531f1c53..9827fc5aa4 100644 --- a/halo2_proofs/benches/plonk.rs +++ b/halo2_proofs/benches/plonk.rs @@ -291,7 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { params, pk, &[circuit], - &[&[]], + &[vec![vec![]]], rng, &mut transcript, ) @@ -302,7 +302,7 @@ fn criterion_benchmark(c: &mut Criterion) { fn verifier(params: &ParamsIPA, vk: &VerifyingKey, proof: &[u8]) { let strategy = SingleStrategy::new(params); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); - assert!(verify_proof(params, vk, strategy, &[&[]], &mut transcript).is_ok()); + assert!(verify_proof(params, vk, strategy, &[vec![vec![]]], &mut transcript).is_ok()); } let k_range = 8..=16; diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index 133df4da40..21caf757fd 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -30,7 +30,7 @@ pub fn create_proof_with_engine< params: &'params Scheme::ParamsProver, pk: &ProvingKey, circuits: &[ConcreteCircuit], - instances: &[&[&[Scheme::Scalar]]], + instances: &[Vec>], rng: R, transcript: &mut T, ) -> Result<(), Error> @@ -51,7 +51,9 @@ where let mut witness_calcs: Vec<_> = circuits .iter() .enumerate() - .map(|(i, circuit)| WitnessCalculator::new(params.k(), circuit, &config, &cs, instances[i])) + .map(|(i, circuit)| { + WitnessCalculator::new(params.k(), circuit, &config, &cs, instances[i].as_slice()) + }) .collect(); let mut prover = Prover::::new_with_engine( engine, params, pk, instances, rng, transcript, @@ -84,7 +86,7 @@ pub fn create_proof< params: &'params Scheme::ParamsProver, pk: &ProvingKey, circuits: &[ConcreteCircuit], - instances: &[&[&[Scheme::Scalar]]], + instances: &[Vec>], rng: R, transcript: &mut T, ) -> Result<(), Error> @@ -160,7 +162,7 @@ fn test_create_proof() { ¶ms, &pk, &[MyCircuit, MyCircuit], - &[&[], &[]], + &[vec![], vec![]], OsRng, &mut transcript, ) @@ -220,7 +222,7 @@ fn test_create_proof_custom() { ¶ms, &pk, &[MyCircuit, MyCircuit], - &[&[], &[]], + &[vec![], vec![]], OsRng, &mut transcript, ) diff --git a/halo2_proofs/tests/compress_selectors.rs b/halo2_proofs/tests/compress_selectors.rs index b34a099151..5362757295 100644 --- a/halo2_proofs/tests/compress_selectors.rs +++ b/halo2_proofs/tests/compress_selectors.rs @@ -372,11 +372,7 @@ fn test_mycircuit( // Proving #[allow(clippy::useless_vec)] - let instances = vec![vec![Fr::one(), Fr::from_u128(3)]]; - let instances_slice: &[&[Fr]] = &(instances - .iter() - .map(|instance| instance.as_slice()) - .collect::>()); + let instances = vec![vec![vec![Fr::one(), Fr::from_u128(3)]]]; let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); create_proof_with_engine::, ProverSHPLONK<'_, Bn256>, _, _, _, _, _>( @@ -384,7 +380,7 @@ fn test_mycircuit( ¶ms, &pk, &[circuit], - &[instances_slice], + instances.as_slice(), &mut rng, &mut transcript, )?; @@ -399,7 +395,7 @@ fn test_mycircuit( &verifier_params, &vk, strategy, - &[instances_slice], + instances.as_slice(), &mut verifier_transcript, ) .map_err(halo2_proofs::plonk::Error::Backend) diff --git a/halo2_proofs/tests/frontend_backend_split.rs b/halo2_proofs/tests/frontend_backend_split.rs index 127d8552e0..5c965191b2 100644 --- a/halo2_proofs/tests/frontend_backend_split.rs +++ b/halo2_proofs/tests/frontend_backend_split.rs @@ -523,11 +523,7 @@ fn test_mycircuit_full_legacy() { println!("Keygen: {:?}", start.elapsed()); // Proving - let instances = circuit.instances(); - let instances_slice: &[&[Fr]] = &(instances - .iter() - .map(|instance| instance.as_slice()) - .collect::>()); + let instances = vec![circuit.instances()]; let start = Instant::now(); let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); @@ -535,7 +531,7 @@ fn test_mycircuit_full_legacy() { ¶ms, &pk, &[circuit], - &[instances_slice], + instances.as_slice(), &mut rng, &mut transcript, ) @@ -554,7 +550,7 @@ fn test_mycircuit_full_legacy() { &verifier_params, &vk, strategy, - &[instances_slice], + instances.as_slice(), &mut verifier_transcript, ) .expect("verify succeeds"); @@ -585,16 +581,11 @@ fn test_mycircuit_full_split() { println!("Keygen: {:?}", start.elapsed()); drop(compiled_circuit); + let instances = circuit.instances(); // Proving println!("Proving..."); - let instances = circuit.instances(); - let instances_slice: &[&[Fr]] = &(instances - .iter() - .map(|instance| instance.as_slice()) - .collect::>()); - let start = Instant::now(); - let mut witness_calc = WitnessCalculator::new(k, &circuit, &config, &cs, instances_slice); + let mut witness_calc = WitnessCalculator::new(k, &circuit, &config, &cs, &instances); let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); let mut prover = ProverSingle::< KZGCommitmentScheme, @@ -607,7 +598,7 @@ fn test_mycircuit_full_split() { engine, ¶ms, &pk, - instances_slice, +instances.clone(), &mut rng, &mut transcript, ) @@ -634,7 +625,7 @@ fn test_mycircuit_full_split() { &verifier_params, &vk, strategy, - instances_slice, + instances, &mut verifier_transcript, ) .expect("verify succeeds"); diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index 9436aff387..9ddd6de68a 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -489,7 +489,7 @@ fn plonk_api() { where Scheme::Scalar: Ord + WithSmallOrderMulGroup<3> + FromUniformBytes<64>, { - let (a, instance, lookup_table) = common!(Scheme); + let (a, instance_val, lookup_table) = common!(Scheme); let circuit: MyCircuit = MyCircuit { a: Value::known(a), @@ -498,19 +498,20 @@ fn plonk_api() { let mut transcript = T::init(vec![]); + let instance = [vec![vec![instance_val]], vec![vec![instance_val]]]; create_plonk_proof_with_engine::( engine, params, pk, &[circuit.clone(), circuit.clone()], - &[&[&[instance]], &[&[instance]]], + &instance, rng, &mut transcript, ) .expect("proof generation should not fail"); // Check this circuit is satisfied. - let prover = match MockProver::run(K, &circuit, vec![vec![instance]]) { + let prover = match MockProver::run(K, &circuit, vec![vec![instance_val]]) { Ok(prover) => prover, Err(e) => panic!("{e:?}"), }; @@ -553,20 +554,14 @@ fn plonk_api() { ) where Scheme::Scalar: Ord + WithSmallOrderMulGroup<3> + FromUniformBytes<64>, { - let (_, instance, _) = common!(Scheme); - let pubinputs = [instance]; + let (_, instance_val, _) = common!(Scheme); let mut transcript = T::init(proof); + let instance = [vec![vec![instance_val]], vec![vec![instance_val]]]; let strategy = Strategy::new(params_verifier); - let strategy = verify_plonk_proof( - params_verifier, - vk, - strategy, - &[&[&pubinputs[..]], &[&pubinputs[..]]], - &mut transcript, - ) - .unwrap(); + let strategy = + verify_plonk_proof(params_verifier, vk, strategy, &instance, &mut transcript).unwrap(); assert!(strategy.finalize()); } diff --git a/halo2_proofs/tests/serialization.rs b/halo2_proofs/tests/serialization.rs index 93e98989e0..93dfca7dff 100644 --- a/halo2_proofs/tests/serialization.rs +++ b/halo2_proofs/tests/serialization.rs @@ -156,7 +156,7 @@ fn test_serialization() { std::fs::remove_file("serialization-test.pk").unwrap(); - let instances: &[&[Fr]] = &[&[circuit.0]]; + let instances: Vec>> = vec![vec![vec![circuit.0]]]; let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -169,7 +169,7 @@ fn test_serialization() { ¶ms, &pk, &[circuit], - &[instances], + instances.as_slice(), OsRng, &mut transcript, ) @@ -189,7 +189,7 @@ fn test_serialization() { &verifier_params, pk.get_vk(), strategy, - &[instances], + instances.as_slice(), &mut transcript ) .is_ok()); diff --git a/halo2_proofs/tests/shuffle.rs b/halo2_proofs/tests/shuffle.rs index 7ecfb49edc..0b27a3509c 100644 --- a/halo2_proofs/tests/shuffle.rs +++ b/halo2_proofs/tests/shuffle.rs @@ -287,7 +287,7 @@ fn test_prover( ¶ms, &pk, &[circuit], - &[&[]], + &[vec![]], OsRng, &mut transcript, ) @@ -304,7 +304,7 @@ fn test_prover( ¶ms, pk.get_vk(), strategy, - &[&[]], + &[vec![]], &mut transcript, ) .map(|strategy| strategy.finalize()) diff --git a/halo2_proofs/tests/shuffle_api.rs b/halo2_proofs/tests/shuffle_api.rs index e7034e6f36..a5c1167081 100644 --- a/halo2_proofs/tests/shuffle_api.rs +++ b/halo2_proofs/tests/shuffle_api.rs @@ -163,7 +163,7 @@ where ¶ms, &pk, &[circuit], - &[&[]], + &[vec![]], OsRng, &mut transcript, ) @@ -180,7 +180,7 @@ where ¶ms, pk.get_vk(), strategy, - &[&[]], + &[vec![]], &mut transcript, ) .map(|strategy| strategy.finalize()) diff --git a/halo2_proofs/tests/vector-ops-unblinded.rs b/halo2_proofs/tests/vector-ops-unblinded.rs index 01c24fef4d..73aa58ab69 100644 --- a/halo2_proofs/tests/vector-ops-unblinded.rs +++ b/halo2_proofs/tests/vector-ops-unblinded.rs @@ -479,6 +479,7 @@ where let vk = keygen_vk(¶ms, &circuit).unwrap(); let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let instances = vec![vec![instances]]; let proof = { let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); @@ -486,7 +487,7 @@ where ¶ms, &pk, &[circuit], - &[&[&instances]], + &instances, OsRng, &mut transcript, ) @@ -503,7 +504,7 @@ where ¶ms, pk.get_vk(), strategy, - &[&[&instances]], + &instances, &mut transcript, ) .map(|strategy| strategy.finalize()) diff --git a/p3_frontend/tests/common/mod.rs b/p3_frontend/tests/common/mod.rs index 62062584a4..a7182f36f7 100644 --- a/p3_frontend/tests/common/mod.rs +++ b/p3_frontend/tests/common/mod.rs @@ -87,7 +87,6 @@ pub(crate) fn setup_prove_verify( // Proving println!("Proving..."); let start = Instant::now(); - let vec_slices: Vec<&[Fr]> = pis.iter().map(|pi| &pi[..]).collect(); let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); let mut prover = ProverSingle::< KZGCommitmentScheme, @@ -96,7 +95,7 @@ pub(crate) fn setup_prove_verify( _, _, H2cEngine, - >::new(¶ms, &pk, &vec_slices, &mut rng, &mut transcript) + >::new(¶ms, &pk, pis.to_vec(), &mut rng, &mut transcript) .unwrap(); println!("phase 0"); prover.commit_phase(0, witness).unwrap(); @@ -115,7 +114,7 @@ pub(crate) fn setup_prove_verify( &verifier_params, &vk, strategy, - &vec_slices, + pis.to_vec(), &mut verifier_transcript, ) .expect("verify succeeds");