From 6959f51f0143f5f3f13cf22697a35263a9d20e5c Mon Sep 17 00:00:00 2001 From: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:18:34 -0600 Subject: [PATCH] [VAN-929] improve performance of bucket interpreter (#82) * Add interpreter functions for non-modifying buckets. This allows certain buckets to be interpreted without first cloning the Env object. * Remove additional Env clones * Remove unused functions from Env * Use bucket ID instead of cloning buckets for map key in unknown index sanitizer --- .../env/extracted_func_env.rs | 6 +- .../src/bucket_interpreter/env/mod.rs | 8 - .../bucket_interpreter/env/standard_env.rs | 4 - .../env/unrolled_block_env.rs | 6 +- circuit_passes/src/bucket_interpreter/mod.rs | 248 ++++++++++-------- .../deterministic_subcomponent_invocation.rs | 3 +- .../passes/loop_unroll/loop_env_recorder.rs | 32 +-- .../src/passes/mapped_to_indexed.rs | 12 +- circuit_passes/src/passes/simplification.rs | 6 +- .../src/passes/unknown_index_sanitization.rs | 69 ++--- 10 files changed, 193 insertions(+), 201 deletions(-) diff --git a/circuit_passes/src/bucket_interpreter/env/extracted_func_env.rs b/circuit_passes/src/bucket_interpreter/env/extracted_func_env.rs index d79011823..6ff04342c 100644 --- a/circuit_passes/src/bucket_interpreter/env/extracted_func_env.rs +++ b/circuit_passes/src/bucket_interpreter/env/extracted_func_env.rs @@ -1,5 +1,5 @@ use std::cell::Ref; -use std::collections::{HashMap, BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashSet}; use std::fmt::{Display, Formatter}; use compiler::circuit_design::function::FunctionCode; use compiler::circuit_design::template::TemplateCode; @@ -251,10 +251,6 @@ impl<'a> ExtractedFuncEnvData<'a> { res } - pub fn get_vars_clone(&self) -> HashMap { - self.base.get_vars_clone() - } - pub fn get_vars_sort(&self) -> BTreeMap { self.base.get_vars_sort() } diff --git a/circuit_passes/src/bucket_interpreter/env/mod.rs b/circuit_passes/src/bucket_interpreter/env/mod.rs index eb7934fa3..1379cb8c6 100644 --- a/circuit_passes/src/bucket_interpreter/env/mod.rs +++ b/circuit_passes/src/bucket_interpreter/env/mod.rs @@ -211,14 +211,6 @@ impl<'a> Env<'a> { } } - pub fn get_vars_clone(&self) -> HashMap { - match self { - Env::Standard(d) => d.get_vars_clone(), - Env::UnrolledBlock(d) => d.get_vars_clone(), - Env::ExtractedFunction(d) => d.get_vars_clone(), - } - } - pub fn get_vars_sort(&self) -> BTreeMap { match self { Env::Standard(d) => d.get_vars_sort(), diff --git a/circuit_passes/src/bucket_interpreter/env/standard_env.rs b/circuit_passes/src/bucket_interpreter/env/standard_env.rs index 5428fecf6..c9cf4bd09 100644 --- a/circuit_passes/src/bucket_interpreter/env/standard_env.rs +++ b/circuit_passes/src/bucket_interpreter/env/standard_env.rs @@ -84,10 +84,6 @@ impl<'a> StandardEnvData<'a> { self.subcmps.get(&subcmp_idx).unwrap().counter_equal_to(value) } - pub fn get_vars_clone(&self) -> HashMap { - self.vars.clone() - } - pub fn get_vars_sort(&self) -> BTreeMap { self.vars.iter().fold(BTreeMap::new(), |mut acc, e| { acc.insert(*e.0, e.1.clone()); diff --git a/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs b/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs index 2a3c3bf26..9f94001c5 100644 --- a/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs +++ b/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs @@ -1,5 +1,5 @@ use std::cell::Ref; -use std::collections::{HashMap, BTreeMap}; +use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; use compiler::circuit_design::function::FunctionCode; use compiler::circuit_design::template::TemplateCode; @@ -94,10 +94,6 @@ impl<'a> UnrolledBlockEnvData<'a> { self.base.subcmp_counter_equal_to(subcmp_idx, value) } - pub fn get_vars_clone(&self) -> HashMap { - self.base.get_vars_clone() - } - pub fn get_vars_sort(&self) -> BTreeMap { self.base.get_vars_sort() } diff --git a/circuit_passes/src/bucket_interpreter/mod.rs b/circuit_passes/src/bucket_interpreter/mod.rs index bc4645a68..89c2501c6 100644 --- a/circuit_passes/src/bucket_interpreter/mod.rs +++ b/circuit_passes/src/bucket_interpreter/mod.rs @@ -32,7 +32,8 @@ pub struct BucketInterpreter<'a, 'd> { p: BigInt, } -pub type R<'a> = Result<(Option, Env<'a>), BadInterp>; +pub type RE<'e> = Result<(Option, Env<'e>), BadInterp>; +pub type RC = Result, BadInterp>; #[inline] pub fn into_result(v: Option, label: S) -> Result { @@ -71,7 +72,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &ValueBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_value_bucket(bucket, env, observe), bucket) } @@ -80,7 +81,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env LoadBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_load_bucket(bucket, env, observe), bucket) } @@ -89,7 +90,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env StoreBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_store_bucket(bucket, env, observe), bucket) } @@ -98,7 +99,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env ComputeBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_compute_bucket(bucket, env, observe), bucket) } @@ -107,7 +108,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env CallBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_call_bucket(bucket, env, observe), bucket) } @@ -116,7 +117,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env BranchBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_branch_bucket(bucket, env, observe), bucket) } @@ -125,7 +126,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env ReturnBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_return_bucket(bucket, env, observe), bucket) } @@ -134,7 +135,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env AssertBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_assert_bucket(bucket, env, observe), bucket) } @@ -143,7 +144,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env LogBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_log_bucket(bucket, env, observe), bucket) } @@ -189,7 +190,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env LoopBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_loop_bucket(bucket, env, observe), bucket) } @@ -198,7 +199,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env CreateCmpBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_create_cmp_bucket(bucket, env, observe), bucket) } @@ -207,7 +208,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env ConstraintBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_constraint_bucket(bucket, env, observe), bucket) } @@ -216,7 +217,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env BlockBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_block_bucket(bucket, env, observe), bucket) } @@ -225,7 +226,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &NopBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_nop_bucket(bucket, env, observe), bucket) } @@ -234,7 +235,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { inst: &'env InstructionPointer, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { add_loc_if_err(self._execute_instruction(inst, env, observe), inst.as_ref()) } @@ -243,7 +244,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { instructions: &'env [InstructionPointer], env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let mut last = (None, env); for inst in instructions { last = self.execute_instruction(inst, last.1, observe)?; @@ -251,6 +252,10 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { Ok(last) } + pub fn compute_instruction(&self, inst: &InstructionPointer, env: &Env, observe: bool) -> RC { + add_loc_if_err(self._compute_instruction(inst, env, observe), inst.as_ref()) + } + /**************************************************************************************************** * Private implemenation * Allows any number of calls to the internal "_execute*bucket" functions without adding source @@ -264,7 +269,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { ) -> Result { match location { LocationRule::Indexed { location, .. } => { - let (idx, _) = self._execute_instruction(location, env.clone(), false)?; + let idx = self.compute_instruction(location, env, false)?; Value::into_u32_result(idx, "index of location") } LocationRule::Mapped { .. } => unreachable!(), @@ -389,48 +394,41 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { Ok((vars, signals, subcmps)) } + fn _compute_value_bucket(&self, bucket: &ValueBucket, _: &Env, _: bool) -> RC { + Ok(Some(match bucket.parse_as { + ValueType::U32 => KnownU32(bucket.value), + ValueType::BigInt => { + let constant = &self.mem.get_field_constant(bucket.value); + KnownBigInt(add_loc_if_err(to_bigint(constant), bucket)?) + } + })) + } + fn _execute_value_bucket<'env>( &self, bucket: &ValueBucket, env: Env<'env>, - _observe: bool, - ) -> R<'env> { - Ok(( - Some(match bucket.parse_as { - ValueType::U32 => KnownU32(bucket.value), - ValueType::BigInt => { - let constant = &self.mem.get_field_constant(bucket.value); - KnownBigInt(add_loc_if_err(to_bigint(constant), bucket)?) - } - }), - env, - )) + observe: bool, + ) -> RE<'env> { + self._compute_value_bucket(bucket, &env, observe).map(|r| (r, env)) } - fn _execute_load_bucket<'env>( - &self, - bucket: &'env LoadBucket, - env: Env<'env>, - observe: bool, - ) -> R<'env> { + fn _compute_load_bucket(&self, bucket: &LoadBucket, env: &Env, observe: bool) -> RC { match &bucket.address_type { AddressType::Variable => { - let continue_observing = if observe { - self.observer.on_location_rule(&bucket.src, &env)? - } else { - false - }; - let (idx, env) = match &bucket.src { + let continue_observing = + if observe { self.observer.on_location_rule(&bucket.src, env)? } else { false }; + let idx = match &bucket.src { LocationRule::Indexed { location, .. } => { - self._execute_instruction(location, env, continue_observing)? + self._compute_instruction(location, env, continue_observing)? } LocationRule::Mapped { .. } => unreachable!(), }; let idx = into_result(idx, "load source variable")?; if idx.is_unknown() { - Ok((Some(Unknown), env)) + Ok(Some(Unknown)) } else { - Ok((Some(env.get_var(idx.get_u32()?)), env)) + Ok(Some(env.get_var(idx.get_u32()?))) } } AddressType::Signal => { @@ -439,61 +437,64 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { } else { false }; - let (idx, env) = match &bucket.src { + let idx = match &bucket.src { LocationRule::Indexed { location, .. } => { - self._execute_instruction(location, env, continue_observing)? + self._compute_instruction(location, env, continue_observing)? } LocationRule::Mapped { .. } => unreachable!(), }; let idx = into_result(idx, "load source signal")?; if idx.is_unknown() { - Ok((Some(Unknown), env)) + Ok(Some(Unknown)) } else { - Ok((Some(env.get_signal(idx.get_u32()?)), env)) + Ok(Some(env.get_signal(idx.get_u32()?))) } } AddressType::SubcmpSignal { cmp_address, .. } => { - let (addr, env) = self._execute_instruction(cmp_address, env, observe)?; + let addr = self._compute_instruction(cmp_address, env, observe)?; let addr = Value::into_u32_result(addr, "load source subcomponent")?; let continue_observing = if observe { self.observer.on_location_rule(&bucket.src, &env)? } else { false }; - let (idx, env) = match &bucket.src { + let idx = match &bucket.src { LocationRule::Indexed { location, .. } => { - let (i, e) = - self._execute_instruction(location, env, continue_observing)?; - (Value::into_u32_result(i, "load source subcomponent indexed signal")?, e) + let i = self._compute_instruction(location, env, continue_observing)?; + Value::into_u32_result(i, "load source subcomponent indexed signal")? } LocationRule::Mapped { signal_code, indexes } => { - let mut acc_env = env; let io_def = - self.mem.get_iodef(&acc_env.get_subcmp_template_id(addr), signal_code); - let map_access = io_def.offset; + self.mem.get_iodef(&env.get_subcmp_template_id(addr), signal_code); if indexes.len() > 0 { let mut indexes_values = vec![]; for i in indexes { - let (val, new_env) = - self._execute_instruction(i, acc_env, continue_observing)?; + let val = self._compute_instruction(i, env, continue_observing)?; indexes_values.push(Value::into_u32_result( val, "load source subcomponent mapped signal", )?); - acc_env = new_env; } - let offset = compute_offset(&indexes_values, &io_def.lengths)?; - (map_access + offset, acc_env) + io_def.offset + compute_offset(&indexes_values, &io_def.lengths)? } else { - (map_access, acc_env) + io_def.offset } } }; - Ok((Some(env.get_subcmp_signal(addr, idx)), env)) + Ok(Some(env.get_subcmp_signal(addr, idx))) } } } + fn _execute_load_bucket<'env>( + &self, + bucket: &LoadBucket, + env: Env<'env>, + observe: bool, + ) -> RE<'env> { + self._compute_load_bucket(bucket, &env, observe).map(|r| (r, env)) + } + fn store_value_in_address<'env>( &self, address: &'env AddressType, @@ -607,7 +608,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env StoreBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let (src, env) = self._execute_instruction(&bucket.src, env, observe)?; let src = into_result(src, "store source value")?; let env = self.store_value_in_address( @@ -620,24 +621,26 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { Ok((None, env)) } - fn _execute_compute_bucket<'env>( - &self, - bucket: &'env ComputeBucket, - env: Env<'env>, - observe: bool, - ) -> R<'env> { + fn _compute_compute_bucket(&self, bucket: &ComputeBucket, env: &Env, observe: bool) -> RC { let mut stack = vec![]; - let mut env = env; for i in &bucket.stack { - let (value, new_env) = self._execute_instruction(i, env, observe)?; - env = new_env; + let value = self._compute_instruction(i, env, observe)?; stack.push(into_result(value, format!("{:?} operand", bucket.op))?); } // If any value of the stack is unknown we just return unknown if stack.iter().any(|v| v.is_unknown()) { - return Ok((Some(Unknown), env)); + return Ok(Some(Unknown)); } - operations::compute_operation(bucket, &stack, &self.p).map(|v| (v, env)) + operations::compute_operation(bucket, &stack, &self.p) + } + + fn _execute_compute_bucket<'env>( + &self, + bucket: &'env ComputeBucket, + env: Env<'env>, + observe: bool, + ) -> RE<'env> { + self._compute_compute_bucket(bucket, &env, observe).map(|r| (r, env)) } fn run_function_extracted<'env>( @@ -645,7 +648,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env CallBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let name = &bucket.symbol; if cfg!(debug_assertions) { println!("Running function {}", name); @@ -707,7 +710,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env CallBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let mut env = env; let res = if bucket.symbol.eq(FR_IDENTITY_ARR_PTR) || bucket.symbol.eq(FR_INDEX_ARR_PTR) { (Some(Unknown), env) @@ -751,7 +754,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env BranchBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let (value, cond, mut env) = self.execute_conditional_bucket( &bucket.cond, &bucket.if_branch, @@ -788,45 +791,53 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { Ok((value, env)) } + fn _compute_return_bucket(&self, bucket: &ReturnBucket, env: &Env, observe: bool) -> RC { + self._compute_instruction(&bucket.value, env, observe) + } + fn _execute_return_bucket<'env>( &self, bucket: &'env ReturnBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { - self._execute_instruction(&bucket.value, env, observe) + ) -> RE<'env> { + self._compute_return_bucket(bucket, &env, observe).map(|r| (r, env)) } - fn _execute_assert_bucket<'env>( - &self, - bucket: &'env AssertBucket, - env: Env<'env>, - observe: bool, - ) -> R<'env> { - let (cond, env) = self._execute_instruction(&bucket.evaluate, env, observe)?; + fn _compute_assert_bucket(&self, bucket: &AssertBucket, env: &Env, observe: bool) -> RC { + let cond = self._compute_instruction(&bucket.evaluate, env, observe)?; let cond = into_result(cond, "assert condition")?; if !cond.is_unknown() && !cond.to_bool(&self.p)? { // Based on 'constraint_generation::execute::treat_result_with_execution_error' Err(BadInterp::error("False assert reached".to_string(), ReportCode::RuntimeError)) } else { - Ok((None, env)) + Ok(None) } } - - fn _execute_log_bucket<'env>( + fn _execute_assert_bucket<'env>( &self, - bucket: &'env LogBucket, + bucket: &'env AssertBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { - let mut env = env; + ) -> RE<'env> { + self._compute_assert_bucket(bucket, &env, observe).map(|r| (r, env)) + } + fn _compute_log_bucket(&self, bucket: &LogBucket, env: &Env, observe: bool) -> RC { for arg in &bucket.argsprint { if let LogBucketArg::LogExp(i) = arg { - let (_, new_env) = self._execute_instruction(i, env, observe)?; - env = new_env + self._compute_instruction(i, env, observe)?; } } - Ok((None, env)) + Ok(None) + } + + fn _execute_log_bucket<'env>( + &self, + bucket: &'env LogBucket, + env: Env<'env>, + observe: bool, + ) -> RE<'env> { + self._compute_log_bucket(bucket, &env, observe).map(|r| (r, env)) } fn _execute_conditional_bucket<'env>( @@ -857,7 +868,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env LoopBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let mut last_value = Some(Unknown); let mut loop_env = env; let mut n_iters = 0; @@ -909,7 +920,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env CreateCmpBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let (cmp_id, env) = self._execute_instruction(&bucket.sub_cmp_id, env, observe)?; let cmp_id = Value::into_u32_result(cmp_id, "ID of subcomponent!")?; let mut env = @@ -928,7 +939,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env ConstraintBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { self._execute_instruction( match bucket { ConstraintBucket::Substitution(i) => i, @@ -944,17 +955,42 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { bucket: &'env BlockBucket, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { self.execute_instructions(&bucket.body, env, observe) } + fn _compute_nop_bucket(&self, _bucket: &NopBucket, _env: &Env, _observe: bool) -> RC { + Ok(None) + } + fn _execute_nop_bucket<'env>( &self, - _bucket: &NopBucket, + bucket: &NopBucket, env: Env<'env>, - _observe: bool, - ) -> R<'env> { - Ok((None, env)) + observe: bool, + ) -> RE<'env> { + self._compute_nop_bucket(bucket, &env, observe).map(|r| (r, env)) + } + + fn _compute_instruction(&self, inst: &InstructionPointer, env: &Env, observe: bool) -> RC { + let continue_observing = + if observe { self.observer.on_instruction(inst, env)? } else { observe }; + match inst.as_ref() { + Instruction::Value(b) => self._compute_value_bucket(b, env, continue_observing), + Instruction::Load(b) => self._compute_load_bucket(b, env, continue_observing), + Instruction::Store(_) => unreachable!("must use '_execute_instruction'"), + Instruction::Compute(b) => self._compute_compute_bucket(b, env, continue_observing), + Instruction::Call(_) => unreachable!("must use '_execute_instruction'"), + Instruction::Branch(_) => unreachable!("must use '_execute_instruction'"), + Instruction::Return(b) => self._compute_return_bucket(b, env, continue_observing), + Instruction::Assert(b) => self._compute_assert_bucket(b, env, continue_observing), + Instruction::Log(b) => self._compute_log_bucket(b, env, continue_observing), + Instruction::Loop(_) => unreachable!("must use '_execute_instruction'"), + Instruction::CreateCmp(_) => unreachable!("must use '_execute_instruction'"), + Instruction::Constraint(_) => unreachable!("must use '_execute_instruction'"), + Instruction::Block(_) => unreachable!("must use '_execute_instruction'"), + Instruction::Nop(b) => self._compute_nop_bucket(b, env, continue_observing), + } } fn _execute_instruction<'env>( @@ -962,7 +998,7 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> { inst: &'env InstructionPointer, env: Env<'env>, observe: bool, - ) -> R<'env> { + ) -> RE<'env> { let continue_observing = if observe { self.observer.on_instruction(inst, &env)? } else { observe }; match inst.as_ref() { diff --git a/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs b/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs index 02dc857f1..5b330cee9 100644 --- a/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs +++ b/circuit_passes/src/passes/deterministic_subcomponent_invocation.rs @@ -45,9 +45,8 @@ impl<'d> DeterministicSubCmpInvokePass<'d> { .. } = address_type { - let env = env.clone(); let interpreter = self.memory.build_interpreter(self.global_data, self); - let (addr, env) = interpreter.execute_instruction(cmp_address, env, false)?; + let addr = interpreter.compute_instruction(cmp_address, env, false)?; let addr = addr .expect("cmp_address instruction in SubcmpSignal must produce a value!") .get_u32()?; diff --git a/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs b/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs index 8f1f647bb..da11dd0f7 100644 --- a/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs +++ b/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs @@ -98,12 +98,6 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { self.vals_per_iteration.borrow_mut().insert(iter, VariableValues::new(env)); } - pub fn get_header_env_clone(&self) -> Env { - let iter = self.get_iter(); - assert!(self.vals_per_iteration.borrow().contains_key(&iter)); - self.vals_per_iteration.borrow().get(&iter).unwrap().env_at_header.clone() - } - pub fn record_reverse_arg_mapping( &self, extract_func: String, @@ -137,7 +131,7 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { fn compute_index_from_inst( &self, - env: Env, + env: &Env, location: &InstructionPointer, ) -> Result { // Evaluate the index using the current environment and using the environment from the @@ -145,15 +139,17 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { // not safe to move the loop body to another function because the index computation may // not give the same result when done at the call site, outside of the new function. let interp = self.mem.build_interpreter(self.global_data, self); - let (idx_loc, _) = interp.execute_instruction(location, env, false)?; - if let Some(idx_loc) = idx_loc { + if let Some(idx_loc) = interp.compute_instruction(location, env, false)? { // NOTE: It's possible for the interpreter to run into problems evaluating the location // using the header Env. For example, a value may not have been defined yet so address // computations on that value could give out of range results for the 'usize' type. // Thus, these errors should be ignored and fall through into the Ok(Unknown) case. + let borrow = self.vals_per_iteration.borrow(); + let vals = borrow.get(&self.get_iter()); + assert!(vals.is_some()); let header_res = - interp.execute_instruction(location, self.get_header_env_clone(), false); - if let Ok((Some(idx_header), _)) = header_res { + interp.compute_instruction(location, &vals.unwrap().env_at_header, false); + if let Ok(Some(idx_header)) = header_res { if Value::eq(&idx_header, &idx_loc) { return Ok(idx_loc); } @@ -162,7 +158,7 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { Ok(Value::Unknown) } - fn compute_index_from_rule(&self, env: Env, loc: &LocationRule) -> Result { + fn compute_index_from_rule(&self, env: &Env, loc: &LocationRule) -> Result { match loc { LocationRule::Mapped { .. } => { //TODO: It's not an array index in this case, at least not immediately but I think it can @@ -175,7 +171,7 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { } } - fn visit_location_rule(&self, env: Env, loc: &LocationRule) -> Result { + fn visit_location_rule(&self, env: &Env, loc: &LocationRule) -> Result { let res = self.compute_index_from_rule(env, loc); if let Ok(Value::Unknown) = res { if DEBUG_LOOP_UNROLL { @@ -194,7 +190,7 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { bucket_id: &BucketId, addr_ty: &AddressType, loc: &LocationRule, - env: Env, + env: &Env, ) -> Result<(), BadInterp> { //NOTE: must record even when Unknown to ensure that Unknown value is not confused with // missing values for an iteration that can be caused by conditionals within the loop. @@ -206,7 +202,7 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { counter_override, } = addr_ty { - let loc_result = self.visit_location_rule(env.clone(), loc)?; + let loc_result = self.visit_location_rule(env, loc)?; let addr_result = self.compute_index_from_inst(env, cmp_address)?; self.record_memloc_at_bucket( bucket_id, @@ -241,7 +237,7 @@ impl Observer> for EnvRecorder<'_, '_> { if let Some(_) = bucket.bounded_fn { todo!(); //not sure if/how to handle that } - self.visit(&bucket.id, &bucket.address_type, &bucket.src, env.clone())?; + self.visit(&bucket.id, &bucket.address_type, &bucket.src, env)?; // For a LoadBucket, there is no need to continue observing inside it and doing // so can actually cause "assert!(bucket_to_args.is_empty())" to fail. See // test "loops/fixed_idx_in_fixed_idx.circom" for an example and explanation. @@ -254,13 +250,13 @@ impl Observer> for EnvRecorder<'_, '_> { if let Some(_) = bucket.bounded_fn { todo!(); //not sure if/how to handle that } - self.visit(&bucket.id, &bucket.dest_address_type, &bucket.dest, env.clone())?; + self.visit(&bucket.id, &bucket.dest_address_type, &bucket.dest, env)?; Ok(self.is_safe_to_move()) //continue observing unless something unsafe has been found } fn on_call_bucket(&self, bucket: &CallBucket, env: &Env) -> Result { if let ReturnType::Final(fd) = &bucket.return_info { - self.visit(&bucket.id, &fd.dest_address_type, &fd.dest, env.clone())?; + self.visit(&bucket.id, &fd.dest_address_type, &fd.dest, env)?; } Ok(self.is_safe_to_move()) //continue observing unless something unsafe has been found } diff --git a/circuit_passes/src/passes/mapped_to_indexed.rs b/circuit_passes/src/passes/mapped_to_indexed.rs index 83e506f1b..c5d1c102c 100644 --- a/circuit_passes/src/passes/mapped_to_indexed.rs +++ b/circuit_passes/src/passes/mapped_to_indexed.rs @@ -45,20 +45,16 @@ impl<'d> MappedToIndexedPass<'d> { ) -> Result { let interpreter = self.memory.build_interpreter(self.global_data, self); - let (resolved_addr, acc_env) = - interpreter.execute_instruction(cmp_address, env.clone(), false)?; - + let resolved_addr = interpreter.compute_instruction(cmp_address, env, false)?; let resolved_addr = Value::into_u32_result(resolved_addr, "subcomponent address")?; - let name = acc_env.get_subcmp_name(resolved_addr).clone(); + let name = env.get_subcmp_name(resolved_addr).clone(); let io_def = - self.memory.get_iodef(&acc_env.get_subcmp_template_id(resolved_addr), &signal_code); + self.memory.get_iodef(&env.get_subcmp_template_id(resolved_addr), &signal_code); let offset = if indexes.len() > 0 { - let mut acc_env = acc_env; let mut indexes_values = vec![]; for i in indexes { - let (val, new_env) = interpreter.execute_instruction(i, acc_env, false)?; + let val = interpreter.compute_instruction(i, env, false)?; indexes_values.push(Value::into_u32_result(val, "subcomponent mapped signal")?); - acc_env = new_env; } io_def.offset + compute_offset(&indexes_values, &io_def.lengths)? } else { diff --git a/circuit_passes/src/passes/simplification.rs b/circuit_passes/src/passes/simplification.rs index 046aa92b9..8cf057d36 100644 --- a/circuit_passes/src/passes/simplification.rs +++ b/circuit_passes/src/passes/simplification.rs @@ -36,9 +36,8 @@ impl<'d> SimplificationPass<'d> { impl Observer> for SimplificationPass<'_> { fn on_compute_bucket(&self, bucket: &ComputeBucket, env: &Env) -> Result { - let env = env.clone(); let interpreter = self.memory.build_interpreter(self.global_data, self); - let (eval, _) = interpreter.execute_compute_bucket(bucket, env, false)?; + let (eval, _) = interpreter.execute_compute_bucket(bucket, env.clone(), false)?; let eval = eval.expect("Compute bucket must produce a value!"); if !eval.is_unknown() { self.compute_replacements.borrow_mut().insert(bucket.id, eval); @@ -49,9 +48,8 @@ impl Observer> for SimplificationPass<'_> { } fn on_call_bucket(&self, bucket: &CallBucket, env: &Env) -> Result { - let env = env.clone(); let interpreter = self.memory.build_interpreter(self.global_data, self); - let (eval, _) = interpreter.execute_call_bucket(bucket, env, false)?; + let (eval, _) = interpreter.execute_call_bucket(bucket, env.clone(), false)?; if let Some(eval) = eval { // Call buckets may not return a value directly if !eval.is_unknown() { diff --git a/circuit_passes/src/passes/unknown_index_sanitization.rs b/circuit_passes/src/passes/unknown_index_sanitization.rs index 711530af5..483e2b983 100644 --- a/circuit_passes/src/passes/unknown_index_sanitization.rs +++ b/circuit_passes/src/passes/unknown_index_sanitization.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, HashSet}; use std::ops::Range; use compiler::circuit_design::template::TemplateCode; use compiler::compiler_interface::Circuit; -use compiler::intermediate_representation::{Instruction, InstructionPointer, new_id}; +use compiler::intermediate_representation::{Instruction, InstructionPointer, new_id, BucketId}; use compiler::intermediate_representation::ir_interface::*; use compiler::num_bigint::BigInt; use code_producers::llvm_elements::array_switch::{get_array_load_name, get_array_store_name}; @@ -13,7 +13,7 @@ use crate::bucket_interpreter::error::{BadInterp, add_loc_if_err}; use crate::bucket_interpreter::memory::PassMemory; use crate::bucket_interpreter::observer::Observer; use crate::bucket_interpreter::operations::compute_operation; -use crate::bucket_interpreter::{R, to_bigint, into_result}; +use crate::bucket_interpreter::{RC, to_bigint, into_result}; use crate::bucket_interpreter::value::Value::{KnownU32, KnownBigInt}; use crate::{ default__get_updated_field_constants, default__name, default__pre_hook_template, @@ -31,51 +31,38 @@ impl<'a> ZeroingInterpreter<'a> { ZeroingInterpreter { constant_fields, p: UsefulConstants::new(prime).get_p().clone() } } - pub fn execute_value_bucket<'env>(&self, bucket: &ValueBucket, env: Env<'env>) -> R<'env> { - Ok(( - Some(match bucket.parse_as { - ValueType::U32 => KnownU32(bucket.value), - ValueType::BigInt => { - let constant = &self.constant_fields[bucket.value]; - KnownBigInt(add_loc_if_err(to_bigint(constant), bucket)?) - } - }), - env, - )) + pub fn compute_value_bucket(&self, bucket: &ValueBucket, _env: &Env) -> RC { + Ok(Some(match bucket.parse_as { + ValueType::U32 => KnownU32(bucket.value), + ValueType::BigInt => { + let constant = &self.constant_fields[bucket.value]; + KnownBigInt(add_loc_if_err(to_bigint(constant), bucket)?) + } + })) } - pub fn execute_load_bucket<'env>(&self, _bucket: &'env LoadBucket, env: Env<'env>) -> R<'env> { - Ok((Some(KnownU32(0)), env)) + pub fn compute_load_bucket(&self, _bucket: &LoadBucket, _env: &Env) -> RC { + Ok(Some(KnownU32(0))) } - pub fn execute_compute_bucket<'env>( - &self, - bucket: &'env ComputeBucket, - env: Env<'env>, - ) -> R<'env> { + pub fn compute_compute_bucket(&self, bucket: &ComputeBucket, env: &Env) -> RC { let mut stack = vec![]; - let mut env = env; for i in &bucket.stack { - let (value, new_env) = self.execute_instruction(i, env)?; - env = new_env; + let value = self.compute_instruction(i, env)?; stack.push(into_result(value, "operand")?); } // If any value of the stack is unknown we just return 0 if stack.iter().any(|v| v.is_unknown()) { - return Ok((Some(KnownU32(0)), env)); + return Ok(Some(KnownU32(0))); } - compute_operation(bucket, &stack, &self.p).map(|v| (v, env)) + compute_operation(bucket, &stack, &self.p) } - pub fn execute_instruction<'env>( - &self, - inst: &'env InstructionPointer, - env: Env<'env>, - ) -> R<'env> { + pub fn compute_instruction(&self, inst: &InstructionPointer, env: &Env) -> RC { match inst.as_ref() { - Instruction::Value(b) => self.execute_value_bucket(b, env), - Instruction::Load(b) => self.execute_load_bucket(b, env), - Instruction::Compute(b) => self.execute_compute_bucket(b, env), + Instruction::Value(b) => self.compute_value_bucket(b, env), + Instruction::Load(b) => self.compute_load_bucket(b, env), + Instruction::Compute(b) => self.compute_compute_bucket(b, env), _ => unreachable!(), } } @@ -85,8 +72,8 @@ pub struct UnknownIndexSanitizationPass<'d> { global_data: &'d RefCell, // Wrapped in a RefCell because the reference to the static analysis is immutable but we need mutability memory: PassMemory, - load_replacements: RefCell>>, - store_replacements: RefCell>>, + load_replacements: RefCell>>, + store_replacements: RefCell>>, scheduled_bounded_loads: RefCell>>, scheduled_bounded_stores: RefCell>>, } @@ -127,7 +114,7 @@ impl<'d> UnknownIndexSanitizationPass<'d> { let mem = &self.memory; let constant_fields = mem.get_field_constants_clone(); let interpreter = ZeroingInterpreter::init(mem.get_prime(), &constant_fields); - let (res, _) = interpreter.execute_instruction(location, env.clone())?; + let res = interpreter.compute_instruction(location, env)?; let offset = match res { Some(KnownU32(base)) => base, @@ -154,7 +141,7 @@ impl<'d> UnknownIndexSanitizationPass<'d> { LocationRule::Indexed { location, .. } => { let mem = &self.memory; let interpreter = mem.build_interpreter(self.global_data, self); - let (r, _) = interpreter.execute_instruction(location, env.clone(), false)?; + let r = interpreter.compute_instruction(location, env, false)?; into_result(r, "indexed location")? } LocationRule::Mapped { .. } => unreachable!(), @@ -174,7 +161,7 @@ impl Observer> for UnknownIndexSanitizationPass<'_> { let location = &bucket.src; if self.is_location_unknown(address, location, env)? { let index_range = self.find_bounds(address, location, env)?; - self.load_replacements.borrow_mut().insert(bucket.clone(), index_range.clone()); + self.load_replacements.borrow_mut().insert(bucket.id, index_range.clone()); self.scheduled_bounded_loads.borrow_mut().insert(index_range); } Ok(true) @@ -185,7 +172,7 @@ impl Observer> for UnknownIndexSanitizationPass<'_> { let location = &bucket.dest; if self.is_location_unknown(address, location, env)? { let index_range = self.find_bounds(address, location, env)?; - self.store_replacements.borrow_mut().insert(bucket.clone(), index_range.clone()); + self.store_replacements.borrow_mut().insert(bucket.id, index_range.clone()); self.scheduled_bounded_stores.borrow_mut().insert(index_range); } Ok(true) @@ -229,7 +216,7 @@ impl CircuitTransformationPass for UnknownIndexSanitizationPass<'_> { } fn transform_load_bucket(&self, bucket: &LoadBucket) -> Result { - let bounded_fn_symbol = match self.load_replacements.borrow().get(bucket) { + let bounded_fn_symbol = match self.load_replacements.borrow().get(&bucket.id) { Some(index_range) => Some(get_array_load_name(index_range)), None => bucket.bounded_fn.clone(), }; @@ -249,7 +236,7 @@ impl CircuitTransformationPass for UnknownIndexSanitizationPass<'_> { &self, bucket: &StoreBucket, ) -> Result { - let bounded_fn_symbol = match self.store_replacements.borrow().get(bucket) { + let bounded_fn_symbol = match self.store_replacements.borrow().get(&bucket.id) { Some(index_range) => Some(get_array_store_name(index_range)), None => bucket.bounded_fn.clone(), };