diff --git a/language/bytecode_verifier/invalid_mutations/src/bounds.rs b/language/bytecode_verifier/invalid_mutations/src/bounds.rs index 164d3f63319a..d5064fccf964 100644 --- a/language/bytecode_verifier/invalid_mutations/src/bounds.rs +++ b/language/bytecode_verifier/invalid_mutations/src/bounds.rs @@ -10,9 +10,9 @@ use std::collections::BTreeMap; use vm::{ errors::{VMStaticViolation, VerificationError}, file_format::{ - AddressPoolIndex, CompiledModule, FieldDefinitionIndex, FunctionHandleIndex, - FunctionSignatureIndex, LocalsSignatureIndex, ModuleHandleIndex, StringPoolIndex, - StructHandleIndex, TableIndex, TypeSignatureIndex, + AddressPoolIndex, CompiledModule, CompiledModuleMut, FieldDefinitionIndex, + FunctionHandleIndex, FunctionSignatureIndex, LocalsSignatureIndex, ModuleHandleIndex, + StringPoolIndex, StructHandleIndex, TableIndex, TypeSignatureIndex, }, internals::ModuleIndex, views::{ModuleView, SignatureTokenView}, @@ -162,8 +162,8 @@ impl AsRef for OutOfBoundsMutation { } } -pub struct ApplyOutOfBoundsContext<'a> { - module: &'a mut CompiledModule, +pub struct ApplyOutOfBoundsContext { + module: CompiledModuleMut, // This is an Option because it gets moved out in apply before apply_one is called. Rust // doesn't let you call another con-consuming method after a partial move out. mutations: Option>, @@ -174,14 +174,14 @@ pub struct ApplyOutOfBoundsContext<'a> { locals_sig_structs: Vec<(LocalsSignatureIndex, usize)>, } -impl<'a> ApplyOutOfBoundsContext<'a> { - pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { - let type_sig_structs: Vec<_> = Self::type_sig_structs(module).collect(); - let function_sig_structs: Vec<_> = Self::function_sig_structs(module).collect(); - let locals_sig_structs: Vec<_> = Self::locals_sig_structs(module).collect(); +impl ApplyOutOfBoundsContext { + pub fn new(module: CompiledModule, mutations: Vec) -> Self { + let type_sig_structs: Vec<_> = Self::type_sig_structs(&module).collect(); + let function_sig_structs: Vec<_> = Self::function_sig_structs(&module).collect(); + let locals_sig_structs: Vec<_> = Self::locals_sig_structs(&module).collect(); Self { - module, + module: module.into_inner(), mutations: Some(mutations), type_sig_structs, function_sig_structs, @@ -189,7 +189,7 @@ impl<'a> ApplyOutOfBoundsContext<'a> { } } - pub fn apply(mut self) -> Vec { + pub fn apply(mut self) -> (CompiledModuleMut, Vec) { // This is a map from (source kind, dest kind) to the actual mutations -- this is done to // figure out how many mutations to do for a particular pair, which is required for // pick_slice_idxs below. @@ -212,7 +212,7 @@ impl<'a> ApplyOutOfBoundsContext<'a> { // to get the lifetimes right :) results.extend(self.apply_one(src_kind, dst_kind, mutations)); } - results + (self.module, results) } fn apply_one( diff --git a/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs b/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs index 0c944b25ba8b..4d6e64ce42f5 100644 --- a/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs +++ b/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs @@ -7,7 +7,7 @@ use std::collections::BTreeMap; use vm::{ errors::{VMStaticViolation, VerificationError}, file_format::{ - AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeOffset, CompiledModule, + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeOffset, CompiledModuleMut, FieldDefinitionIndex, FunctionHandleIndex, LocalIndex, StringPoolIndex, StructDefinitionIndex, TableIndex, }, @@ -43,7 +43,7 @@ impl AsRef for CodeUnitBoundsMutation { } pub struct ApplyCodeUnitBoundsContext<'a> { - module: &'a mut CompiledModule, + module: &'a mut CompiledModuleMut, // This is so apply_one can be called after mutations has been iterated on. mutations: Option>, } @@ -82,7 +82,7 @@ macro_rules! locals_bytecode { } impl<'a> ApplyCodeUnitBoundsContext<'a> { - pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + pub fn new(module: &'a mut CompiledModuleMut, mutations: Vec) -> Self { Self { module, mutations: Some(mutations), diff --git a/language/bytecode_verifier/invalid_mutations/src/signature.rs b/language/bytecode_verifier/invalid_mutations/src/signature.rs index 99698d1c2c6f..8faed3027d94 100644 --- a/language/bytecode_verifier/invalid_mutations/src/signature.rs +++ b/language/bytecode_verifier/invalid_mutations/src/signature.rs @@ -9,7 +9,7 @@ use proptest_helpers::{pick_slice_idxs, RepeatVec}; use std::collections::BTreeMap; use vm::{ errors::{VMStaticViolation, VerificationError}, - file_format::{CompiledModule, SignatureToken}, + file_format::{CompiledModuleMut, SignatureToken}, internals::ModuleIndex, IndexKind, SignatureTokenKind, }; @@ -38,12 +38,12 @@ impl AsRef for DoubleRefMutation { /// Context for applying a list of `DoubleRefMutation` instances. pub struct ApplySignatureDoubleRefContext<'a> { - module: &'a mut CompiledModule, + module: &'a mut CompiledModuleMut, mutations: Vec, } impl<'a> ApplySignatureDoubleRefContext<'a> { - pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + pub fn new(module: &'a mut CompiledModuleMut, mutations: Vec) -> Self { Self { module, mutations } } @@ -134,12 +134,12 @@ impl AsRef for FieldRefMutation { /// Context for applying a list of `FieldRefMutation` instances. pub struct ApplySignatureFieldRefContext<'a> { - module: &'a mut CompiledModule, + module: &'a mut CompiledModuleMut, mutations: Vec, } impl<'a> ApplySignatureFieldRefContext<'a> { - pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + pub fn new(module: &'a mut CompiledModuleMut, mutations: Vec) -> Self { Self { module, mutations } } diff --git a/language/bytecode_verifier/src/check_duplication.rs b/language/bytecode_verifier/src/check_duplication.rs index d72d3f713246..2f15a3ba631d 100644 --- a/language/bytecode_verifier/src/check_duplication.rs +++ b/language/bytecode_verifier/src/check_duplication.rs @@ -14,7 +14,7 @@ use vm::{ errors::{VMStaticViolation, VerificationError}, file_format::{ CompiledModule, FieldDefinitionIndex, FunctionHandleIndex, ModuleHandleIndex, - StructHandleIndex, + StructHandleIndex, TableIndex, }, IndexKind, }; @@ -138,9 +138,14 @@ impl<'a> DuplicationChecker<'a> { break; } let next_start_field_index = start_field_index + struct_def.field_count as usize; - if !(start_field_index..next_start_field_index) - .all(|i| struct_def.struct_handle == self.module.field_defs[i].struct_) - { + let all_fields_match = (start_field_index..next_start_field_index).all(|i| { + struct_def.struct_handle + == self + .module + .field_def_at(FieldDefinitionIndex::new(i as TableIndex)) + .struct_ + }); + if !all_fields_match { idx_opt = Some(idx); break; } @@ -152,7 +157,7 @@ impl<'a> DuplicationChecker<'a> { idx, err: VMStaticViolation::InconsistentFields, }); - } else if start_field_index != self.module.field_defs.len() { + } else if start_field_index != self.module.field_defs().len() { errors.push(VerificationError { kind: IndexKind::FieldDefinition, idx: start_field_index, @@ -188,7 +193,7 @@ impl<'a> DuplicationChecker<'a> { // implemented. let implemented_struct_handles: HashSet = self.module.struct_defs().map(|x| x.struct_handle).collect(); - if let Some(idx) = (0..self.module.struct_handles.len()).position(|x| { + if let Some(idx) = (0..self.module.struct_handles().len()).position(|x| { let y = StructHandleIndex::new(x as u16); self.module.struct_handle_at(y).module == ModuleHandleIndex::new(CompiledModule::IMPLEMENTED_MODULE_INDEX) @@ -204,7 +209,7 @@ impl<'a> DuplicationChecker<'a> { // implemented. let implemented_function_handles: HashSet = self.module.function_defs().map(|x| x.function).collect(); - if let Some(idx) = (0..self.module.function_handles.len()).position(|x| { + if let Some(idx) = (0..self.module.function_handles().len()).position(|x| { let y = FunctionHandleIndex::new(x as u16); self.module.function_handle_at(y).module == ModuleHandleIndex::new(CompiledModule::IMPLEMENTED_MODULE_INDEX) diff --git a/language/bytecode_verifier/src/struct_defs.rs b/language/bytecode_verifier/src/struct_defs.rs index 1570936db511..89be7ce1c73e 100644 --- a/language/bytecode_verifier/src/struct_defs.rs +++ b/language/bytecode_verifier/src/struct_defs.rs @@ -61,7 +61,7 @@ impl<'a> StructDefGraphBuilder<'a> { let mut handle_to_def = BTreeMap::new(); // the mapping from struct definitions to struct handles is already checked to be 1-1 by // DuplicationChecker - for (idx, struct_def) in module.struct_defs.iter().enumerate() { + for (idx, struct_def) in module.struct_defs().enumerate() { let sh_idx = struct_def.struct_handle; handle_to_def.insert(sh_idx, StructDefinitionIndex::new(idx as TableIndex)); } @@ -75,7 +75,7 @@ impl<'a> StructDefGraphBuilder<'a> { pub fn build(self) -> Graph { let mut graph = Graph::new(); - let struct_def_count = self.module.struct_defs.len(); + let struct_def_count = self.module.struct_defs().len(); let nodes: Vec<_> = (0..struct_def_count) .map(|idx| graph.add_node(StructDefinitionIndex::new(idx as TableIndex))) diff --git a/language/bytecode_verifier/src/verifier.rs b/language/bytecode_verifier/src/verifier.rs index 2145e2094f67..05284abb689a 100644 --- a/language/bytecode_verifier/src/verifier.rs +++ b/language/bytecode_verifier/src/verifier.rs @@ -11,7 +11,6 @@ use std::collections::BTreeMap; use types::language_storage::CodeKey; use vm::{ access::{BaseAccess, ScriptAccess}, - checks::BoundsChecker, errors::{VMStaticViolation, VerificationError}, file_format::{CompiledModule, CompiledScript}, resolver::Resolver, @@ -19,15 +18,14 @@ use vm::{ IndexKind, }; -/// Verification of a module is performed through a sequnence of checks. -/// There is a partial order on the checs. For example, bounds checking must precede all other -/// checks and duplication check must precede the structural recursion check. In general, later -/// checks are more expensive. +/// Verification of a module is performed through a sequence of checks. +/// +/// There is a partial order on the checks. For example, the duplication check must precede the +/// structural recursion check. In general, later checks are more expensive. pub fn verify_module(module: CompiledModule) -> (CompiledModule, Vec) { - let mut errors = BoundsChecker::new(&module).verify(); - if errors.is_empty() { - errors.append(&mut DuplicationChecker::new(&module).verify()); - } + // All CompiledModule instances are statically guaranteed to be bounds checked, so there's no + // need for more checking. + let mut errors = DuplicationChecker::new(&module).verify(); if errors.is_empty() { errors.append(&mut SignatureChecker::new(&module).verify()); errors.append(&mut ResourceTransitiveChecker::new(&module).verify()); diff --git a/language/bytecode_verifier/tests/bounds_tests.rs b/language/bytecode_verifier/tests/bounds_tests.rs index 670efea0c461..37e490b5c0f4 100644 --- a/language/bytecode_verifier/tests/bounds_tests.rs +++ b/language/bytecode_verifier/tests/bounds_tests.rs @@ -10,15 +10,15 @@ use types::{account_address::AccountAddress, byte_array::ByteArray}; use vm::{ checks::BoundsChecker, errors::{VMStaticViolation, VerificationError}, + file_format::{CompiledModule, CompiledModuleMut}, proptest_types::CompiledModuleStrategyGen, - CompiledModule, IndexKind, + IndexKind, }; proptest! { #[test] - fn valid_bounds(module in CompiledModule::valid_strategy(20)) { - let bounds_checker = BoundsChecker::new(&module); - prop_assert_eq!(bounds_checker.verify(), vec![]); + fn valid_bounds(_module in CompiledModule::valid_strategy(20)) { + // valid_strategy will panic if there are any bounds check issues. } } @@ -30,9 +30,8 @@ proptest! { fn valid_bounds_no_members() { let mut gen = CompiledModuleStrategyGen::new(20); gen.member_count(0); - proptest!(|(module in gen.generate())| { - let bounds_checker = BoundsChecker::new(&module); - prop_assert_eq!(bounds_checker.verify(), vec![]) + proptest!(|(_module in gen.generate())| { + // gen.generate() will panic if there are any bounds check issues. }); } @@ -42,9 +41,8 @@ proptest! { module in CompiledModule::valid_strategy(20), oob_mutations in vec(OutOfBoundsMutation::strategy(), 0..40), ) { - let mut module = module; - let mut expected_violations = { - let oob_context = ApplyOutOfBoundsContext::new(&mut module, oob_mutations); + let (module, mut expected_violations) = { + let oob_context = ApplyOutOfBoundsContext::new(module, oob_mutations); oob_context.apply() }; expected_violations.sort(); @@ -60,7 +58,7 @@ proptest! { module in CompiledModule::valid_strategy(20), mutations in vec(CodeUnitBoundsMutation::strategy(), 0..40), ) { - let mut module = module; + let mut module = module.into_inner(); let mut expected_violations = { let context = ApplyCodeUnitBoundsContext::new(&mut module, mutations); context.apply() @@ -81,7 +79,7 @@ proptest! { ) { // If there are no module handles, the only other things that can be stored are intrinsic // data. - let mut module = CompiledModule::default(); + let mut module = CompiledModuleMut::default(); module.string_pool = string_pool; module.address_pool = address_pool; module.byte_array_pool = byte_array_pool; @@ -108,8 +106,7 @@ proptest! { /// Make sure that garbage inputs don't crash the bounds checker. #[test] - fn garbage_inputs(module in any_with::(16)) { - let bounds_checker = BoundsChecker::new(&module); - bounds_checker.verify(); + fn garbage_inputs(module in any_with::(16)) { + let _ = module.freeze(); } } diff --git a/language/bytecode_verifier/tests/duplication_tests.rs b/language/bytecode_verifier/tests/duplication_tests.rs index 1e329d56bddf..2eb8d2cc6874 100644 --- a/language/bytecode_verifier/tests/duplication_tests.rs +++ b/language/bytecode_verifier/tests/duplication_tests.rs @@ -3,12 +3,11 @@ use bytecode_verifier::DuplicationChecker; use proptest::prelude::*; -use vm::{checks::BoundsChecker, file_format::CompiledModule}; +use vm::file_format::CompiledModule; proptest! { #[test] fn valid_duplication(module in CompiledModule::valid_strategy(20)) { - prop_assert!(BoundsChecker::new(&module).verify().is_empty()); let duplication_checker = DuplicationChecker::new(&module); prop_assert!(!duplication_checker.verify().is_empty()); } diff --git a/language/bytecode_verifier/tests/resources_tests.rs b/language/bytecode_verifier/tests/resources_tests.rs index 43da53dea6d3..7423357831c6 100644 --- a/language/bytecode_verifier/tests/resources_tests.rs +++ b/language/bytecode_verifier/tests/resources_tests.rs @@ -3,12 +3,11 @@ use bytecode_verifier::ResourceTransitiveChecker; use proptest::prelude::*; -use vm::{checks::BoundsChecker, file_format::CompiledModule}; +use vm::file_format::CompiledModule; proptest! { #[test] fn valid_resource_transitivity(module in CompiledModule::valid_strategy(20)) { - prop_assert!(BoundsChecker::new(&module).verify().is_empty()); let resource_checker = ResourceTransitiveChecker::new(&module); prop_assert!(resource_checker.verify().is_empty()); } diff --git a/language/bytecode_verifier/tests/signature_tests.rs b/language/bytecode_verifier/tests/signature_tests.rs index 8b6082f07726..dcfc52d2af7e 100644 --- a/language/bytecode_verifier/tests/signature_tests.rs +++ b/language/bytecode_verifier/tests/signature_tests.rs @@ -7,12 +7,11 @@ use invalid_mutations::signature::{ FieldRefMutation, }; use proptest::{collection::vec, prelude::*}; -use vm::{checks::BoundsChecker, errors::VMStaticViolation, file_format::CompiledModule}; +use vm::{errors::VMStaticViolation, file_format::CompiledModule}; proptest! { #[test] fn valid_signatures(module in CompiledModule::valid_strategy(20)) { - prop_assert!(BoundsChecker::new(&module).verify().is_empty()); let signature_checker = SignatureChecker::new(&module); prop_assert_eq!(signature_checker.verify(), vec![]); } @@ -22,14 +21,14 @@ proptest! { module in CompiledModule::valid_strategy(20), mutations in vec(DoubleRefMutation::strategy(), 0..40), ) { - let mut module = module; + let mut module = module.into_inner(); let mut expected_violations = { let context = ApplySignatureDoubleRefContext::new(&mut module, mutations); context.apply() }; expected_violations.sort(); + let module = module.freeze().expect("should satisfy bounds checker"); - prop_assert!(BoundsChecker::new(&module).verify().is_empty()); let signature_checker = SignatureChecker::new(&module); let actual_violations = signature_checker.verify(); @@ -52,14 +51,14 @@ proptest! { module in CompiledModule::valid_strategy(20), mutations in vec(FieldRefMutation::strategy(), 0..40), ) { - let mut module = module; + let mut module = module.into_inner(); let mut expected_violations = { let context = ApplySignatureFieldRefContext::new(&mut module, mutations); context.apply() }; expected_violations.sort(); + let module = module.freeze().expect("should satisfy bounds checker"); - prop_assert!(BoundsChecker::new(&module).verify().is_empty()); let signature_checker = SignatureChecker::new(&module); let mut actual_violations = signature_checker.verify(); diff --git a/language/bytecode_verifier/tests/struct_defs_tests.rs b/language/bytecode_verifier/tests/struct_defs_tests.rs index 987343529395..1b820f5b33bb 100644 --- a/language/bytecode_verifier/tests/struct_defs_tests.rs +++ b/language/bytecode_verifier/tests/struct_defs_tests.rs @@ -3,12 +3,11 @@ use bytecode_verifier::RecursiveStructDefChecker; use proptest::prelude::*; -use vm::{checks::BoundsChecker, file_format::CompiledModule}; +use vm::file_format::CompiledModule; proptest! { #[test] fn valid_recursive_struct_defs(module in CompiledModule::valid_strategy(20)) { - prop_assert!(BoundsChecker::new(&module).verify().is_empty()); let recursive_checker = RecursiveStructDefChecker::new(&module); prop_assert!(recursive_checker.verify().is_empty()); } diff --git a/language/compiler/src/compiler.rs b/language/compiler/src/compiler.rs index 3b0492c50e57..41a0881aaf88 100644 --- a/language/compiler/src/compiler.rs +++ b/language/compiler/src/compiler.rs @@ -21,15 +21,16 @@ use std::{ }; use types::{account_address::AccountAddress, byte_array::ByteArray}; use vm::{ + access::{BaseAccess, ModuleAccess}, errors::VerificationError, file_format::{ - AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeUnit, CompiledModule, CompiledProgram, - CompiledScriptMut, FieldDefinition, FieldDefinitionIndex, FunctionDefinition, - FunctionDefinitionIndex, FunctionHandle, FunctionHandleIndex, FunctionSignature, - FunctionSignatureIndex, LocalsSignature, LocalsSignatureIndex, MemberCount, ModuleHandle, - ModuleHandleIndex, SignatureToken, StringPoolIndex, StructDefinition, - StructDefinitionIndex, StructHandle, StructHandleIndex, TableIndex, TypeSignature, - TypeSignatureIndex, SELF_MODULE_NAME, + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeUnit, CompiledModule, + CompiledModuleMut, CompiledProgram, CompiledScriptMut, FieldDefinition, + FieldDefinitionIndex, FunctionDefinition, FunctionDefinitionIndex, FunctionHandle, + FunctionHandleIndex, FunctionSignature, FunctionSignatureIndex, LocalsSignature, + LocalsSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, SignatureToken, + StringPoolIndex, StructDefinition, StructDefinitionIndex, StructHandle, StructHandleIndex, + TableIndex, TypeSignature, TypeSignatureIndex, SELF_MODULE_NAME, }, printers::TableAccess, }; @@ -357,16 +358,9 @@ impl<'a> CompilationScope<'a> { let module = &self.modules[*module_index as usize]; - let fh_idx = match module.function_defs.get(fd_idx.0 as usize) { - None => bail!( - "No function definition index {} in function definition table", - fd_idx - ), - Some(function_def) => function_def.function, - }; - - let fh = module.get_function_at(fh_idx)?; - module.get_function_signature_at(fh.signature) + let fh_idx = module.function_def_at(*fd_idx).function; + let fh = module.function_handle_at(fh_idx); + Ok(module.function_signature_at(fh.signature)) } } @@ -379,11 +373,11 @@ struct ModuleScope<'a> { field_definitions: HashMap, function_definitions: HashMap, // the module being compiled - pub module: CompiledModule, + pub module: CompiledModuleMut, } impl<'a> ModuleScope<'a> { - fn new(module: CompiledModule, modules: &[CompiledModule]) -> ModuleScope { + fn new(module: CompiledModuleMut, modules: &[CompiledModule]) -> ModuleScope { ModuleScope { compilation_scope: CompilationScope::new(modules), struct_definitions: HashMap::new(), @@ -963,7 +957,7 @@ pub fn compile_module( module: &ModuleDefinition, modules: &[CompiledModule], ) -> Result { - let compiled_module = CompiledModule::default(); + let compiled_module = CompiledModuleMut::default(); let scope = ModuleScope::new(compiled_module, modules); let mut compiler = Compiler::new(scope); let addr_idx = compiler.make_address(&address)?; @@ -999,8 +993,11 @@ pub fn compile_module( FunctionBody::Native => (), } } - - Ok(compiler.scope.module) + compiler + .scope + .module + .freeze() + .map_err(|errs| InternalCompilerError::BoundsCheckErrors(errs).into()) } /// Compile a module and invoke the bytecode verifier on it @@ -1117,11 +1114,11 @@ impl Compiler { SignatureToken::Struct(sh_idx) => { let (defining_module_name, name, is_resource) = { let module = self.scope.get_imported_module(module_name)?; - let struct_handle = module.get_struct_at(sh_idx)?; - let defining_module_handle = module.get_module_at(struct_handle.module)?; + let struct_handle = module.struct_handle_at(sh_idx); + let defining_module_handle = module.module_handle_at(struct_handle.module); ( - module.get_string_at(defining_module_handle.name)?, - module.get_string_at(struct_handle.name)?.clone(), + module.string_at(defining_module_handle.name), + module.string_at(struct_handle.name).to_string(), struct_handle.is_resource, ) }; @@ -2193,16 +2190,18 @@ impl Compiler { let target_module = self.scope.get_imported_module(module.name_ref())?; let mut idx = 0; - while idx < target_module.function_defs.len() { - let fh_idx = target_module.function_defs[idx].function; - let fh = target_module.get_function_at(fh_idx)?; - let func_name = target_module.get_string_at(fh.name)?; + while idx < target_module.function_defs().len() { + let fh_idx = target_module + .function_def_at(FunctionDefinitionIndex::new(idx as TableIndex)) + .function; + let fh = target_module.function_handle_at(fh_idx); + let func_name = target_module.string_at(fh.name); if func_name == name.name_ref() { break; } idx += 1; } - if idx == target_module.function_defs.len() { + if idx == target_module.function_defs().len() { bail!( "Cannot find function `{}' in module `{}'", name.name_ref(), diff --git a/language/compiler/src/unit_tests/expression_tests.rs b/language/compiler/src/unit_tests/expression_tests.rs index dde39a7cfdb3..febc6f580695 100644 --- a/language/compiler/src/unit_tests/expression_tests.rs +++ b/language/compiler/src/unit_tests/expression_tests.rs @@ -183,8 +183,8 @@ module Test { } }", ); - let compiled_script = compile_module_string(&code).unwrap(); - assert!(compiled_script.struct_handles.len() == 1); + let compiled_module = compile_module_string(&code).unwrap(); + assert!(compiled_module.struct_handles().len() == 1); } #[test] diff --git a/language/vm/cost_synthesis/src/module_generator.rs b/language/vm/cost_synthesis/src/module_generator.rs index 69efd8e87e12..a14e0f61fc81 100644 --- a/language/vm/cost_synthesis/src/module_generator.rs +++ b/language/vm/cost_synthesis/src/module_generator.rs @@ -14,13 +14,14 @@ use types::{account_address::AccountAddress, byte_array::ByteArray, language_sto use vm::{ access::*, file_format::{ - AddressPoolIndex, Bytecode, CodeUnit, CompiledModule, FieldDefinition, + AddressPoolIndex, Bytecode, CodeUnit, CompiledModule, CompiledModuleMut, FieldDefinition, FieldDefinitionIndex, FunctionDefinition, FunctionHandle, FunctionHandleIndex, FunctionSignature, FunctionSignatureIndex, LocalsSignature, LocalsSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, SignatureToken, StringPoolIndex, StructDefinition, StructHandle, StructHandleIndex, TableIndex, TypeSignature, TypeSignatureIndex, }, + internals::ModuleIndex, }; /// A wrapper around a `CompiledModule` containing information needed for generation. @@ -36,7 +37,7 @@ pub struct ModuleBuilder { gen: StdRng, /// The current module being built. - module: CompiledModule, + module: CompiledModuleMut, /// The minimum size of the tables in the generated module. table_size: TableIndex, @@ -248,20 +249,20 @@ impl ModuleBuilder { // We have half/half inter- and intra-module calls. let number_of_cross_calls = self.table_size; for _ in 0..number_of_cross_calls { - let non_self_module_handle_idx = self.gen.gen_range(1, module_table_size) as TableIndex; - let callee_module_handle = self - .module - .module_handle_at(ModuleHandleIndex::new(non_self_module_handle_idx)); - let address = *self.module.address_at(callee_module_handle.address); - let name = self.module.string_at(callee_module_handle.name); + let non_self_module_handle_idx = self.gen.gen_range(1, module_table_size); + let callee_module_handle = &self.module.module_handles[non_self_module_handle_idx]; + let address = self.module.address_pool[callee_module_handle.address.into_index()]; + let name = &self.module.string_pool[callee_module_handle.name.into_index()]; let code_key = CodeKey::new(address, name.to_string()); let callee_module = self .known_modules .get(&code_key) .expect("[Module Lookup] Unable to get module from known_modules."); - let callee_function_handle_idx = - self.gen.gen_range(0, callee_module.function_handles.len()) as TableIndex; + let callee_function_handle_idx = self + .gen + .gen_range(0, callee_module.function_handles().len()) + as TableIndex; let callee_function_handle = callee_module .function_handle_at(FunctionHandleIndex::new(callee_function_handle_idx)); let callee_type_sig = callee_module @@ -273,7 +274,7 @@ impl ModuleBuilder { let callee_name_idx = self.module.string_pool.len() as TableIndex; let callee_type_sig_idx = self.module.function_signatures.len() as TableIndex; let func_handle = FunctionHandle { - module: ModuleHandleIndex::new(non_self_module_handle_idx), + module: ModuleHandleIndex::new(non_self_module_handle_idx as TableIndex), name: StringPoolIndex::new(callee_name_idx), signature: FunctionSignatureIndex::new(callee_type_sig_idx), }; @@ -333,16 +334,17 @@ impl ModuleBuilder { self.with_random_functions(); self.with_structs(); let module = std::mem::replace(&mut self.module, Self::default_module_with_types()); + let module = module.freeze().expect("should satisfy bounds checker"); self.known_modules .insert(module.self_code_key(), module.clone()); module } - // This method generates a default (empty) `CompiledModule` but with base types. This way we + // This method generates a default (empty) `CompiledModuleMut` but with base types. This way we // can point to them when generating structs/functions etc. - fn default_module_with_types() -> CompiledModule { + fn default_module_with_types() -> CompiledModuleMut { use SignatureToken::*; - let mut module = CompiledModule::default(); + let mut module = CompiledModuleMut::default(); module.type_signatures = vec![Bool, U64, String, ByteArray, Address] .into_iter() .map(TypeSignature) diff --git a/language/vm/cost_synthesis/src/stack_generator.rs b/language/vm/cost_synthesis/src/stack_generator.rs index d4008a1f72c1..4fc8ef53a759 100644 --- a/language/vm/cost_synthesis/src/stack_generator.rs +++ b/language/vm/cost_synthesis/src/stack_generator.rs @@ -17,11 +17,10 @@ use vm::{ assert_ok, file_format::{ AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeOffset, FieldDefinitionIndex, - FunctionDefinition, FunctionDefinitionIndex, FunctionHandleIndex, LocalIndex, ModuleHandle, - SignatureToken, StringPoolIndex, StructDefinition, StructDefinitionIndex, + FunctionDefinition, FunctionDefinitionIndex, FunctionHandleIndex, LocalIndex, MemberCount, + ModuleHandle, SignatureToken, StringPoolIndex, StructDefinition, StructDefinitionIndex, StructHandleIndex, TableIndex, }, - internals::ModuleIndex, }; use vm_runtime::{ code_cache::module_cache::ModuleCache, execution_stack::ExecutionStack, @@ -209,8 +208,9 @@ where let len: usize = self.gen.gen_range(1, MAX_STRING_SIZE); (0..len).map(|_| self.gen.gen::()).collect::() } else { - let string = - self.root_module.module.string_pool[self.string_pool_index as usize].clone(); + let string = self.root_module.module.as_inner().string_pool + [self.string_pool_index as usize] + .clone(); self.string_pool_index = self .string_pool_index .checked_sub(1) @@ -223,7 +223,8 @@ where if !self.points_to_module_data() || is_padding { AccountAddress::random() } else { - let address = self.root_module.module.address_pool[self.address_pool_index as usize]; + let address = + self.root_module.module.as_inner().address_pool[self.address_pool_index as usize]; self.address_pool_index = self .address_pool_index .checked_sub(1) @@ -237,23 +238,23 @@ where } fn next_string_idx(&mut self) -> StringPoolIndex { - let len = self.root_module.module.string_pool.len(); + let len = self.root_module.module.string_pool().len(); StringPoolIndex::new(self.gen.gen_range(0, len) as TableIndex) } fn next_address_idx(&mut self) -> AddressPoolIndex { - let len = self.root_module.module.address_pool.len(); + let len = self.root_module.module.address_pool().len(); AddressPoolIndex::new(self.gen.gen_range(0, len) as TableIndex) } fn next_bytearray_idx(&mut self) -> ByteArrayPoolIndex { - let len = self.root_module.module.byte_array_pool.len(); + let len = self.root_module.module.byte_array_pool().len(); ByteArrayPoolIndex::new(self.gen.gen_range(0, len) as TableIndex) } fn next_function_handle_idx(&mut self) -> FunctionHandleIndex { let table_idx = - self.next_bounded_index(self.root_module.module.function_handles.len() as TableIndex); + self.next_bounded_index(self.root_module.module.function_handles().len() as TableIndex); FunctionHandleIndex::new(table_idx) } @@ -469,10 +470,12 @@ where .module .struct_def_at(self.resolve_struct_handle(struct_handle_idx).2); let num_fields = struct_definition.field_count as usize; - let index = struct_definition.fields.into_index(); - let fields = &self.root_module.module.field_defs[index..index + num_fields]; + let index = struct_definition.fields; + let fields = self + .root_module + .module + .field_def_range(num_fields as MemberCount, index); let mutvals = fields - .iter() .map(|field| { self.resolve_to_value( self.root_module @@ -543,15 +546,17 @@ where ) } Pack(_struct_def_idx) => { - let struct_def_bound = self.root_module.module.struct_defs.len() as TableIndex; + let struct_def_bound = self.root_module.module.struct_defs().len() as TableIndex; let random_struct_idx = StructDefinitionIndex::new(self.next_bounded_index(struct_def_bound)); let struct_definition = self.root_module.module.struct_def_at(random_struct_idx); let num_fields = struct_definition.field_count as usize; - let index = struct_definition.fields.into_index(); - let fields = &self.root_module.module.field_defs[index..index + num_fields]; + let index = struct_definition.fields; + let fields = self + .root_module + .module + .field_def_range(num_fields as MemberCount, index); let stack: Stack = fields - .iter() .map(|field| { let ty = self .root_module @@ -572,7 +577,7 @@ where ) } Unpack(_struct_def_idx) => { - let struct_def_bound = self.root_module.module.struct_defs.len() as TableIndex; + let struct_def_bound = self.root_module.module.struct_defs().len() as TableIndex; let random_struct_idx = StructDefinitionIndex::new(self.next_bounded_index(struct_def_bound)); let struct_handle_idx = self @@ -593,7 +598,7 @@ where } BorrowField(_) => { // First grab a random struct - let struct_def_bound = self.root_module.module.struct_defs.len() as TableIndex; + let struct_def_bound = self.root_module.module.struct_defs().len() as TableIndex; let random_struct_idx = StructDefinitionIndex::new(self.next_bounded_index(struct_def_bound)); let struct_definition = self.root_module.module.struct_def_at(random_struct_idx); diff --git a/language/vm/src/access.rs b/language/vm/src/access.rs index 6b70c099aefb..d4d09720f58c 100644 --- a/language/vm/src/access.rs +++ b/language/vm/src/access.rs @@ -10,12 +10,12 @@ use types::{account_address::AccountAddress, byte_array::ByteArray, language_sto use crate::{ errors::VMStaticViolation, file_format::{ - AddressPoolIndex, ByteArrayPoolIndex, CompiledModule, CompiledScript, FieldDefinition, - FieldDefinitionIndex, FunctionDefinition, FunctionDefinitionIndex, FunctionHandle, - FunctionHandleIndex, FunctionSignature, FunctionSignatureIndex, LocalsSignature, - LocalsSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, StringPoolIndex, - StructDefinition, StructDefinitionIndex, StructHandle, StructHandleIndex, TypeSignature, - TypeSignatureIndex, + AddressPoolIndex, ByteArrayPoolIndex, CompiledModule, CompiledModuleMut, CompiledScript, + FieldDefinition, FieldDefinitionIndex, FunctionDefinition, FunctionDefinitionIndex, + FunctionHandle, FunctionHandleIndex, FunctionSignature, FunctionSignatureIndex, + LocalsSignature, LocalsSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, + StringPoolIndex, StructDefinition, StructDefinitionIndex, StructHandle, StructHandleIndex, + TypeSignature, TypeSignatureIndex, }, internals::ModuleIndex, IndexKind, @@ -155,78 +155,9 @@ macro_rules! impl_base_access { }; } -// impl_base_access!(CompiledModule); +impl_base_access!(CompiledModule); impl_base_access!(CompiledScript); -// XXX this is temporary while CompiledModule gets moved to dual mutable/immutable versions. -impl BaseAccess for CompiledModule { - fn module_handle_at(&self, idx: ModuleHandleIndex) -> &ModuleHandle { - &self.module_handles[idx.into_index()] - } - - fn struct_handle_at(&self, idx: StructHandleIndex) -> &StructHandle { - &self.struct_handles[idx.into_index()] - } - - fn function_handle_at(&self, idx: FunctionHandleIndex) -> &FunctionHandle { - &self.function_handles[idx.into_index()] - } - - fn type_signature_at(&self, idx: TypeSignatureIndex) -> &TypeSignature { - &self.type_signatures[idx.into_index()] - } - - fn function_signature_at(&self, idx: FunctionSignatureIndex) -> &FunctionSignature { - &self.function_signatures[idx.into_index()] - } - - fn locals_signature_at(&self, idx: LocalsSignatureIndex) -> &LocalsSignature { - &self.locals_signatures[idx.into_index()] - } - - fn string_at(&self, idx: StringPoolIndex) -> &str { - self.string_pool[idx.into_index()].as_str() - } - - fn byte_array_at(&self, idx: ByteArrayPoolIndex) -> &ByteArray { - &self.byte_array_pool[idx.into_index()] - } - - fn address_at(&self, idx: AddressPoolIndex) -> &AccountAddress { - &self.address_pool[idx.into_index()] - } - - fn module_handles(&self) -> slice::Iter { - self.module_handles[..].iter() - } - fn struct_handles(&self) -> slice::Iter { - self.struct_handles[..].iter() - } - fn function_handles(&self) -> slice::Iter { - self.function_handles[..].iter() - } - - fn type_signatures(&self) -> slice::Iter { - self.type_signatures[..].iter() - } - fn function_signatures(&self) -> slice::Iter { - self.function_signatures[..].iter() - } - fn locals_signatures(&self) -> slice::Iter { - self.locals_signatures[..].iter() - } - - fn byte_array_pool(&self) -> slice::Iter { - self.byte_array_pool[..].iter() - } - fn address_pool(&self) -> slice::Iter { - self.address_pool[..].iter() - } - fn string_pool(&self) -> slice::Iter { - self.string_pool[..].iter() - } -} - impl ModuleAccess for CompiledModule { fn self_code_key(&self) -> CodeKey { self.self_code_key() @@ -237,27 +168,27 @@ impl ModuleAccess for CompiledModule { } fn struct_def_at(&self, idx: StructDefinitionIndex) -> &StructDefinition { - &self.struct_defs[idx.into_index()] + &self.as_inner().struct_defs[idx.into_index()] } fn field_def_at(&self, idx: FieldDefinitionIndex) -> &FieldDefinition { - &self.field_defs[idx.into_index()] + &self.as_inner().field_defs[idx.into_index()] } fn function_def_at(&self, idx: FunctionDefinitionIndex) -> &FunctionDefinition { - &self.function_defs[idx.into_index()] + &self.as_inner().function_defs[idx.into_index()] } fn struct_defs(&self) -> slice::Iter { - self.struct_defs[..].iter() + self.as_inner().struct_defs[..].iter() } fn field_defs(&self) -> slice::Iter { - self.field_defs[..].iter() + self.as_inner().field_defs[..].iter() } fn function_defs(&self) -> slice::Iter { - self.function_defs[..].iter() + self.as_inner().function_defs[..].iter() } fn field_def_range( @@ -268,7 +199,7 @@ impl ModuleAccess for CompiledModule { let first_field = first_field.0 as usize; let field_count = field_count as usize; let last_field = first_field + field_count; - self.field_defs[first_field..last_field].iter() + self.as_inner().field_defs[first_field..last_field].iter() } } @@ -278,7 +209,7 @@ impl ScriptAccess for CompiledScript { } } -impl CompiledModule { +impl CompiledModuleMut { #[inline] pub(crate) fn check_field_range( &self, diff --git a/language/vm/src/checks/bounds.rs b/language/vm/src/checks/bounds.rs index b4b2695914a9..1d9ab8ca0e4b 100644 --- a/language/vm/src/checks/bounds.rs +++ b/language/vm/src/checks/bounds.rs @@ -4,7 +4,7 @@ use crate::{ errors::{VMStaticViolation, VerificationError}, file_format::{ - Bytecode, CompiledModule, FieldDefinition, FunctionDefinition, FunctionHandle, + Bytecode, CompiledModuleMut, FieldDefinition, FunctionDefinition, FunctionHandle, FunctionSignature, LocalsSignature, ModuleHandle, SignatureToken, StructDefinition, StructHandle, TypeSignature, }, @@ -13,11 +13,11 @@ use crate::{ }; pub struct BoundsChecker<'a> { - module: &'a CompiledModule, + module: &'a CompiledModuleMut, } impl<'a> BoundsChecker<'a> { - pub fn new(module: &'a CompiledModule) -> Self { + pub fn new(module: &'a CompiledModuleMut) -> Self { Self { module } } @@ -108,7 +108,7 @@ impl<'a> BoundsChecker<'a> { fn verify_impl( kind: IndexKind, iter: impl Iterator, - module: &CompiledModule, + module: &CompiledModuleMut, ) -> Vec { iter.enumerate() .map(move |(idx, elem)| { @@ -122,7 +122,7 @@ impl<'a> BoundsChecker<'a> { } pub trait BoundsCheck { - fn check_bounds(&self, module: &CompiledModule) -> Vec; + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec; } #[inline] @@ -141,7 +141,7 @@ where impl BoundsCheck for &ModuleHandle { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { vec![ check_bounds_impl(&module.address_pool, self.address), check_bounds_impl(&module.string_pool, self.name), @@ -154,7 +154,7 @@ impl BoundsCheck for &ModuleHandle { impl BoundsCheck for &StructHandle { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { vec![ check_bounds_impl(&module.module_handles, self.module), check_bounds_impl(&module.string_pool, self.name), @@ -167,7 +167,7 @@ impl BoundsCheck for &StructHandle { impl BoundsCheck for &FunctionHandle { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { vec![ check_bounds_impl(&module.module_handles, self.module), check_bounds_impl(&module.string_pool, self.name), @@ -181,7 +181,7 @@ impl BoundsCheck for &FunctionHandle { impl BoundsCheck for &StructDefinition { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { vec![ check_bounds_impl(&module.struct_handles, self.struct_handle), module.check_field_range(self.field_count, self.fields), @@ -194,7 +194,7 @@ impl BoundsCheck for &StructDefinition { impl BoundsCheck for &FieldDefinition { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { vec![ check_bounds_impl(&module.struct_handles, self.struct_), check_bounds_impl(&module.string_pool, self.name), @@ -208,7 +208,7 @@ impl BoundsCheck for &FieldDefinition { impl BoundsCheck for &FunctionDefinition { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { vec![ check_bounds_impl(&module.function_handles, self.function), if self.is_native() { @@ -225,14 +225,14 @@ impl BoundsCheck for &FunctionDefinition { impl BoundsCheck for &TypeSignature { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { self.0.check_bounds(module).into_iter().collect() } } impl BoundsCheck for &FunctionSignature { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { self.return_types .iter() .filter_map(|token| token.check_bounds(module)) @@ -247,7 +247,7 @@ impl BoundsCheck for &FunctionSignature { impl BoundsCheck for &LocalsSignature { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Vec { + fn check_bounds(&self, module: &CompiledModuleMut) -> Vec { self.0 .iter() .filter_map(|token| token.check_bounds(module)) @@ -257,7 +257,7 @@ impl BoundsCheck for &LocalsSignature { impl SignatureToken { #[inline] - fn check_bounds(&self, module: &CompiledModule) -> Option { + fn check_bounds(&self, module: &CompiledModuleMut) -> Option { match self.struct_index() { Some(sh_idx) => check_bounds_impl(&module.struct_handles, sh_idx), None => None, @@ -268,7 +268,7 @@ impl SignatureToken { impl FunctionDefinition { // This is implemented separately because it depends on the locals signature index being // checked. - fn check_code_unit_bounds(&self, module: &CompiledModule) -> Vec { + fn check_code_unit_bounds(&self, module: &CompiledModuleMut) -> Vec { if self.is_native() { return vec![]; } diff --git a/language/vm/src/deserializer.rs b/language/vm/src/deserializer.rs index 3477561a63f7..58586723a1b0 100644 --- a/language/vm/src/deserializer.rs +++ b/language/vm/src/deserializer.rs @@ -1,7 +1,7 @@ // Copyright (c) The Libra Core Contributors // SPDX-License-Identifier: Apache-2.0 -use crate::{checks::BoundsChecker, errors::*, file_format::*, file_format_common::*}; +use crate::{errors::*, file_format::*, file_format_common::*}; use byteorder::{LittleEndian, ReadBytesExt}; use std::{ collections::HashSet, @@ -28,25 +28,18 @@ impl CompiledScriptMut { } impl CompiledModule { - /// Deserialize a &[u8] slice into a module (`CompiledModule`) - pub fn deserialize(binary: &[u8]) -> BinaryLoaderResult { - let compiled_module = Self::deserialize_no_check_bounds(binary)?; - compiled_module.check_bounds() + /// Deserialize a &[u8] slice into a `CompiledModule` instance. + pub fn deserialize(binary: &[u8]) -> BinaryLoaderResult { + let deserialized = CompiledModuleMut::deserialize_no_check_bounds(binary)?; + deserialized.freeze().map_err(|_| BinaryError::Malformed) } +} +impl CompiledModuleMut { // exposed as a public function to enable testing the deserializer - pub fn deserialize_no_check_bounds(binary: &[u8]) -> BinaryLoaderResult { + pub fn deserialize_no_check_bounds(binary: &[u8]) -> BinaryLoaderResult { deserialize_compiled_module(binary) } - - /// Checks that all indexes are in bound in this `CompiledModule`. - pub fn check_bounds(self) -> BinaryLoaderResult { - if BoundsChecker::new(&self).verify().is_empty() { - Ok(self) - } else { - Err(BinaryError::Malformed) - } - } } /// Table info: table type, offset where the table content starts from, count of bytes for @@ -81,7 +74,7 @@ fn deserialize_compiled_script(binary: &[u8]) -> BinaryLoaderResult BinaryLoaderResult { +fn deserialize_compiled_module(binary: &[u8]) -> BinaryLoaderResult { let binary_len = binary.len() as u64; let mut cursor = Cursor::new(binary); let table_count = check_binary(&mut cursor)?; @@ -245,7 +238,7 @@ impl CommonTables for CompiledScriptMut { } } -impl CommonTables for CompiledModule { +impl CommonTables for CompiledModuleMut { fn get_module_handles(&mut self) -> &mut Vec { &mut self.module_handles } @@ -291,9 +284,9 @@ fn build_compiled_script(binary: &[u8], tables: &[Table]) -> BinaryLoaderResult< Ok(script) } -/// Builds and returns a `CompiledModule`. -fn build_compiled_module(binary: &[u8], tables: &[Table]) -> BinaryLoaderResult { - let mut module = CompiledModule::default(); +/// Builds and returns a `CompiledModuleMut`. +fn build_compiled_module(binary: &[u8], tables: &[Table]) -> BinaryLoaderResult { + let mut module = CompiledModuleMut::default(); build_common_tables(binary, tables, &mut module)?; build_module_tables(binary, tables, &mut module)?; Ok(module) @@ -343,11 +336,11 @@ fn build_common_tables( Ok(()) } -/// Builds tables related to a `CompiledModule`. +/// Builds tables related to a `CompiledModuleMut`. fn build_module_tables( binary: &[u8], tables: &[Table], - module: &mut CompiledModule, + module: &mut CompiledModuleMut, ) -> BinaryLoaderResult<()> { for table in tables { match table.kind { diff --git a/language/vm/src/file_format.rs b/language/vm/src/file_format.rs index e59f726a1748..a83aeda36ab6 100644 --- a/language/vm/src/file_format.rs +++ b/language/vm/src/file_format.rs @@ -1084,7 +1084,7 @@ impl CompiledScript { /// If a `CompiledScript` has been bounds checked, the corresponding `CompiledModule` can be /// assumed to pass the bounds checker as well. pub fn into_module(self) -> CompiledModule { - self.0.into_module() + CompiledModule(self.0.into_module()) } } @@ -1093,18 +1093,13 @@ impl CompiledScriptMut { /// consistency. This includes bounds checks but no others. pub fn freeze(self) -> Result> { let fake_module = self.into_module(); - let errors = BoundsChecker::new(&fake_module).verify(); - if errors.is_empty() { - Ok(fake_module.into_script()) - } else { - Err(errors) - } + Ok(fake_module.freeze()?.into_script()) } /// Converts a `CompiledScriptMut` to a `CompiledModule` for code that wants a uniform view /// of both. - pub fn into_module(self) -> CompiledModule { - CompiledModule { + pub fn into_module(self) -> CompiledModuleMut { + CompiledModuleMut { module_handles: self.module_handles, struct_handles: self.struct_handles, function_handles: self.function_handles, @@ -1130,8 +1125,13 @@ impl CompiledScriptMut { /// It is a unit of code that can be used by transactions or other modules. /// /// A module is published as a single entry and it is retrieved as a single blob. -#[derive(Clone, Default, Eq, PartialEq, Debug)] -pub struct CompiledModule { +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct CompiledModule(CompiledModuleMut); + +/// A mutable version of `CompiledModule`. Converting to a `CompiledModule` requires this to pass +/// the bounds checker. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct CompiledModuleMut { /// Handles to external modules and self at position 0. pub module_handles: Vec, /// Handles to external and internal types. @@ -1166,7 +1166,7 @@ pub struct CompiledModule { // Need a custom implementation of Arbitrary because as of proptest-derive 0.1.1, the derivation // doesn't work for structs with more than 10 fields. -impl Arbitrary for CompiledModule { +impl Arbitrary for CompiledModuleMut { type Strategy = BoxedStrategy; /// The size of the compiled module. type Parameters = usize; @@ -1201,7 +1201,7 @@ impl Arbitrary for CompiledModule { (string_pool, byte_array_pool, address_pool), (struct_defs, field_defs, function_defs), )| { - CompiledModule { + CompiledModuleMut { module_handles, struct_handles, function_handles, @@ -1221,12 +1221,58 @@ impl Arbitrary for CompiledModule { } } +impl CompiledModuleMut { + /// Returns the count of a specific `IndexKind` + pub fn kind_count(&self, kind: IndexKind) -> usize { + match kind { + IndexKind::ModuleHandle => self.module_handles.len(), + IndexKind::StructHandle => self.struct_handles.len(), + IndexKind::FunctionHandle => self.function_handles.len(), + IndexKind::StructDefinition => self.struct_defs.len(), + IndexKind::FieldDefinition => self.field_defs.len(), + IndexKind::FunctionDefinition => self.function_defs.len(), + IndexKind::TypeSignature => self.type_signatures.len(), + IndexKind::FunctionSignature => self.function_signatures.len(), + IndexKind::LocalsSignature => self.locals_signatures.len(), + IndexKind::StringPool => self.string_pool.len(), + IndexKind::ByteArrayPool => self.byte_array_pool.len(), + IndexKind::AddressPool => self.address_pool.len(), + // XXX these two don't seem to belong here + other @ IndexKind::LocalPool | other @ IndexKind::CodeDefinition => { + panic!("invalid kind for count: {:?}", other) + } + } + } + + /// Converts this instance into `CompiledModule` after verifying it for basic internal + /// consistency. This includes bounds checks but no others. + pub fn freeze(self) -> Result> { + let errors = BoundsChecker::new(&self).verify(); + if errors.is_empty() { + Ok(CompiledModule(self)) + } else { + Err(errors) + } + } +} + impl CompiledModule { /// By convention, the index of the module being implemented is 0. pub const IMPLEMENTED_MODULE_INDEX: u16 = 0; fn self_handle(&self) -> &ModuleHandle { - &self.module_handles[Self::IMPLEMENTED_MODULE_INDEX as usize] + &self.module_handle_at(ModuleHandleIndex::new(Self::IMPLEMENTED_MODULE_INDEX)) + } + + /// Returns a reference to the inner `CompiledModuleMut`. + pub fn as_inner(&self) -> &CompiledModuleMut { + &self.0 + } + + /// Converts this instance into the inner `CompiledModuleMut`. Converting back to a + /// `CompiledModule` would require it to be verified again. + pub fn into_inner(self) -> CompiledModuleMut { + self.0 } /// Returns the name of the module. @@ -1239,6 +1285,11 @@ impl CompiledModule { self.address_at(self.self_handle().address) } + /// Returns the number of items of a specific `IndexKind`. + pub fn kind_count(&self, kind: IndexKind) -> usize { + self.as_inner().kind_count(kind) + } + /// Returns the code key of `module_handle` pub fn code_key_for_handle(&self, module_handle: &ModuleHandle) -> CodeKey { CodeKey::new( @@ -1252,47 +1303,24 @@ impl CompiledModule { self.code_key_for_handle(self.self_handle()) } - /// Returns the count of a specific `IndexKind` - pub fn kind_count(&self, kind: IndexKind) -> usize { - match kind { - IndexKind::ModuleHandle => self.module_handles.len(), - IndexKind::StructHandle => self.struct_handles.len(), - IndexKind::FunctionHandle => self.function_handles.len(), - IndexKind::StructDefinition => self.struct_defs.len(), - IndexKind::FieldDefinition => self.field_defs.len(), - IndexKind::FunctionDefinition => self.function_defs.len(), - IndexKind::TypeSignature => self.type_signatures.len(), - IndexKind::FunctionSignature => self.function_signatures.len(), - IndexKind::LocalsSignature => self.locals_signatures.len(), - IndexKind::StringPool => self.string_pool.len(), - IndexKind::ByteArrayPool => self.byte_array_pool.len(), - IndexKind::AddressPool => self.address_pool.len(), - // XXX these two don't seem to belong here - other @ IndexKind::LocalPool | other @ IndexKind::CodeDefinition => { - panic!("invalid kind for count: {:?}", other) - } - } - } - /// This function should only be called on an instance of CompiledModule obtained by invoking /// into_module on some instance of CompiledScript. This function is the inverse of /// into_module, i.e., script.into_module().into_script() == script. - pub fn into_script(mut self) -> CompiledScript { - let main = self.function_defs.remove(0); - // XXX this assumes that CompiledModule instances have already been bounds checked. Encode - // this assumption in the type system. + pub fn into_script(self) -> CompiledScript { + let mut inner = self.into_inner(); + let main = inner.function_defs.remove(0); CompiledScript(CompiledScriptMut { - module_handles: self.module_handles, - struct_handles: self.struct_handles, - function_handles: self.function_handles, + module_handles: inner.module_handles, + struct_handles: inner.struct_handles, + function_handles: inner.function_handles, - type_signatures: self.type_signatures, - function_signatures: self.function_signatures, - locals_signatures: self.locals_signatures, + type_signatures: inner.type_signatures, + function_signatures: inner.function_signatures, + locals_signatures: inner.locals_signatures, - string_pool: self.string_pool, - byte_array_pool: self.byte_array_pool, - address_pool: self.address_pool, + string_pool: inner.string_pool, + byte_array_pool: inner.byte_array_pool, + address_pool: inner.address_pool, main, }) diff --git a/language/vm/src/printers.rs b/language/vm/src/printers.rs index aa7470fb5d88..c1903d150d47 100644 --- a/language/vm/src/printers.rs +++ b/language/vm/src/printers.rs @@ -144,7 +144,7 @@ impl TableAccess for CompiledScriptMut { } } -impl TableAccess for CompiledModule { +impl TableAccess for CompiledModuleMut { fn get_field_def_at(&self, idx: FieldDefinitionIndex) -> Result<&FieldDefinition> { match self.field_defs.get(idx.0 as usize) { None => bail!("bad field definition index {}", idx), @@ -292,97 +292,98 @@ impl fmt::Display for CompiledScript { impl fmt::Display for CompiledModule { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let inner = self.as_inner(); writeln!(f, "CompiledModule: {{")?; write!(f, "Module Handles: [")?; - for module_handle in &self.module_handles { + for module_handle in &inner.module_handles { write!(f, "\n\t")?; - display_module_handle(module_handle, self, f)?; + display_module_handle(module_handle, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Struct Handles: [")?; - for struct_handle in &self.struct_handles { + for struct_handle in &inner.struct_handles { write!(f, "\n\t")?; - display_struct_handle(struct_handle, self, f)?; + display_struct_handle(struct_handle, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Function Handles: [")?; - for function_handle in &self.function_handles { + for function_handle in &inner.function_handles { write!(f, "\n\t")?; - display_function_handle(function_handle, self, f)?; + display_function_handle(function_handle, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Struct Definitions: [")?; - for struct_def in &self.struct_defs { + for struct_def in &inner.struct_defs { write!(f, "\n\t{{")?; - display_struct_definition(struct_def, self, f)?; + display_struct_definition(struct_def, inner, f)?; let f_start_idx = struct_def.fields; let f_end_idx = f_start_idx.0 as u16 + struct_def.field_count; for idx in f_start_idx.0 as u16..f_end_idx { - let field_def = match self.field_defs.get(idx as usize) { + let field_def = match inner.field_defs.get(idx as usize) { None => panic!("bad field definition index {}", idx), Some(f) => f, }; write!(f, "\n\t\t")?; - display_field_definition(field_def, self, f)?; + display_field_definition(field_def, inner, f)?; } write!(f, "}},")?; } writeln!(f, "]")?; write!(f, "Field Definitions: [")?; - for field_def in &self.field_defs { + for field_def in &inner.field_defs { write!(f, "\n\t")?; - display_field_definition(field_def, self, f)?; + display_field_definition(field_def, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Function Definitions: [")?; - for function_def in &self.function_defs { + for function_def in &inner.function_defs { write!(f, "\n\t")?; - display_function_definition(function_def, self, f)?; + display_function_definition(function_def, inner, f)?; if function_def.flags & CodeUnit::NATIVE == 0 { - display_code(&function_def.code, self, "\n\t\t", f)?; + display_code(&function_def.code, inner, "\n\t\t", f)?; } write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Type Signatures: [")?; - for signature in &self.type_signatures { + for signature in &inner.type_signatures { write!(f, "\n\t")?; - display_type_signature(signature, self, f)?; + display_type_signature(signature, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Function Signatures: [")?; - for signature in &self.function_signatures { + for signature in &inner.function_signatures { write!(f, "\n\t")?; - display_function_signature(signature, self, f)?; + display_function_signature(signature, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Locals Signatures: [")?; - for signature in &self.locals_signatures { + for signature in &inner.locals_signatures { write!(f, "\n\t")?; - display_locals_signature(signature, self, f)?; + display_locals_signature(signature, inner, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Strings: [")?; - for string in &self.string_pool { + for string in &inner.string_pool { write!(f, "\n\t{},", string)?; } writeln!(f, "]")?; write!(f, "ByteArrays: [")?; - for byte_array in &self.byte_array_pool { + for byte_array in &inner.byte_array_pool { write!(f, "\n\t")?; display_byte_array(byte_array, f)?; write!(f, ",")?; } writeln!(f, "]")?; write!(f, "Addresses: [")?; - for address in &self.address_pool { + for address in &inner.address_pool { write!(f, "\n\t")?; display_address(address, f)?; write!(f, ",")?; diff --git a/language/vm/src/proptest_types.rs b/language/vm/src/proptest_types.rs index 6552a38382dc..6d24e4f73de9 100644 --- a/language/vm/src/proptest_types.rs +++ b/language/vm/src/proptest_types.rs @@ -4,10 +4,10 @@ //! Utilities for property-based testing. use crate::file_format::{ - AddressPoolIndex, CompiledModule, FieldDefinition, FieldDefinitionIndex, FunctionHandle, - FunctionSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, SignatureToken, - StringPoolIndex, StructDefinition, StructHandle, StructHandleIndex, TableIndex, TypeSignature, - TypeSignatureIndex, + AddressPoolIndex, CompiledModule, CompiledModuleMut, FieldDefinition, FieldDefinitionIndex, + FunctionHandle, FunctionSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, + SignatureToken, StringPoolIndex, StructDefinition, StructHandle, StructHandleIndex, TableIndex, + TypeSignature, TypeSignatureIndex, }; use proptest::{ collection::{vec, SizeRange}, @@ -307,7 +307,7 @@ impl CompiledModuleStrategyGen { assert_eq!(function_handles_len, function_handles.len()); // Put it all together. - CompiledModule { + CompiledModuleMut { module_handles, struct_handles, function_handles, @@ -324,6 +324,8 @@ impl CompiledModuleStrategyGen { byte_array_pool, address_pool, } + .freeze() + .expect("valid modules should satisfy the bounds checker") }, ) } diff --git a/language/vm/src/serializer.rs b/language/vm/src/serializer.rs index 0ee2bc90cdf6..d5609532c93f 100644 --- a/language/vm/src/serializer.rs +++ b/language/vm/src/serializer.rs @@ -38,6 +38,16 @@ impl CompiledScriptMut { impl CompiledModule { /// Serializes a `CompiledModule` into a binary. The mutable `Vec` will contain the /// binary blob on return. + pub fn serialize(&self, binary: &mut Vec) -> Result<()> { + self.as_inner().serialize(binary) + } +} + +impl CompiledModuleMut { + /// Serializes this into a binary format. + /// + /// This is intended mainly for test code. Production code will typically use + /// [`CompiledModule::serialize`]. pub fn serialize(&self, binary: &mut Vec) -> Result<()> { let mut ser = ModuleSerializer::new(1, 0); let mut temp: Vec = Vec::new(); @@ -167,7 +177,7 @@ impl CommonTables for CompiledScriptMut { } } -impl CommonTables for CompiledModule { +impl CommonTables for CompiledModuleMut { fn get_module_handles(&self) -> &[ModuleHandle] { &self.module_handles } @@ -814,7 +824,7 @@ impl ModuleSerializer { } } - fn serialize(&mut self, binary: &mut Vec, module: &CompiledModule) -> Result<()> { + fn serialize(&mut self, binary: &mut Vec, module: &CompiledModuleMut) -> Result<()> { self.common.serialize_common(binary, module)?; self.serialize_struct_definitions(binary, &module.struct_defs)?; self.serialize_field_definitions(binary, &module.field_defs)?; diff --git a/language/vm/tests/serializer_tests.rs b/language/vm/tests/serializer_tests.rs index 1059c54d2a45..a93807460a1c 100644 --- a/language/vm/tests/serializer_tests.rs +++ b/language/vm/tests/serializer_tests.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use proptest::prelude::*; -use vm::file_format::CompiledModule; +use vm::file_format::{CompiledModule, CompiledModuleMut}; proptest! { #[test] @@ -23,11 +23,11 @@ proptest! { /// Make sure that garbage inputs don't crash the serializer and deserializer. #[test] - fn garbage_inputs(module in any_with::(16)) { + fn garbage_inputs(module in any_with::(16)) { let mut serialized = Vec::with_capacity(65536); module.serialize(&mut serialized).expect("serialization should work"); - let deserialized_module = CompiledModule::deserialize_no_check_bounds(&serialized) + let deserialized_module = CompiledModuleMut::deserialize_no_check_bounds(&serialized) .expect("deserialization should work"); prop_assert_eq!(module, deserialized_module); } diff --git a/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs b/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs index cee97b790adf..773093240cf7 100644 --- a/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs +++ b/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs @@ -4,7 +4,10 @@ use crate::identifier::resource_storage_key; use canonical_serialization::{SimpleDeserializer, SimpleSerializer}; use proptest::prelude::*; -use vm::file_format::{CompiledModule, StructDefinitionIndex, TableIndex}; +use vm::{ + access::ModuleAccess, + file_format::{CompiledModule, StructDefinitionIndex, TableIndex}, +}; proptest! { #[test] @@ -16,7 +19,7 @@ proptest! { }; prop_assert_eq!(code_key, deserialized_code_key); - for i in 0..module.struct_defs.len() { + for i in 0..module.struct_defs().len() { let struct_key = resource_storage_key(&module, StructDefinitionIndex::new(i as TableIndex)); let deserialized_struct_key = { let serialized_key = SimpleSerializer::>::serialize(&struct_key).unwrap(); diff --git a/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs b/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs index 07a7a60b650b..3cf96de1f3ed 100644 --- a/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs +++ b/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs @@ -13,7 +13,7 @@ use vm::file_format::*; use vm_cache_map::Arena; fn test_module(name: String) -> CompiledModule { - CompiledModule { + CompiledModuleMut { module_handles: vec![ModuleHandle { name: StringPoolIndex::new(0), address: AddressPoolIndex::new(0), @@ -70,6 +70,8 @@ fn test_module(name: String) -> CompiledModule { byte_array_pool: vec![], address_pool: vec![AccountAddress::default()], } + .freeze() + .expect("test module should satisfy bounds checker") } fn test_script() -> CompiledScript { diff --git a/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs b/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs index 848a6497075e..056afbb188dd 100644 --- a/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs +++ b/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs @@ -9,10 +9,10 @@ use std::collections::HashMap; use types::{access_path::AccessPath, account_address::AccountAddress, byte_array::ByteArray}; use vm::{ file_format::{ - AddressPoolIndex, Bytecode, CodeUnit, CompiledModule, CompiledScript, CompiledScriptMut, - FunctionDefinition, FunctionHandle, FunctionHandleIndex, FunctionSignature, - FunctionSignatureIndex, LocalsSignature, LocalsSignatureIndex, ModuleHandle, - ModuleHandleIndex, SignatureToken, StringPoolIndex, + AddressPoolIndex, Bytecode, CodeUnit, CompiledModule, CompiledModuleMut, CompiledScript, + CompiledScriptMut, FunctionDefinition, FunctionHandle, FunctionHandleIndex, + FunctionSignature, FunctionSignatureIndex, LocalsSignature, LocalsSignatureIndex, + ModuleHandle, ModuleHandleIndex, SignatureToken, StringPoolIndex, }, transaction_metadata::TransactionMetadata, }; @@ -567,7 +567,7 @@ fn fake_module_with_calls(sigs: Vec<(Vec, FunctionSignature)>) - }) .collect(); let (local_sigs, function_sigs): (Vec<_>, Vec<_>) = sigs.into_iter().unzip(); - CompiledModule { + CompiledModuleMut { function_defs, field_defs: vec![], struct_defs: vec![], @@ -585,6 +585,8 @@ fn fake_module_with_calls(sigs: Vec<(Vec, FunctionSignature)>) - byte_array_pool: vec![], address_pool: vec![AccountAddress::default()], } + .freeze() + .expect("test module should satisfy the bounds checker") } #[test] diff --git a/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs b/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs index 04bc00ecb927..059417655b8a 100644 --- a/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs +++ b/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs @@ -3,7 +3,7 @@ use crate::{fuzz_targets::new_value, FuzzTargetImpl}; use proptest::{prelude::*, test_runner::TestRunner}; -use vm::file_format::CompiledModule; +use vm::file_format::{CompiledModule, CompiledModuleMut}; #[derive(Clone, Debug, Default)] pub struct CompiledModuleTarget; @@ -18,7 +18,7 @@ impl FuzzTargetImpl for CompiledModuleTarget { } fn generate(&self, runner: &mut TestRunner) -> Vec { - let value = new_value(runner, any_with::(16)); + let value = new_value(runner, any_with::(16)); let mut out = vec![]; value .serialize(&mut out) @@ -27,7 +27,8 @@ impl FuzzTargetImpl for CompiledModuleTarget { } fn fuzz(&self, data: &[u8]) { - // Errors are OK -- the fuzzer cares about panics and OOMs. + // Errors are OK -- the fuzzer cares about panics and OOMs. Note that + // `CompiledModule::deserialize` also runs the bounds checker, which is desirable here. let _ = CompiledModule::deserialize(data); } }