Skip to content

Commit

Permalink
switch implementation of WritePreimageKey syscall, use one register p…
Browse files Browse the repository at this point in the history
…er byte for the preimage key.
  • Loading branch information
martyall committed Jan 10, 2025
1 parent f7caf12 commit 0800a90
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 152 deletions.
2 changes: 1 addition & 1 deletion o1vm/src/interpreters/mips/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use strum::EnumCount;

use super::{ITypeInstruction, JTypeInstruction, RTypeInstruction};

pub(crate) const SCRATCH_SIZE_WITHOUT_KECCAK: usize = 50;
pub(crate) const SCRATCH_SIZE_WITHOUT_KECCAK: usize = 95;
/// The number of hashes performed so far in the block
pub(crate) const MIPS_HASH_COUNTER_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK;
/// The number of bytes of the preimage that have been read so far in this hash
Expand Down
192 changes: 70 additions & 122 deletions o1vm/src/interpreters/mips/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use crate::{
cannon::PAGE_ADDRESS_SIZE,
interpreters::mips::registers::{
REGISTER_CURRENT_IP, REGISTER_HEAP_POINTER, REGISTER_HI, REGISTER_LO, REGISTER_NEXT_IP,
REGISTER_PREIMAGE_KEY_END, REGISTER_PREIMAGE_OFFSET,
REGISTER_PREIMAGE_KEY_START, REGISTER_PREIMAGE_KEY_WRITE_OFFSET, REGISTER_PREIMAGE_OFFSET,
},
lookups::{Lookup, LookupTableIDs},
};
use ark_ff::{One, Zero};
use itertools::enumerate;
use strum::{EnumCount, IntoEnumIterator};
use strum_macros::{EnumCount, EnumIter};

