diff --git a/o1vm/src/interpreters/riscv32im/interpreter.rs b/o1vm/src/interpreters/riscv32im/interpreter.rs index 0b0adab521..c08bd96f08 100644 --- a/o1vm/src/interpreters/riscv32im/interpreter.rs +++ b/o1vm/src/interpreters/riscv32im/interpreter.rs @@ -2133,8 +2133,7 @@ pub fn interpret_stype(env: &mut Env, instr: SInstruction) /// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) pub fn interpret_sbtype(env: &mut Env, instr: SBInstruction) { let instruction_pointer = env.get_instruction_pointer(); - let _next_instruction_pointer = env.get_next_instruction_pointer(); - /* read instruction from ip address */ + let next_instruction_pointer = env.get_next_instruction_pointer(); let instruction = { let v0 = env.read_memory(&instruction_pointer); let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); @@ -2203,8 +2202,8 @@ pub fn interpret_sbtype(env: &mut Env, instr: SBInstruction - (imm11.clone() * Env::constant(1 << 7)) // imm11 at bits 8 - (imm1_4.clone() * Env::constant(1 << 8)) // imm1_4 at bits 9-11 - (funct3 * Env::constant(1 << 11)) // funct3 at bits 11-14 - - (rs1 * Env::constant(1 << 14)) // rs1 at bits 15-20 - - (rs2 * Env::constant(1 << 19)) // rs2 at bits 20-24 + - (rs1.clone() * Env::constant(1 << 14)) // rs1 at bits 15-20 + - (rs2.clone() * Env::constant(1 << 19)) // rs2 at bits 20-24 - (imm5_10.clone() * Env::constant(1 << 24)) // imm5_10 at bits 25-30 - (imm12.clone() * Env::constant(1 << 31)), // imm12 at bits 31 ); @@ -2215,14 +2214,35 @@ pub fn interpret_sbtype(env: &mut Env, instr: SBInstruction + (imm1_4 * Env::constant(1 << 1)) }; // extra bit is because the 0th bit in the immediate is always 0 i.e you cannot jump to an odd address - let _imm0_12 = env.sign_extend(&imm0_12, 13); + let imm0_12 = env.sign_extend(&imm0_12, 13); match instr { SBInstruction::BranchEq => { unimplemented!("BranchEq") } SBInstruction::BranchNeq => { - unimplemented!("BranchNeq") + // bne: if (x[rs1] != x[rs2]) pc += sext(offset) + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + + let equals = env.equal(&local_rs1, &local_rs2); + let offset = equals.clone() * Env::constant(4) + (Env::constant(1) - equals) * imm0_12; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); } SBInstruction::BranchLessThan => { unimplemented!("BranchLessThan")