Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VAN-723] Track array load/store info to only generate used functions #63

Merged
merged 6 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions circom/tests/arrays/unknown_index_load_store.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@ pragma circom 2.0.0;

template UnknownIndexLoadStore() {
signal input in;
signal output out[10];
signal output out[8];

var arr1[10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
var arr1[9] = [0, 1, 2, 3, 4, 5, 6, 7, 8];
var arr2[10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
var arr3[10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
var arr3[11] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];

out[in] <-- arr2[in];
}

component main = UnknownIndexLoadStore();
component main = UnknownIndexLoadStore();

// CHECK: define void @__array_store__0_to_8([0 x i256]* %0, i32 %1, i256 %2)
// CHECK: define i256 @__array_load__9_to_19([0 x i256]* %0, i32 %1)
// CHECK-NOT: @__array_load__0_to_8
// CHECK-NOT: @__array_store__9_to_19
// CHECK-NOT: @__array_load__20_to_31
// CHECK-NOT: @__array_store__20_to_31
8 changes: 4 additions & 4 deletions circuit_passes/src/passes/conditional_flattening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl CircuitTransformationPass for ConditionalFlatteningPass<'_> {
"ConditionalFlattening"
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn pre_hook_circuit(&self, circuit: &Circuit) {
self.memory.fill_from_circuit(circuit);
}
Expand All @@ -121,10 +125,6 @@ impl CircuitTransformationPass for ConditionalFlatteningPass<'_> {
self.memory.run_template(self.global_data, self, template);
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn transform_branch_bucket(&self, bucket: &BranchBucket) -> InstructionPointer {
if let Some(side) = self.replacements.borrow().get(&bucket) {
let code = if *side { &bucket.if_branch } else { &bucket.else_branch };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ impl CircuitTransformationPass for DeterministicSubCmpInvokePass<'_> {
"DeterministicSubCmpInvokePass"
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn pre_hook_circuit(&self, circuit: &Circuit) {
self.memory.fill_from_circuit(circuit);
}
Expand All @@ -139,10 +143,6 @@ impl CircuitTransformationPass for DeterministicSubCmpInvokePass<'_> {
self.memory.run_template(self.global_data, self, template);
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn transform_address_type(&self, address: &AddressType) -> AddressType {
let replacements = self.replacements.borrow();
match address {
Expand Down
8 changes: 4 additions & 4 deletions circuit_passes/src/passes/loop_unroll/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ impl CircuitTransformationPass for LoopUnrollPass<'_> {
"LoopUnrollPass"
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn pre_hook_circuit(&self, circuit: &Circuit) {
self.memory.fill_from_circuit(circuit);
}
Expand All @@ -206,10 +210,6 @@ impl CircuitTransformationPass for LoopUnrollPass<'_> {
self.memory.run_template(self.global_data, self, template);
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn transform_loop_bucket(&self, bucket: &LoopBucket) -> InstructionPointer {
if let Some(unrolled_loop) = self.replacements.borrow().get(&bucket.id) {
return self.transform_instruction(unrolled_loop);
Expand Down
15 changes: 13 additions & 2 deletions circuit_passes/src/passes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cell::RefCell;
use std::collections::{HashMap, BTreeMap};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::ops::Range;
use compiler::circuit_design::function::{FunctionCode, FunctionCodeInfo};
use compiler::circuit_design::template::{TemplateCode, TemplateCodeInfo};
use compiler::compiler_interface::Circuit;
Expand Down Expand Up @@ -38,10 +39,12 @@ pub trait CircuitTransformationPass {
self.pre_hook_circuit(&circuit);
let templates = circuit.templates.iter().map(|t| self.transform_template(t)).collect();
let field_tracking = self.get_updated_field_constants();
let bounded_loads = self.get_updated_bounded_array_loads(&circuit.llvm_data.bounded_array_loads);
let bounded_stores = self.get_updated_bounded_array_stores(&circuit.llvm_data.bounded_array_stores);
let mut new_circuit = Circuit {
wasm_producer: circuit.wasm_producer.clone(),
c_producer: circuit.c_producer.clone(),
llvm_data: circuit.llvm_data.clone_with_new_field_tracking(field_tracking),
llvm_data: circuit.llvm_data.clone_with_updates(field_tracking, bounded_loads, bounded_stores),
templates,
functions: circuit.functions.iter().map(|f| self.transform_function(f)).collect(),
};
Expand All @@ -51,6 +54,14 @@ pub trait CircuitTransformationPass {

fn get_updated_field_constants(&self) -> Vec<String>;
iangneal marked this conversation as resolved.
Show resolved Hide resolved

fn get_updated_bounded_array_loads(&self, old_array_loads: &HashSet<Range<usize>>) -> HashSet<Range<usize>> {
old_array_loads.clone()
}

fn get_updated_bounded_array_stores(&self, old_array_stores: &HashSet<Range<usize>>) -> HashSet<Range<usize>> {
old_array_stores.clone()
}

fn transform_template(&self, template: &TemplateCode) -> TemplateCode {
self.pre_hook_template(template);
Box::new(TemplateCodeInfo {
Expand Down
38 changes: 28 additions & 10 deletions circuit_passes/src/passes/unknown_index_sanitization.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashSet};
use std::ops::Range;
use compiler::circuit_design::template::TemplateCode;
use compiler::compiler_interface::Circuit;
use compiler::intermediate_representation::{Instruction, InstructionPointer, new_id};
use compiler::intermediate_representation::ir_interface::*;
use compiler::num_bigint::BigInt;
use code_producers::llvm_elements::array_switch::{get_array_load_symbol, get_array_store_symbol};
use code_producers::llvm_elements::array_switch::{get_array_load_name, get_array_store_name};
use program_structure::constants::UsefulConstants;
use crate::bucket_interpreter::env::Env;
use crate::bucket_interpreter::memory::PassMemory;
Expand Down Expand Up @@ -87,6 +87,8 @@ pub struct UnknownIndexSanitizationPass<'d> {
memory: PassMemory,
load_replacements: RefCell<BTreeMap<LoadBucket, Range<usize>>>,
store_replacements: RefCell<BTreeMap<StoreBucket, Range<usize>>>,
scheduled_bounded_loads: RefCell<HashSet<Range<usize>>>,
scheduled_bounded_stores: RefCell<HashSet<Range<usize>>>,
}

/**
Expand All @@ -99,6 +101,8 @@ impl<'d> UnknownIndexSanitizationPass<'d> {
memory: PassMemory::new(prime, "".to_string(), Default::default()),
load_replacements: Default::default(),
store_replacements: Default::default(),
scheduled_bounded_loads: Default::default(),
scheduled_bounded_stores: Default::default(),
}
}

Expand Down Expand Up @@ -176,7 +180,8 @@ impl InterpreterObserver for UnknownIndexSanitizationPass<'_> {
let location = &bucket.src;
if self.is_location_unknown(address, location, env) {
let index_range = self.find_bounds(address, location, env);
self.load_replacements.borrow_mut().insert(bucket.clone(), index_range);
self.load_replacements.borrow_mut().insert(bucket.clone(), index_range.clone());
self.scheduled_bounded_loads.borrow_mut().insert(index_range);
}
true
}
Expand All @@ -186,7 +191,8 @@ impl InterpreterObserver for UnknownIndexSanitizationPass<'_> {
let location = &bucket.dest;
if self.is_location_unknown(address, location, env) {
let index_range = self.find_bounds(address, location, env);
self.store_replacements.borrow_mut().insert(bucket.clone(), index_range);
self.store_replacements.borrow_mut().insert(bucket.clone(), index_range.clone());
self.scheduled_bounded_stores.borrow_mut().insert(index_range);
}
true
}
Expand Down Expand Up @@ -248,14 +254,30 @@ impl InterpreterObserver for UnknownIndexSanitizationPass<'_> {
}
}

fn do_array_union(a: &HashSet<Range<usize>>, b: &HashSet<Range<usize>>) -> HashSet<Range<usize>> {
a.union(b).map(|e| e.clone()).collect()
}

impl CircuitTransformationPass for UnknownIndexSanitizationPass<'_> {
fn name(&self) -> &str {
"UnknownIndexSanitizationPass"
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn get_updated_bounded_array_loads(&self, old_array_loads: &HashSet<Range<usize>>) -> HashSet<Range<usize>> {
do_array_union(old_array_loads, &self.scheduled_bounded_loads.borrow())
}

fn get_updated_bounded_array_stores(&self, old_array_stores: &HashSet<Range<usize>>) -> HashSet<Range<usize>> {
do_array_union(old_array_stores, &self.scheduled_bounded_stores.borrow())
}

fn transform_load_bucket(&self, bucket: &LoadBucket) -> InstructionPointer {
let bounded_fn_symbol = match self.load_replacements.borrow().get(bucket) {
Some(index_range) => Some(get_array_load_symbol(index_range)),
Some(index_range) => Some(get_array_load_name(index_range)),
None => bucket.bounded_fn.clone(),
};
LoadBucket {
Expand All @@ -272,7 +294,7 @@ impl CircuitTransformationPass for UnknownIndexSanitizationPass<'_> {

fn transform_store_bucket(&self, bucket: &StoreBucket) -> InstructionPointer {
let bounded_fn_symbol = match self.store_replacements.borrow().get(bucket) {
Some(index_range) => Some(get_array_store_symbol(index_range)),
Some(index_range) => Some(get_array_store_name(index_range)),
None => bucket.bounded_fn.clone(),
};
StoreBucket {
Expand All @@ -290,10 +312,6 @@ impl CircuitTransformationPass for UnknownIndexSanitizationPass<'_> {
.allocate()
}

fn get_updated_field_constants(&self) -> Vec<String> {
self.memory.get_field_constants_clone()
}

fn pre_hook_circuit(&self, circuit: &Circuit) {
self.memory.fill_from_circuit(circuit);
}
Expand Down
19 changes: 13 additions & 6 deletions code_producers/src/llvm_elements/array_switch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Range;
use std::{ops::Range, collections::HashSet};
use inkwell::types::PointerType;
use crate::llvm_elements::LLVMIRProducer;
use super::types::bigint_type;
Expand Down Expand Up @@ -129,15 +129,22 @@ pub fn array_ptr_ty<'a>(producer: &dyn LLVMIRProducer<'a>) -> PointerType<'a> {
bigint_ty.array_type(0).ptr_type(Default::default())
}

pub fn load_array_switch<'a>(producer: &dyn LLVMIRProducer<'a>, index_range: &Range<usize>) {
array_switch::create_array_load_fn(producer, index_range);
array_switch::create_array_store_fn(producer, index_range);
pub fn load_array_load_fns<'a>(producer: &dyn LLVMIRProducer<'a>, scheduled_array_loads: &HashSet<Range<usize>>) {
for range in scheduled_array_loads {
array_switch::create_array_load_fn(producer, range);
}
}

pub fn load_array_stores_fns<'a>(producer: &dyn LLVMIRProducer<'a>, scheduled_array_stores: &HashSet<Range<usize>>) {
for range in scheduled_array_stores {
array_switch::create_array_store_fn(producer, range);
}
}

pub fn get_array_load_symbol(index_range: &Range<usize>) -> String {
pub fn get_array_load_name(index_range: &Range<usize>) -> String {
array_switch::get_load_symbol(index_range)
}

pub fn get_array_store_symbol(index_range: &Range<usize>) -> String {
pub fn get_array_store_name(index_range: &Range<usize>) -> String {
array_switch::get_store_symbol(index_range)
}
8 changes: 6 additions & 2 deletions code_producers/src/llvm_elements/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::ops::Range;
use std::rc::Rc;
Expand Down Expand Up @@ -100,16 +100,20 @@ pub struct LLVMCircuitData {
pub signal_index_mapping: HashMap<String, IndexMapping>,
pub variable_index_mapping: HashMap<String, IndexMapping>,
pub component_index_mapping: HashMap<String, IndexMapping>,
pub bounded_array_loads: HashSet<Range<usize>>,
pub bounded_array_stores: HashSet<Range<usize>>,
}

impl LLVMCircuitData {
pub fn clone_with_new_field_tracking(&self, field_tracking: Vec<String>) -> Self {
pub fn clone_with_updates(&self, field_tracking: Vec<String>, array_loads: HashSet<Range<usize>>, array_stores: HashSet<Range<usize>>) -> Self {
LLVMCircuitData {
field_tracking,
io_map: self.io_map.clone(),
signal_index_mapping: self.signal_index_mapping.clone(),
variable_index_mapping: self.variable_index_mapping.clone(),
component_index_mapping: self.component_index_mapping.clone(),
bounded_array_loads: array_loads,
bounded_array_stores: array_stores,
}
}
}
Expand Down
34 changes: 10 additions & 24 deletions compiler/src/circuit_design/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::io::Write;
use super::function::{FunctionCode, FunctionCodeInfo};
use super::template::{TemplateCode, TemplateCodeInfo};
Expand All @@ -7,8 +7,8 @@ use crate::hir::very_concrete_program::VCP;
use crate::intermediate_representation::ir_interface::ObtainMeta;
use crate::translating_traits::*;
use code_producers::c_elements::*;
use code_producers::llvm_elements::array_switch::{load_array_stores_fns, load_array_load_fns};
use code_producers::llvm_elements::*;
use code_producers::llvm_elements::array_switch::load_array_switch;
use code_producers::llvm_elements::fr::load_fr;
use code_producers::llvm_elements::functions::{
create_function, FunctionLLVMIRProducer, ExtractedFunctionLLVMIRProducer,
Expand Down Expand Up @@ -55,23 +55,9 @@ impl WriteLLVMIR for Circuit {
load_fr(producer);
load_stdlib(producer);

// Generate all the switch functions
let mut ranges = HashSet::new();
let mappings = [
&self.llvm_data.signal_index_mapping,
&self.llvm_data.variable_index_mapping,
&self.llvm_data.component_index_mapping,
];
for mapping in mappings {
for range_mapping in mapping.values() {
for range in range_mapping.values() {
ranges.insert(range);
}
}
}
for range in ranges {
load_array_switch(producer, range);
}
// Code for bounded array switch functions
load_array_load_fns(producer, &self.llvm_data.bounded_array_loads);
load_array_stores_fns(producer, &self.llvm_data.bounded_array_stores);

// Declare all the functions
let mut funcs = HashMap::new();
Expand Down Expand Up @@ -229,7 +215,7 @@ impl WriteWasm for Circuit {
code.append(&mut code_aux);

code_aux = get_input_size_generator(&producer);
code.append(&mut code_aux);
code.append(&mut code_aux);

code_aux = get_witness_size_generator(&producer);
code.append(&mut code_aux);
Expand Down Expand Up @@ -369,7 +355,7 @@ impl WriteWasm for Circuit {
code = merge_code(code_aux);
writer.write_all(code.as_bytes()).map_err(|_| {})?;
writer.flush().map_err(|_| {})?;

code_aux = get_witness_size_generator(&producer);
code = merge_code(code_aux);
writer.write_all(code.as_bytes()).map_err(|_| {})?;
Expand Down Expand Up @@ -460,7 +446,7 @@ impl WriteC for Circuit {
std::mem::drop(function_headers);

let (func_list_no_parallel, func_list_parallel) = generate_function_list(
producer,
producer,
producer.get_template_instance_list()
);

Expand All @@ -476,7 +462,7 @@ impl WriteC for Circuit {
"uint get_main_input_signal_start() {{return {};}}\n",
producer.get_number_of_main_outputs()
));

code.push(format!(
"uint get_main_input_signal_no() {{return {};}}\n",
producer.get_number_of_main_inputs()
Expand All @@ -503,7 +489,7 @@ impl WriteC for Circuit {
producer.get_io_map().len()
));
//code.append(&mut generate_message_list_def(producer, producer.get_message_list()));

// Functions to release the memory
let mut release_component_code = generate_function_release_memory_component();
code.append(&mut release_component_code);
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/circuit_design/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ impl ToString for TemplateCodeInfo {

impl WriteLLVMIR for TemplateCodeInfo {
fn produce_llvm_ir<'ctx, 'prod>(&self, producer: &'prod dyn LLVMIRProducer<'ctx>) -> Option<LLVMInstruction<'ctx>> {
println!("Generating code for {}", self.header);
if cfg!(debug_assertions) {
println!("Generating code for {}", self.header);
}
let void = void_type(producer);
let n_signals = self.number_of_inputs + self.number_of_outputs + self.number_of_intermediates;
let template_struct = create_template_struct(producer, n_signals);
Expand Down
Loading