From 0800a90bec66d6d106a3d2010735e9bf7e4706c5 Mon Sep 17 00:00:00 2001 From: martyall Date: Fri, 10 Jan 2025 12:56:21 -0800 Subject: [PATCH] switch implementation of WritePreimageKey syscall, use one register per byte for the preimage key. --- o1vm/src/interpreters/mips/column.rs | 2 +- o1vm/src/interpreters/mips/interpreter.rs | 192 ++++++++-------------- o1vm/src/interpreters/mips/registers.rs | 17 +- o1vm/src/interpreters/mips/tests.rs | 16 +- o1vm/src/interpreters/mips/witness.rs | 17 +- o1vm/src/pickles/mod.rs | 2 +- 6 files changed, 94 insertions(+), 152 deletions(-) diff --git a/o1vm/src/interpreters/mips/column.rs b/o1vm/src/interpreters/mips/column.rs index 54743f4dbf..8d46775ee3 100644 --- a/o1vm/src/interpreters/mips/column.rs +++ b/o1vm/src/interpreters/mips/column.rs @@ -8,7 +8,7 @@ use strum::EnumCount; use super::{ITypeInstruction, JTypeInstruction, RTypeInstruction}; -pub(crate) const SCRATCH_SIZE_WITHOUT_KECCAK: usize = 50; +pub(crate) const SCRATCH_SIZE_WITHOUT_KECCAK: usize = 95; /// The number of hashes performed so far in the block pub(crate) const MIPS_HASH_COUNTER_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK; /// The number of bytes of the preimage that have been read so far in this hash diff --git a/o1vm/src/interpreters/mips/interpreter.rs b/o1vm/src/interpreters/mips/interpreter.rs index 5ba96d9885..45d1feba96 100644 --- a/o1vm/src/interpreters/mips/interpreter.rs +++ b/o1vm/src/interpreters/mips/interpreter.rs @@ -2,11 +2,12 @@ use crate::{ cannon::PAGE_ADDRESS_SIZE, interpreters::mips::registers::{ REGISTER_CURRENT_IP, REGISTER_HEAP_POINTER, REGISTER_HI, REGISTER_LO, REGISTER_NEXT_IP, - REGISTER_PREIMAGE_KEY_END, REGISTER_PREIMAGE_OFFSET, + REGISTER_PREIMAGE_KEY_START, REGISTER_PREIMAGE_KEY_WRITE_OFFSET, REGISTER_PREIMAGE_OFFSET, }, lookups::{Lookup, LookupTableIDs}, }; use ark_ff::{One, Zero}; +use itertools::enumerate; use strum::{EnumCount, IntoEnumIterator}; use strum_macros::{EnumCount, EnumIter}; @@ -1211,149 +1212,96 @@ pub fn interpret_rtype(env: &mut Env, instr: RTypeInstructi env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); } RTypeInstruction::SyscallWritePreimage => { - let addr = env.read_register(&Env::constant(5)); - let write_length = env.read_register(&Env::constant(6)); + let requested_addr = env.read_register(&Env::constant(5)); + let requested_write_length = env.read_register(&Env::constant(6)); - // Cannon assumes that the remaining `byte_length` represents how much remains to be - // read (i.e. all write calls send the full data in one syscall, and attempt to retry - // with the rest until there is a success). This also simplifies the implementation - // here, so we will follow suit. - let bytes_to_preserve_in_register = { + // Offset within the aligned word, e.g. alignment = 1 means we can only write 3 bytes. + let alignment = { let pos = env.alloc_scratch(); - unsafe { env.bitmask(&write_length, 2, 0, pos) } + unsafe { env.and_witness(&requested_addr, &Env::constant(3), pos) } }; - env.range_check2(&bytes_to_preserve_in_register); - let register_idx = { - let registers_left_to_write_after_this = { - let pos = env.alloc_scratch(); - // The virtual register is 32 bits wide, so we can just read 6 bytes. If the - // register has an incorrect value, it will be unprovable and we'll fault. - unsafe { env.bitmask(&write_length, 6, 2, pos) } - }; - env.range_check8(®isters_left_to_write_after_this, 4); - Env::constant(REGISTER_PREIMAGE_KEY_END as u32) - registers_left_to_write_after_this - }; - - let [r0, r1, r2, r3] = { - let register_value = { - let initial_register_value = env.read_register(®ister_idx); - // We should clear the register if our offset into the read will replace all of its - // bytes. - let should_clear_register = env.is_zero(&bytes_to_preserve_in_register); + // Align the read address to the word boundary + let read_address = { + let pos = env.alloc_scratch(); + unsafe { env.and_witness(&requested_addr, &Env::constant(0xFFFFFFFC), pos) } + }; + // Our actual write length given the word boundary restriction + let write_length = { + // number of bytes we are allowed to write till we reach the end of the word boundary. + let num_available_bytes = Env::constant(4) - alignment.clone(); + let requested_too_much = { let pos = env.alloc_scratch(); - env.copy( - &((Env::constant(1) - should_clear_register) * initial_register_value), - pos, - ) + unsafe { + env.test_less_than(&num_available_bytes, &requested_write_length, pos) + } }; - [ - { - let pos = env.alloc_scratch(); - unsafe { env.bitmask(®ister_value, 32, 24, pos) } - }, - { - let pos = env.alloc_scratch(); - unsafe { env.bitmask(®ister_value, 24, 16, pos) } - }, - { - let pos = env.alloc_scratch(); - unsafe { env.bitmask(®ister_value, 16, 8, pos) } - }, - { - let pos = env.alloc_scratch(); - unsafe { env.bitmask(®ister_value, 8, 0, pos) } - }, - ] + + (Env::constant(1) - requested_too_much.clone()) * requested_write_length + + requested_too_much * num_available_bytes }; - env.lookup_8bits(&r0); - env.lookup_8bits(&r1); - env.lookup_8bits(&r2); - env.lookup_8bits(&r3); - // We choose our read address so that the bytes we read come aligned with the target - // bytes in the register, to avoid an expensive bitshift. - let read_address = addr.clone() - bytes_to_preserve_in_register.clone(); - - let m0 = env.read_memory(&read_address); - let m1 = env.read_memory(&(read_address.clone() + Env::constant(1))); - let m2 = env.read_memory(&(read_address.clone() + Env::constant(2))); - let m3 = env.read_memory(&(read_address.clone() + Env::constant(3))); - - // Now, for some complexity. From the perspective of the write operation, we should be - // reading the `4 - bytes_to_preserve_in_register`. However, to match cannon 1:1, we - // only want to read the bytes up to the end of the current word. - let [overwrite_0, overwrite_1, overwrite_2, overwrite_3] = { - let next_word_addr = { - let byte_subaddr = { - // FIXME: Requires a range check + let write_flags = { + let cap = alignment.clone() + write_length.clone(); + // i \elem [alignment, alignment + write_length) + let mut make_condition = |i: u32| { + let c1 = { let pos = env.alloc_scratch(); - unsafe { env.bitmask(&addr, 2, 0, pos) } + Env::constant(1) + - unsafe { env.test_less_than(&Env::constant(i), &alignment, pos) } }; - env.range_check2(&byte_subaddr); - addr.clone() + Env::constant(4) - byte_subaddr - }; - let overwrite_0 = { - // We always write the first byte if we're not preserving it, since it will - // have been read from `addr`. - env.equal(&bytes_to_preserve_in_register, &Env::constant(0)) - }; - let overwrite_1 = { - // We write the second byte if: - // we wrote the first byte - overwrite_0.clone() - // and this isn't the start of the next word (which implies `overwrite_0`), - - env.equal(&(read_address.clone() + Env::constant(1)), &next_word_addr) - // or this byte was read from `addr` - + env.equal(&bytes_to_preserve_in_register, &Env::constant(1)) + let c2 = { + let pos = env.alloc_scratch(); + unsafe { env.test_less_than(&Env::constant(i), &cap, pos) } + }; + let pos = env.alloc_scratch(); + unsafe { env.and_witness(&c1, &c2, pos) } }; - let overwrite_2 = { - // We write the third byte if: - // we wrote the second byte - overwrite_1.clone() - // and this isn't the start of the next word (which implies `overwrite_1`), - - env.equal(&(read_address.clone() + Env::constant(2)), &next_word_addr) - // or this byte was read from `addr` - + env.equal(&bytes_to_preserve_in_register, &Env::constant(2)) + let write_flag_0 = make_condition(0); + let write_flag_1 = make_condition(1); + let write_flag_2 = make_condition(2); + let write_flag_3 = make_condition(3); + [write_flag_0, write_flag_1, write_flag_2, write_flag_3] + }; + + for (index, write_flag) in enumerate(write_flags) { + let offset = + env.read_register(&Env::constant(REGISTER_PREIMAGE_KEY_WRITE_OFFSET as u32)); + + let current_write_register = + Env::constant(REGISTER_PREIMAGE_KEY_START as u32) + offset.clone(); + + let byte_to_write = { + let curr = env.read_register(¤t_write_register); + let m = env.read_memory(&(read_address.clone() + Env::constant(index as u32))); + write_flag.clone() * m + (Env::constant(1) - write_flag.clone()) * curr }; - let overwrite_3 = { - // We write the fourth byte if: - // we wrote the third byte - overwrite_2.clone() - // and this isn't the start of the next word (which implies `overwrite_2`), - - env.equal(&(read_address.clone() + Env::constant(3)), &next_word_addr) - // or this byte was read from `addr` - + env.equal(&bytes_to_preserve_in_register, &Env::constant(3)) + + env.write_register(¤t_write_register, byte_to_write); + + let new_offset = { + let quot = env.alloc_scratch(); + let rem = env.alloc_scratch(); + let (_, r) = unsafe { + env.divmod(&(offset + write_flag), &Env::constant(32), quot, rem) + }; + r }; - [overwrite_0, overwrite_1, overwrite_2, overwrite_3] - }; - let value = { - let value = ((overwrite_0.clone() * m0 - + (Env::constant(1) - overwrite_0.clone()) * r0) - * Env::constant(1 << 24)) - + ((overwrite_1.clone() * m1 + (Env::constant(1) - overwrite_1.clone()) * r1) - * Env::constant(1 << 16)) - + ((overwrite_2.clone() * m2 + (Env::constant(1) - overwrite_2.clone()) * r2) - * Env::constant(1 << 8)) - + (overwrite_3.clone() * m3 + (Env::constant(1) - overwrite_3.clone()) * r3); - let pos = env.alloc_scratch(); - env.copy(&value, pos) - }; + env.write_register( + &Env::constant(REGISTER_PREIMAGE_KEY_WRITE_OFFSET as u32), + new_offset, + ); + } - // Update the preimage key. - env.write_register(®ister_idx, value); // Reset the preimage offset. env.write_register( &Env::constant(REGISTER_PREIMAGE_OFFSET as u32), Env::constant(0u32), ); // Return the number of bytes read. - env.write_register( - &Env::constant(2), - overwrite_0 + overwrite_1 + overwrite_2 + overwrite_3, - ); + env.write_register(&Env::constant(2), write_length); // Set the error register to 0. env.write_register(&Env::constant(7), Env::constant(0u32)); diff --git a/o1vm/src/interpreters/mips/registers.rs b/o1vm/src/interpreters/mips/registers.rs index 0dcfe9516b..b388924522 100644 --- a/o1vm/src/interpreters/mips/registers.rs +++ b/o1vm/src/interpreters/mips/registers.rs @@ -6,11 +6,12 @@ pub const REGISTER_LO: usize = 33; pub const REGISTER_CURRENT_IP: usize = 34; pub const REGISTER_NEXT_IP: usize = 35; pub const REGISTER_HEAP_POINTER: usize = 36; -pub const REGISTER_PREIMAGE_KEY_START: usize = 37; -pub const REGISTER_PREIMAGE_KEY_END: usize = REGISTER_PREIMAGE_KEY_START + 8 /* 37 + 8 = 45 */; -pub const REGISTER_PREIMAGE_OFFSET: usize = 45; +pub const REGISTER_PREIMAGE_KEY_WRITE_OFFSET: usize = 37; +pub const REGISTER_PREIMAGE_KEY_START: usize = 38; +pub const REGISTER_PREIMAGE_KEY_END: usize = REGISTER_PREIMAGE_KEY_START + 32 /* 38 + 32 = 70 */; +pub const REGISTER_PREIMAGE_OFFSET: usize = 70; -pub const NUM_REGISTERS: usize = 46; +pub const NUM_REGISTERS: usize = 71; /// This represents the internal state of the virtual machine. #[derive(Clone, Default, Debug, Serialize, Deserialize)] @@ -21,7 +22,8 @@ pub struct Registers { pub current_instruction_pointer: T, pub next_instruction_pointer: T, pub heap_pointer: T, - pub preimage_key: [T; 8], + pub preimage_key_write_offset: T, + pub preimage_key: [T; 32], pub preimage_offset: T, } @@ -35,6 +37,7 @@ impl Registers { &self.current_instruction_pointer, &self.next_instruction_pointer, &self.heap_pointer, + &self.preimage_key_write_offset, ]) .chain(self.preimage_key.iter()) .chain([&self.preimage_offset]) @@ -57,6 +60,8 @@ impl Index for Registers { &self.next_instruction_pointer } else if index == REGISTER_HEAP_POINTER { &self.heap_pointer + } else if index == REGISTER_PREIMAGE_KEY_WRITE_OFFSET { + &self.preimage_key_write_offset } else if (REGISTER_PREIMAGE_KEY_START..REGISTER_PREIMAGE_KEY_END).contains(&index) { &self.preimage_key[index - REGISTER_PREIMAGE_KEY_START] } else if index == REGISTER_PREIMAGE_OFFSET { @@ -81,6 +86,8 @@ impl IndexMut for Registers { &mut self.next_instruction_pointer } else if index == REGISTER_HEAP_POINTER { &mut self.heap_pointer + } else if index == REGISTER_PREIMAGE_KEY_WRITE_OFFSET { + &mut self.preimage_key_write_offset } else if (REGISTER_PREIMAGE_KEY_START..REGISTER_PREIMAGE_KEY_END).contains(&index) { &mut self.preimage_key[index - REGISTER_PREIMAGE_KEY_START] } else if index == REGISTER_PREIMAGE_OFFSET { diff --git a/o1vm/src/interpreters/mips/tests.rs b/o1vm/src/interpreters/mips/tests.rs index 1b212077d1..63f69e507d 100644 --- a/o1vm/src/interpreters/mips/tests.rs +++ b/o1vm/src/interpreters/mips/tests.rs @@ -83,19 +83,13 @@ mod rtype { 0x05, 0x67, 0xbd, 0xa4, 0x08, 0x77, 0xa7, 0xe8, 0x5d, 0xce, 0xb6, 0xff, 0x1f, 0x37, 0x48, 0x0f, 0xef, 0x3d, ]; - let chunks = preimage_key - .chunks(4) - .map(|chunk| { - ((chunk[0] as u32) << 24) - + ((chunk[1] as u32) << 16) - + ((chunk[2] as u32) << 8) - + (chunk[3] as u32) - }) - .collect::>(); - dummy_env.registers.preimage_key = std::array::from_fn(|i| chunks[i]); + dummy_env.registers.preimage_key = preimage_key; // The whole preimage - let preimage = dummy_env.preimage_oracle.get_preimage(preimage_key).get(); + let preimage = dummy_env + .preimage_oracle + .get_preimage(preimage_key.map(|x| x as u8)) + .get(); // Total number of bytes that need to be processed (includes length) let total_length = 8 + preimage.len() as u32; diff --git a/o1vm/src/interpreters/mips/witness.rs b/o1vm/src/interpreters/mips/witness.rs index eb7db1b967..4823d65313 100644 --- a/o1vm/src/interpreters/mips/witness.rs +++ b/o1vm/src/interpreters/mips/witness.rs @@ -614,13 +614,7 @@ impl InterpreterEnv for Env Self::Variable { // The beginning of the syscall if self.registers.preimage_offset == 0 { - let mut preimage_key = [0u8; 32]; - for i in 0..8 { - let bytes = u32::to_be_bytes(self.registers.preimage_key[i]); - for j in 0..4 { - preimage_key[4 * i + j] = bytes[j] - } - } + let preimage_key = self.registers.preimage_key.map(|x| x as u8); let preimage = self.preimage_oracle.get_preimage(preimage_key).get(); self.preimage = Some(preimage.clone()); self.preimage_key = Some(preimage_key); @@ -846,12 +840,10 @@ impl Env { .collect::>(); let initial_registers = { - let preimage_key = { - let mut preimage_key = [0u32; 8]; + let preimage_key: [u32; 32] = { + let mut preimage_key = [0u32; 32]; for (i, preimage_key_word) in preimage_key.iter_mut().enumerate() { - *preimage_key_word = u32::from_be_bytes( - state.preimage_key[i * 4..(i + 1) * 4].try_into().unwrap(), - ) + *preimage_key_word = state.preimage_key[i] as u32; } preimage_key }; @@ -862,6 +854,7 @@ impl Env { current_instruction_pointer: initial_instruction_pointer, next_instruction_pointer, heap_pointer: state.heap, + preimage_key_write_offset: 0, preimage_key, preimage_offset: state.preimage_offset, } diff --git a/o1vm/src/pickles/mod.rs b/o1vm/src/pickles/mod.rs index 193162e90b..dcf28f9c43 100644 --- a/o1vm/src/pickles/mod.rs +++ b/o1vm/src/pickles/mod.rs @@ -31,7 +31,7 @@ pub const DEGREE_QUOTIENT_POLYNOMIAL: u64 = 7; /// Total number of constraints for all instructions, including the constraints /// added for the selectors. -pub const TOTAL_NUMBER_OF_CONSTRAINTS: usize = 464; +pub const TOTAL_NUMBER_OF_CONSTRAINTS: usize = 467; #[cfg(test)] mod tests;