Expand Down Expand Up @@ -1211,149 +1212,96 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RTypeInstructi
env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
}
RTypeInstruction::SyscallWritePreimage => {
let addr = env.read_register(&Env::constant(5));
let write_length = env.read_register(&Env::constant(6));
let requested_addr = env.read_register(&Env::constant(5));
let requested_write_length = env.read_register(&Env::constant(6));

// Cannon assumes that the remaining `byte_length` represents how much remains to be
// read (i.e. all write calls send the full data in one syscall, and attempt to retry
// with the rest until there is a success). This also simplifies the implementation
// here, so we will follow suit.
let bytes_to_preserve_in_register = {
// Offset within the aligned word, e.g. alignment = 1 means we can only write 3 bytes.
let alignment = {
let pos = env.alloc_scratch();
unsafe { env.bitmask(&write_length, 2, 0, pos) }
unsafe { env.and_witness(&requested_addr, &Env::constant(3), pos) }
};
env.range_check2(&bytes_to_preserve_in_register);
let register_idx = {
let registers_left_to_write_after_this = {
let pos = env.alloc_scratch();
// The virtual register is 32 bits wide, so we can just read 6 bytes. If the
// register has an incorrect value, it will be unprovable and we'll fault.
unsafe { env.bitmask(&write_length, 6, 2, pos) }
};
env.range_check8(&registers_left_to_write_after_this, 4);
Env::constant(REGISTER_PREIMAGE_KEY_END as u32) - registers_left_to_write_after_this
};

let [r0, r1, r2, r3] = {
let register_value = {
let initial_register_value = env.read_register(&register_idx);

// We should clear the register if our offset into the read will replace all of its
// bytes.
let should_clear_register = env.is_zero(&bytes_to_preserve_in_register);
// Align the read address to the word boundary
let read_address = {
let pos = env.alloc_scratch();
unsafe { env.and_witness(&requested_addr, &Env::constant(0xFFFFFFFC), pos) }
};

// Our actual write length given the word boundary restriction
let write_length = {
// number of bytes we are allowed to write till we reach the end of the word boundary.
let num_available_bytes = Env::constant(4) - alignment.clone();
let requested_too_much = {
let pos = env.alloc_scratch();
env.copy(
&((Env::constant(1) - should_clear_register) * initial_register_value),
pos,
)
unsafe {
env.test_less_than(&num_available_bytes, &requested_write_length, pos)
}
};
[
{
let pos = env.alloc_scratch();
unsafe { env.bitmask(&register_value, 32, 24, pos) }
},
{
let pos = env.alloc_scratch();
unsafe { env.bitmask(&register_value, 24, 16, pos) }
},
{
let pos = env.alloc_scratch();
unsafe { env.bitmask(&register_value, 16, 8, pos) }
},
{
let pos = env.alloc_scratch();
unsafe { env.bitmask(&register_value, 8, 0, pos) }
},
]

(Env::constant(1) - requested_too_much.clone()) * requested_write_length
+ requested_too_much * num_available_bytes
};
env.lookup_8bits(&r0);
env.lookup_8bits(&r1);
env.lookup_8bits(&r2);
env.lookup_8bits(&r3);

// We choose our read address so that the bytes we read come aligned with the target
// bytes in the register, to avoid an expensive bitshift.
let read_address = addr.clone() - bytes_to_preserve_in_register.clone();

let m0 = env.read_memory(&read_address);
let m1 = env.read_memory(&(read_address.clone() + Env::constant(1)));
let m2 = env.read_memory(&(read_address.clone() + Env::constant(2)));
let m3 = env.read_memory(&(read_address.clone() + Env::constant(3)));

// Now, for some complexity. From the perspective of the write operation, we should be
// reading the `4 - bytes_to_preserve_in_register`. However, to match cannon 1:1, we
// only want to read the bytes up to the end of the current word.
let [overwrite_0, overwrite_1, overwrite_2, overwrite_3] = {
let next_word_addr = {
let byte_subaddr = {
// FIXME: Requires a range check
let write_flags = {
let cap = alignment.clone() + write_length.clone();
// i \elem [alignment, alignment + write_length)
let mut make_condition = |i: u32| {
let c1 = {
let pos = env.alloc_scratch();
unsafe { env.bitmask(&addr, 2, 0, pos) }
Env::constant(1)
- unsafe { env.test_less_than(&Env::constant(i), &alignment, pos) }
};
env.range_check2(&byte_subaddr);
addr.clone() + Env::constant(4) - byte_subaddr
};
let overwrite_0 = {
// We always write the first byte if we're not preserving it, since it will
// have been read from `addr`.
env.equal(&bytes_to_preserve_in_register, &Env::constant(0))
};
let overwrite_1 = {
// We write the second byte if:
// we wrote the first byte
overwrite_0.clone()
// and this isn't the start of the next word (which implies `overwrite_0`),
- env.equal(&(read_address.clone() + Env::constant(1)), &next_word_addr)
// or this byte was read from `addr`
+ env.equal(&bytes_to_preserve_in_register, &Env::constant(1))
let c2 = {
let pos = env.alloc_scratch();
unsafe { env.test_less_than(&Env::constant(i), &cap, pos) }
};
let pos = env.alloc_scratch();
unsafe { env.and_witness(&c1, &c2, pos) }
};
let overwrite_2 = {
// We write the third byte if:
// we wrote the second byte
overwrite_1.clone()
// and this isn't the start of the next word (which implies `overwrite_1`),
- env.equal(&(read_address.clone() + Env::constant(2)), &next_word_addr)
// or this byte was read from `addr`
+ env.equal(&bytes_to_preserve_in_register, &Env::constant(2))
let write_flag_0 = make_condition(0);
let write_flag_1 = make_condition(1);
let write_flag_2 = make_condition(2);
let write_flag_3 = make_condition(3);
[write_flag_0, write_flag_1, write_flag_2, write_flag_3]
};

for (index, write_flag) in enumerate(write_flags) {
let offset =
env.read_register(&Env::constant(REGISTER_PREIMAGE_KEY_WRITE_OFFSET as u32));

let current_write_register =
Env::constant(REGISTER_PREIMAGE_KEY_START as u32) + offset.clone();

let byte_to_write = {
let curr = env.read_register(&current_write_register);
let m = env.read_memory(&(read_address.clone() + Env::constant(index as u32)));
write_flag.clone() * m + (Env::constant(1) - write_flag.clone()) * curr
};
let overwrite_3 = {
// We write the fourth byte if:
// we wrote the third byte
overwrite_2.clone()
// and this isn't the start of the next word (which implies `overwrite_2`),
- env.equal(&(read_address.clone() + Env::constant(3)), &next_word_addr)
// or this byte was read from `addr`
+ env.equal(&bytes_to_preserve_in_register, &Env::constant(3))

env.write_register(&current_write_register, byte_to_write);

let new_offset = {
let quot = env.alloc_scratch();
let rem = env.alloc_scratch();
let (_, r) = unsafe {
env.divmod(&(offset + write_flag), &Env::constant(32), quot, rem)
};
r
};
[overwrite_0, overwrite_1, overwrite_2, overwrite_3]
};

let value = {
let value = ((overwrite_0.clone() * m0
+ (Env::constant(1) - overwrite_0.clone()) * r0)
* Env::constant(1 << 24))
+ ((overwrite_1.clone() * m1 + (Env::constant(1) - overwrite_1.clone()) * r1)
* Env::constant(1 << 16))
+ ((overwrite_2.clone() * m2 + (Env::constant(1) - overwrite_2.clone()) * r2)
* Env::constant(1 << 8))
+ (overwrite_3.clone() * m3 + (Env::constant(1) - overwrite_3.clone()) * r3);
let pos = env.alloc_scratch();
env.copy(&value, pos)
};
env.write_register(
&Env::constant(REGISTER_PREIMAGE_KEY_WRITE_OFFSET as u32),
new_offset,
);
}

// Update the preimage key.
env.write_register(&register_idx, value);
// Reset the preimage offset.
env.write_register(
&Env::constant(REGISTER_PREIMAGE_OFFSET as u32),
Env::constant(0u32),
);
// Return the number of bytes read.
env.write_register(
&Env::constant(2),
overwrite_0 + overwrite_1 + overwrite_2 + overwrite_3,
);
env.write_register(&Env::constant(2), write_length);
// Set the error register to 0.
env.write_register(&Env::constant(7), Env::constant(0u32));

Expand Down
17 changes: 12 additions & 5 deletions o1vm/src/interpreters/mips/registers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ pub const REGISTER_LO: usize = 33;
pub const REGISTER_CURRENT_IP: usize = 34;
pub const REGISTER_NEXT_IP: usize = 35;
pub const REGISTER_HEAP_POINTER: usize = 36;
pub const REGISTER_PREIMAGE_KEY_START: usize = 37;
pub const REGISTER_PREIMAGE_KEY_END: usize = REGISTER_PREIMAGE_KEY_START + 8 /* 37 + 8 = 45 */;
pub const REGISTER_PREIMAGE_OFFSET: usize = 45;
pub const REGISTER_PREIMAGE_KEY_WRITE_OFFSET: usize = 37;
pub const REGISTER_PREIMAGE_KEY_START: usize = 38;
pub const REGISTER_PREIMAGE_KEY_END: usize = REGISTER_PREIMAGE_KEY_START + 32 /* 38 + 32 = 70 */;
pub const REGISTER_PREIMAGE_OFFSET: usize = 70;

pub const NUM_REGISTERS: usize = 46;
pub const NUM_REGISTERS: usize = 71;

/// This represents the internal state of the virtual machine.
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
Expand All @@ -21,7 +22,8 @@ pub struct Registers<T> {
pub current_instruction_pointer: T,
pub next_instruction_pointer: T,
pub heap_pointer: T,
pub preimage_key: [T; 8],
pub preimage_key_write_offset: T,
pub preimage_key: [T; 32],
pub preimage_offset: T,
}

Expand All @@ -35,6 +37,7 @@ impl<T> Registers<T> {
&self.current_instruction_pointer,
&self.next_instruction_pointer,
&self.heap_pointer,
&self.preimage_key_write_offset,
])
.chain(self.preimage_key.iter())
.chain([&self.preimage_offset])
Expand All @@ -57,6 +60,8 @@ impl<T: Clone> Index<usize> for Registers<T> {
&self.next_instruction_pointer
} else if index == REGISTER_HEAP_POINTER {
&self.heap_pointer
} else if index == REGISTER_PREIMAGE_KEY_WRITE_OFFSET {
&self.preimage_key_write_offset
} else if (REGISTER_PREIMAGE_KEY_START..REGISTER_PREIMAGE_KEY_END).contains(&index) {
&self.preimage_key[index - REGISTER_PREIMAGE_KEY_START]
} else if index == REGISTER_PREIMAGE_OFFSET {
Expand All @@ -81,6 +86,8 @@ impl<T: Clone> IndexMut<usize> for Registers<T> {
&mut self.next_instruction_pointer
} else if index == REGISTER_HEAP_POINTER {
&mut self.heap_pointer
} else if index == REGISTER_PREIMAGE_KEY_WRITE_OFFSET {
&mut self.preimage_key_write_offset
} else if (REGISTER_PREIMAGE_KEY_START..REGISTER_PREIMAGE_KEY_END).contains(&index) {
&mut self.preimage_key[index - REGISTER_PREIMAGE_KEY_START]
} else if index == REGISTER_PREIMAGE_OFFSET {
Expand Down
16 changes: 5 additions & 11 deletions o1vm/src/interpreters/mips/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,13 @@ mod rtype {
0x05, 0x67, 0xbd, 0xa4, 0x08, 0x77, 0xa7, 0xe8, 0x5d, 0xce, 0xb6, 0xff, 0x1f, 0x37,
0x48, 0x0f, 0xef, 0x3d,
];
let chunks = preimage_key
.chunks(4)
.map(|chunk| {
((chunk[0] as u32) << 24)
+ ((chunk[1] as u32) << 16)
+ ((chunk[2] as u32) << 8)
+ (chunk[3] as u32)
})
.collect::<Vec<_>>();
dummy_env.registers.preimage_key = std::array::from_fn(|i| chunks[i]);
dummy_env.registers.preimage_key = preimage_key;

// The whole preimage
let preimage = dummy_env.preimage_oracle.get_preimage(preimage_key).get();
let preimage = dummy_env
.preimage_oracle
.get_preimage(preimage_key.map(|x| x as u8))
.get();

// Total number of bytes that need to be processed (includes length)
let total_length = 8 + preimage.len() as u32;
Expand Down
17 changes: 5 additions & 12 deletions o1vm/src/interpreters/mips/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,7 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
) -> Self::Variable {
// The beginning of the syscall
if self.registers.preimage_offset == 0 {
let mut preimage_key = [0u8; 32];
for i in 0..8 {
let bytes = u32::to_be_bytes(self.registers.preimage_key[i]);
for j in 0..4 {
preimage_key[4 * i + j] = bytes[j]
}
}
let preimage_key = self.registers.preimage_key.map(|x| x as u8);
let preimage = self.preimage_oracle.get_preimage(preimage_key).get();
self.preimage = Some(preimage.clone());
self.preimage_key = Some(preimage_key);
Expand Down Expand Up @@ -846,12 +840,10 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
.collect::<Vec<_>>();

let initial_registers = {
let preimage_key = {
let mut preimage_key = [0u32; 8];
let preimage_key: [u32; 32] = {
let mut preimage_key = [0u32; 32];
for (i, preimage_key_word) in preimage_key.iter_mut().enumerate() {
*preimage_key_word = u32::from_be_bytes(
state.preimage_key[i * 4..(i + 1) * 4].try_into().unwrap(),
)
*preimage_key_word = state.preimage_key[i] as u32;
}
preimage_key
};
Expand All @@ -862,6 +854,7 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
current_instruction_pointer: initial_instruction_pointer,
next_instruction_pointer,
heap_pointer: state.heap,
preimage_key_write_offset: 0,
preimage_key,
preimage_offset: state.preimage_offset,
}
Expand Down
2 changes: 1 addition & 1 deletion o1vm/src/pickles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub const DEGREE_QUOTIENT_POLYNOMIAL: u64 = 7;

/// Total number of constraints for all instructions, including the constraints
/// added for the selectors.
pub const TOTAL_NUMBER_OF_CONSTRAINTS: usize = 464;
pub const TOTAL_NUMBER_OF_CONSTRAINTS: usize = 467;

#[cfg(test)]
mod tests;

0 comments on commit 0800a90

Please sign in to comment.