Skip to content

Commit

Permalink
address CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-hoffman committed Oct 18, 2023
1 parent 5d19490 commit dbfc8a6
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use compiler::intermediate_representation::{InstructionPointer, new_id};
use compiler::intermediate_representation::ir_interface::*;
use super::loop_unroll::body_extractor::AddressOffset;

pub fn new_u32_value(bucket: &dyn ObtainMeta, val: usize) -> InstructionPointer {
pub fn build_u32_value(bucket: &dyn ObtainMeta, val: usize) -> InstructionPointer {
ValueBucket {
id: new_id(),
source_file_id: bucket.get_source_file_id().clone(),
Expand All @@ -17,7 +17,7 @@ pub fn new_u32_value(bucket: &dyn ObtainMeta, val: usize) -> InstructionPointer
.allocate()
}

pub fn new_call(
pub fn build_call(
meta: &dyn ObtainMeta,
name: impl Into<String>,
args: Vec<InstructionPointer>,
Expand All @@ -36,7 +36,7 @@ pub fn new_call(
.allocate()
}

pub fn new_custom_fn_load_bucket(
pub fn build_custom_fn_load_bucket(
bucket: &dyn ObtainMeta,
load_fun: &str,
addr_type: AddressType,
Expand All @@ -54,36 +54,39 @@ pub fn new_custom_fn_load_bucket(
.allocate()
}

pub fn new_storage_ptr_ref(bucket: &dyn ObtainMeta, addr_type: AddressType) -> InstructionPointer {
new_custom_fn_load_bucket(
pub fn build_storage_ptr_ref(
bucket: &dyn ObtainMeta,
addr_type: AddressType,
) -> InstructionPointer {
build_custom_fn_load_bucket(
bucket,
FR_IDENTITY_ARR_PTR,
addr_type,
new_u32_value(bucket, 0), //use index 0 to ref the entire storage array
build_u32_value(bucket, 0), //use index 0 to ref the entire storage array
)
}

//NOTE: When the 'bounded_fn' for LoadBucket is Some(_), the index parameter
// is ignored so we must instead use `FR_INDEX_ARR_PTR` to apply the index.
// Uses of that function can be inlined later.
// NOTE: Must start with `GENERATED_FN_PREFIX` to use `ExtractedFunctionCtx`
pub fn new_indexed_storage_ptr_ref(
pub fn build_indexed_storage_ptr_ref(
bucket: &dyn ObtainMeta,
addr_type: AddressType,
index: AddressOffset,
) -> InstructionPointer {
new_call(
build_call(
bucket,
FR_INDEX_ARR_PTR,
vec![new_storage_ptr_ref(bucket, addr_type), new_u32_value(bucket, index)],
vec![build_storage_ptr_ref(bucket, addr_type), build_u32_value(bucket, index)],
)
}

pub fn new_subcmp_counter_storage_ptr_ref(
pub fn build_subcmp_counter_storage_ptr_ref(
bucket: &dyn ObtainMeta,
sub_cmp_id: InstructionPointer,
) -> InstructionPointer {
new_custom_fn_load_bucket(
build_custom_fn_load_bucket(
bucket,
FR_PTR_CAST_I32_I256,
AddressType::SubcmpSignal {
Expand All @@ -93,33 +96,10 @@ pub fn new_subcmp_counter_storage_ptr_ref(
input_information: InputInformation::NoInput,
counter_override: true,
},
new_u32_value(bucket, usize::MAX), //index is ignored for these
build_u32_value(bucket, usize::MAX), //index is ignored for these
)
}

pub fn new_null_ptr(bucket: &dyn ObtainMeta, null_fun: &str) -> InstructionPointer {
new_call(bucket, null_fun, vec![])
}

pub fn all_same<T>(data: T) -> bool
where
T: Iterator,
T::Item: PartialEq,
{
data.fold((true, None), {
|acc, elem| {
if acc.1.is_some() {
(acc.0 && (acc.1.unwrap() == elem), Some(elem))
} else {
(true, Some(elem))
}
}
})
.0
}

pub fn new_filled_vec<T: Clone>(new_len: usize, value: T) -> Vec<T> {
let mut result = Vec::with_capacity(new_len);
result.resize(new_len, value);
result
pub fn build_null_ptr(bucket: &dyn ObtainMeta, null_fun: &str) -> InstructionPointer {
build_call(bucket, null_fun, vec![])
}
18 changes: 18 additions & 0 deletions circuit_passes/src/passes/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,21 @@ pub fn assert_unique_ids_in_circuit(circuit: &Circuit) {
assert_unique_ids_in_function(function, &mut visited);
}
}

/// Return true iff all elements returned by the given Iterator are equal.
pub fn all_same<T>(data: T) -> bool
where
T: Iterator,
T::Item: PartialEq,
{
data.fold((true, None), {
|acc, elem| {
if acc.1.is_some() {
(acc.0 && (acc.1.unwrap() == elem), Some(elem))
} else {
(true, Some(elem))
}
}
})
.0
}
6 changes: 3 additions & 3 deletions circuit_passes/src/passes/const_arg_deduplication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use compiler::intermediate_representation::{InstructionPointer, BucketId, Update
use compiler::intermediate_representation::ir_interface::*;
use compiler::intermediate_representation::translate::ARRAY_PARAM_STORES;
use crate::bucket_interpreter::memory::PassMemory;
use super::{CircuitTransformationPass, GlobalPassData, utils};
use super::{CircuitTransformationPass, GlobalPassData, builders};

pub struct ConstArgDeduplicationPass<'d> {
_global_data: &'d RefCell<GlobalPassData>,
Expand Down Expand Up @@ -181,10 +181,10 @@ impl CircuitTransformationPass for ConstArgDeduplicationPass<'_> {

// Generate a call to the extracted function
let meta_info: &dyn ObtainMeta = &**const_stores[0];
new_body.push(utils::new_call(
new_body.push(builders::build_call(
meta_info,
self.get_or_create_function_for(meta_info, idx_val_pairs, const_stores),
vec![utils::new_storage_ptr_ref(meta_info, AddressType::Variable)],
vec![builders::build_storage_ptr_ref(meta_info, AddressType::Variable)],
));
}
}
Expand Down
38 changes: 22 additions & 16 deletions circuit_passes/src/passes/loop_unroll/body_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::bucket_interpreter::value::Value;
use crate::passes::loop_unroll::LOOP_BODY_FN_PREFIX;
use crate::passes::loop_unroll::extracted_location_updater::ExtractedFunctionLocationUpdater;
use crate::passes::loop_unroll::loop_env_recorder::EnvRecorder;
use crate::passes::utils;
use crate::passes::{builders, checks};

pub type FuncArgIdx = usize;
pub type AddressOffset = usize;
Expand Down Expand Up @@ -109,6 +109,12 @@ pub struct LoopBodyExtractor {
}

impl LoopBodyExtractor {
fn new_filled_vec<T: Clone>(new_len: usize, value: T) -> Vec<T> {
let mut result = Vec::with_capacity(new_len);
result.resize(new_len, value);
result
}

pub fn get_new_functions(&self) -> Ref<Vec<FunctionCode>> {
self.new_body_functions.borrow()
}
Expand All @@ -132,42 +138,42 @@ impl LoopBodyExtractor {
// instruction to use these pointers as parameters to the function so we must use the
// `bounded_fn` field of the LoadBucket to specify the identity function to perform
// the "loading" (but really it just returns the pointer that was passed in).
let mut args = utils::new_filled_vec(
let mut args = Self::new_filled_vec(
extra_arg_info.num_args,
NopBucket { id: 0 }.allocate(), // garbage fill
);
// Parameter for local vars
args[0] = utils::new_storage_ptr_ref(bucket, AddressType::Variable);
args[0] = builders::build_storage_ptr_ref(bucket, AddressType::Variable);
// Parameter for signals/arena
args[1] = utils::new_storage_ptr_ref(bucket, AddressType::Signal);
args[1] = builders::build_storage_ptr_ref(bucket, AddressType::Signal);
// Additional parameters for subcmps and variant array indexing within the loop
for (loc, ai) in extra_arg_info.get_passing_refs_for_itr(iter_num) {
match loc {
None => match ai {
ArgIndex::Signal(signal) => {
args[signal] = utils::new_null_ptr(bucket, FR_NULL_I256_PTR);
args[signal] = builders::build_null_ptr(bucket, FR_NULL_I256_PTR);
}
ArgIndex::SubCmp { signal, arena, counter } => {
args[signal] = utils::new_null_ptr(bucket, FR_NULL_I256_PTR);
args[arena] = utils::new_null_ptr(bucket, FR_NULL_I256_ARR_PTR);
args[counter] = utils::new_null_ptr(bucket, FR_NULL_I256_PTR);
args[signal] = builders::build_null_ptr(bucket, FR_NULL_I256_PTR);
args[arena] = builders::build_null_ptr(bucket, FR_NULL_I256_ARR_PTR);
args[counter] = builders::build_null_ptr(bucket, FR_NULL_I256_PTR);
}
},
Some((at, val)) => match ai {
ArgIndex::Signal(signal) => {
args[signal] =
utils::new_indexed_storage_ptr_ref(bucket, at.clone(), *val)
builders::build_indexed_storage_ptr_ref(bucket, at.clone(), *val)
}
ArgIndex::SubCmp { signal, arena, counter } => {
// Pass specific signal referenced
args[signal] =
utils::new_indexed_storage_ptr_ref(bucket, at.clone(), *val);
builders::build_indexed_storage_ptr_ref(bucket, at.clone(), *val);
// Pass entire subcomponent arena for calling the 'template_run' function
args[arena] = utils::new_storage_ptr_ref(bucket, at.clone());
args[arena] = builders::build_storage_ptr_ref(bucket, at.clone());
// Pass subcomponent counter reference
if let AddressType::SubcmpSignal { cmp_address, .. } = &at {
//TODO: may only need to add this when is_output=true but have to skip adding the Param too in that case.
args[counter] = utils::new_subcmp_counter_storage_ptr_ref(
args[counter] = builders::build_subcmp_counter_storage_ptr_ref(
bucket,
cmp_address.clone(),
);
Expand All @@ -178,7 +184,7 @@ impl LoopBodyExtractor {
},
}
}
unrolled.push(utils::new_call(bucket, &name, args));
unrolled.push(builders::build_call(bucket, &name, args));

recorder.record_reverse_arg_mapping(
name.clone(),
Expand All @@ -197,7 +203,7 @@ impl LoopBodyExtractor {
// NOTE: must create parameter list before 'bucket_to_args' is modified
// Since the ArgIndex instances could have indices in any random order,
// create the vector of required size and then set elements by index.
let mut params = utils::new_filled_vec(
let mut params = Self::new_filled_vec(
num_args,
Param { name: String::from("EMPTY"), length: vec![usize::MAX] },
);
Expand Down Expand Up @@ -306,7 +312,7 @@ impl LoopBodyExtractor {
column.push(temp.map(|(a, v)| (a.clone(), v.get_u32())));
}
// ASSERT: same AddressType kind for this bucket in every (available) iteration
assert!(utils::all_same(
assert!(checks::all_same(
column.iter().filter_map(|x| x.as_ref()).map(|x| std::mem::discriminant(&x.0))
));

Expand All @@ -315,7 +321,7 @@ impl LoopBodyExtractor {
// Actually, check not only the computed index Value but the AddressType as well to capture when
// it's a SubcmpSignal referencing a different subcomponent (the AddressType::cmp_address field
// was also interpreted within the EnvRecorder so this comparison will be accurate).
if !utils::all_same(column.iter().filter_map(|x| x.as_ref())) {
if !checks::all_same(column.iter().filter_map(|x| x.as_ref())) {
bucket_to_args.insert(*id, ArgIndex::Signal(next_idx));
next_idx += 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use indexmap::IndexMap;
use code_producers::llvm_elements::fr::FR_IDENTITY_ARR_PTR;
use compiler::intermediate_representation::{BucketId, InstructionPointer, new_id};
use compiler::intermediate_representation::ir_interface::*;
use crate::passes::utils::new_u32_value;
use crate::passes::builders::build_u32_value;

use super::body_extractor::ArgIndex;

Expand All @@ -27,14 +27,14 @@ impl ExtractedFunctionLocationUpdater {
// parameters with those. So this has to use SubcmpSignal (it should
// work fine because subcomps will also just be additional params).
bucket.address_type = AddressType::SubcmpSignal {
cmp_address: new_u32_value(bucket, ai.get_signal_idx()),
cmp_address: build_u32_value(bucket, ai.get_signal_idx()),
uniform_parallel_value: None,
counter_override: false,
is_output: false,
input_information: InputInformation::NoInput,
};
bucket.src = LocationRule::Indexed {
location: new_u32_value(bucket, 0), //use index 0 to ref the entire storage array
location: build_u32_value(bucket, 0), //use index 0 to ref the entire storage array
template_header: None,
};
} else {
Expand Down Expand Up @@ -65,7 +65,7 @@ impl ExtractedFunctionLocationUpdater {
context: bucket.context.clone(),
dest_is_output: bucket.dest_is_output,
dest_address_type: AddressType::SubcmpSignal {
cmp_address: new_u32_value(bucket, arena),
cmp_address: build_u32_value(bucket, arena),
uniform_parallel_value: None,
counter_override: false,
is_output: false,
Expand All @@ -81,15 +81,15 @@ impl ExtractedFunctionLocationUpdater {
},
},
dest: LocationRule::Indexed {
location: new_u32_value(bucket, 0), //the value here is ignored by the 'bounded_fn' below
location: build_u32_value(bucket, 0), //the value here is ignored by the 'bounded_fn' below
template_header: match &bucket.dest {
LocationRule::Indexed { template_header, .. } => {
template_header.clone()
}
LocationRule::Mapped { .. } => todo!(),
},
},
src: new_u32_value(bucket, 0), //the value here is ignored at runtime
src: build_u32_value(bucket, 0), //the value here is ignored at runtime
bounded_fn: Some(String::from(FR_IDENTITY_ARR_PTR)), //NOTE: doesn't have enough arguments but it works out
}
.allocate(),
Expand All @@ -101,14 +101,14 @@ impl ExtractedFunctionLocationUpdater {

//Transform this bucket into the normal fixed-index signal reference
bucket.dest_address_type = AddressType::SubcmpSignal {
cmp_address: new_u32_value(bucket, ai.get_signal_idx()),
cmp_address: build_u32_value(bucket, ai.get_signal_idx()),
uniform_parallel_value: None,
counter_override: false,
is_output: false,
input_information: InputInformation::NoInput,
};
bucket.dest = LocationRule::Indexed {
location: new_u32_value(bucket, 0), //use index 0 to ref the entire storage array
location: build_u32_value(bucket, 0), //use index 0 to ref the entire storage array
template_header: None,
};
} else {
Expand Down
2 changes: 1 addition & 1 deletion circuit_passes/src/passes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod mapped_to_indexed;
mod unknown_index_sanitization;
mod checks;
pub mod loop_unroll;
pub mod utils;
pub mod builders;

macro_rules! pre_hook {
($name: ident, $bucket_ty: ty) => {
Expand Down

0 comments on commit dbfc8a6

Please sign in to comment.