Skip to main content

ab_riscv_interpreter/
rv64.rs

1//! Part of the interpreter responsible for RISC-V RV64 base instruction set
2
3pub mod b;
4pub mod m;
5#[cfg(test)]
6pub(crate) mod test_utils;
7#[cfg(test)]
8mod tests;
9pub mod zk;
10
11use crate::{
12    ExecutableInstruction, ExecutionError, InterpreterState, ProgramCounter,
13    SystemInstructionHandler, VirtualMemory,
14};
15use ab_riscv_macros::instruction_execution;
16use ab_riscv_primitives::instructions::Instruction;
17use ab_riscv_primitives::instructions::rv64::Rv64Instruction;
18use ab_riscv_primitives::registers::general_purpose::Register;
19use core::ops::ControlFlow;
20
21#[instruction_execution]
22impl<Reg, ExtState, Memory, PC, InstructionHandler, CustomError>
23    ExecutableInstruction<
24        InterpreterState<Reg, ExtState, Memory, PC, InstructionHandler, CustomError>,
25        CustomError,
26    > for Rv64Instruction<Reg>
27where
28    Reg: Register<Type = u64>,
29    [(); Reg::N]:,
30    Memory: VirtualMemory,
31    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
32    InstructionHandler: SystemInstructionHandler<Reg, Memory, PC, CustomError>,
33{
34    #[inline(always)]
35    fn execute(
36        self,
37        state: &mut InterpreterState<Reg, ExtState, Memory, PC, InstructionHandler, CustomError>,
38    ) -> Result<ControlFlow<()>, ExecutionError<Reg::Type, CustomError>> {
39        match self {
40            Self::Add { rd, rs1, rs2 } => {
41                let value = state.regs.read(rs1).wrapping_add(state.regs.read(rs2));
42                state.regs.write(rd, value);
43            }
44            Self::Sub { rd, rs1, rs2 } => {
45                let value = state.regs.read(rs1).wrapping_sub(state.regs.read(rs2));
46                state.regs.write(rd, value);
47            }
48            Self::Sll { rd, rs1, rs2 } => {
49                let shamt = state.regs.read(rs2) & 0x3f;
50                let value = state.regs.read(rs1) << shamt;
51                state.regs.write(rd, value);
52            }
53            Self::Slt { rd, rs1, rs2 } => {
54                let value = state.regs.read(rs1).cast_signed() < state.regs.read(rs2).cast_signed();
55                state.regs.write(rd, value as u64);
56            }
57            Self::Sltu { rd, rs1, rs2 } => {
58                let value = state.regs.read(rs1) < state.regs.read(rs2);
59                state.regs.write(rd, value as u64);
60            }
61            Self::Xor { rd, rs1, rs2 } => {
62                let value = state.regs.read(rs1) ^ state.regs.read(rs2);
63                state.regs.write(rd, value);
64            }
65            Self::Srl { rd, rs1, rs2 } => {
66                let shamt = state.regs.read(rs2) & 0x3f;
67                let value = state.regs.read(rs1) >> shamt;
68                state.regs.write(rd, value);
69            }
70            Self::Sra { rd, rs1, rs2 } => {
71                let shamt = state.regs.read(rs2) & 0x3f;
72                let value = state.regs.read(rs1).cast_signed() >> shamt;
73                state.regs.write(rd, value.cast_unsigned());
74            }
75            Self::Or { rd, rs1, rs2 } => {
76                let value = state.regs.read(rs1) | state.regs.read(rs2);
77                state.regs.write(rd, value);
78            }
79            Self::And { rd, rs1, rs2 } => {
80                let value = state.regs.read(rs1) & state.regs.read(rs2);
81                state.regs.write(rd, value);
82            }
83
84            Self::Addw { rd, rs1, rs2 } => {
85                let sum = (state.regs.read(rs1) as i32).wrapping_add(state.regs.read(rs2) as i32);
86                state.regs.write(rd, i64::from(sum).cast_unsigned());
87            }
88            Self::Subw { rd, rs1, rs2 } => {
89                let diff = (state.regs.read(rs1) as i32).wrapping_sub(state.regs.read(rs2) as i32);
90                state.regs.write(rd, i64::from(diff).cast_unsigned());
91            }
92            Self::Sllw { rd, rs1, rs2 } => {
93                let shamt = state.regs.read(rs2) & 0x1f;
94                let shifted = (state.regs.read(rs1) as u32) << shamt;
95                state
96                    .regs
97                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
98            }
99            Self::Srlw { rd, rs1, rs2 } => {
100                let shamt = state.regs.read(rs2) & 0x1f;
101                let shifted = (state.regs.read(rs1) as u32) >> shamt;
102                state
103                    .regs
104                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
105            }
106            Self::Sraw { rd, rs1, rs2 } => {
107                let shamt = state.regs.read(rs2) & 0x1f;
108                let shifted = (state.regs.read(rs1) as i32) >> shamt;
109                state.regs.write(rd, i64::from(shifted).cast_unsigned());
110            }
111
112            Self::Addi { rd, rs1, imm } => {
113                let value = state
114                    .regs
115                    .read(rs1)
116                    .wrapping_add(i64::from(imm).cast_unsigned());
117                state.regs.write(rd, value);
118            }
119            Self::Slti { rd, rs1, imm } => {
120                let value = state.regs.read(rs1).cast_signed() < i64::from(imm);
121                state.regs.write(rd, value as u64);
122            }
123            Self::Sltiu { rd, rs1, imm } => {
124                let value = state.regs.read(rs1) < i64::from(imm).cast_unsigned();
125                state.regs.write(rd, value as u64);
126            }
127            Self::Xori { rd, rs1, imm } => {
128                let value = state.regs.read(rs1) ^ i64::from(imm).cast_unsigned();
129                state.regs.write(rd, value);
130            }
131            Self::Ori { rd, rs1, imm } => {
132                let value = state.regs.read(rs1) | i64::from(imm).cast_unsigned();
133                state.regs.write(rd, value);
134            }
135            Self::Andi { rd, rs1, imm } => {
136                let value = state.regs.read(rs1) & i64::from(imm).cast_unsigned();
137                state.regs.write(rd, value);
138            }
139            Self::Slli { rd, rs1, shamt } => {
140                let value = state.regs.read(rs1) << shamt;
141                state.regs.write(rd, value);
142            }
143            Self::Srli { rd, rs1, shamt } => {
144                let value = state.regs.read(rs1) >> shamt;
145                state.regs.write(rd, value);
146            }
147            Self::Srai { rd, rs1, shamt } => {
148                let value = state.regs.read(rs1).cast_signed() >> shamt;
149                state.regs.write(rd, value.cast_unsigned());
150            }
151
152            Self::Addiw { rd, rs1, imm } => {
153                let sum = (state.regs.read(rs1) as i32).wrapping_add(i32::from(imm));
154                state.regs.write(rd, i64::from(sum).cast_unsigned());
155            }
156            Self::Slliw { rd, rs1, shamt } => {
157                let shifted = (state.regs.read(rs1) as u32) << shamt;
158                state
159                    .regs
160                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
161            }
162            Self::Srliw { rd, rs1, shamt } => {
163                let shifted = (state.regs.read(rs1) as u32) >> shamt;
164                state
165                    .regs
166                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
167            }
168            Self::Sraiw { rd, rs1, shamt } => {
169                let shifted = (state.regs.read(rs1) as i32) >> shamt;
170                state.regs.write(rd, i64::from(shifted).cast_unsigned());
171            }
172
173            Self::Lb { rd, rs1, imm } => {
174                let addr = state
175                    .regs
176                    .read(rs1)
177                    .wrapping_add(i64::from(imm).cast_unsigned());
178                let value = i64::from(state.memory.read::<i8>(addr)?);
179                state.regs.write(rd, value.cast_unsigned());
180            }
181            Self::Lh { rd, rs1, imm } => {
182                let addr = state
183                    .regs
184                    .read(rs1)
185                    .wrapping_add(i64::from(imm).cast_unsigned());
186                let value = i64::from(state.memory.read::<i16>(addr)?);
187                state.regs.write(rd, value.cast_unsigned());
188            }
189            Self::Lw { rd, rs1, imm } => {
190                let addr = state
191                    .regs
192                    .read(rs1)
193                    .wrapping_add(i64::from(imm).cast_unsigned());
194                let value = i64::from(state.memory.read::<i32>(addr)?);
195                state.regs.write(rd, value.cast_unsigned());
196            }
197            Self::Ld { rd, rs1, imm } => {
198                let addr = state
199                    .regs
200                    .read(rs1)
201                    .wrapping_add(i64::from(imm).cast_unsigned());
202                let value = state.memory.read::<u64>(addr)?;
203                state.regs.write(rd, value);
204            }
205            Self::Lbu { rd, rs1, imm } => {
206                let addr = state
207                    .regs
208                    .read(rs1)
209                    .wrapping_add(i64::from(imm).cast_unsigned());
210                let value = state.memory.read::<u8>(addr)?;
211                state.regs.write(rd, value as u64);
212            }
213            Self::Lhu { rd, rs1, imm } => {
214                let addr = state
215                    .regs
216                    .read(rs1)
217                    .wrapping_add(i64::from(imm).cast_unsigned());
218                let value = state.memory.read::<u16>(addr)?;
219                state.regs.write(rd, value as u64);
220            }
221            Self::Lwu { rd, rs1, imm } => {
222                let addr = state
223                    .regs
224                    .read(rs1)
225                    .wrapping_add(i64::from(imm).cast_unsigned());
226                let value = state.memory.read::<u32>(addr)?;
227                state.regs.write(rd, value as u64);
228            }
229
230            Self::Jalr { rd, rs1, imm } => {
231                let target = (state
232                    .regs
233                    .read(rs1)
234                    .wrapping_add(i64::from(imm).cast_unsigned()))
235                    & !1u64;
236                state.regs.write(rd, state.instruction_fetcher.get_pc());
237                return state
238                    .instruction_fetcher
239                    .set_pc(&state.memory, target)
240                    .map_err(ExecutionError::from);
241            }
242
243            Self::Sb { rs2, rs1, imm } => {
244                let addr = state
245                    .regs
246                    .read(rs1)
247                    .wrapping_add(i64::from(imm).cast_unsigned());
248                state.memory.write(addr, state.regs.read(rs2) as u8)?;
249            }
250            Self::Sh { rs2, rs1, imm } => {
251                let addr = state
252                    .regs
253                    .read(rs1)
254                    .wrapping_add(i64::from(imm).cast_unsigned());
255                state.memory.write(addr, state.regs.read(rs2) as u16)?;
256            }
257            Self::Sw { rs2, rs1, imm } => {
258                let addr = state
259                    .regs
260                    .read(rs1)
261                    .wrapping_add(i64::from(imm).cast_unsigned());
262                state.memory.write(addr, state.regs.read(rs2) as u32)?;
263            }
264            Self::Sd { rs2, rs1, imm } => {
265                let addr = state
266                    .regs
267                    .read(rs1)
268                    .wrapping_add(i64::from(imm).cast_unsigned());
269                state.memory.write(addr, state.regs.read(rs2))?;
270            }
271
272            Self::Beq { rs1, rs2, imm } => {
273                if state.regs.read(rs1) == state.regs.read(rs2) {
274                    let old_pc = state.instruction_fetcher.old_pc(self.size());
275                    return state
276                        .instruction_fetcher
277                        .set_pc(
278                            &state.memory,
279                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
280                        )
281                        .map_err(ExecutionError::from);
282                }
283            }
284            Self::Bne { rs1, rs2, imm } => {
285                if state.regs.read(rs1) != state.regs.read(rs2) {
286                    let old_pc = state.instruction_fetcher.old_pc(self.size());
287                    return state
288                        .instruction_fetcher
289                        .set_pc(
290                            &state.memory,
291                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
292                        )
293                        .map_err(ExecutionError::from);
294                }
295            }
296            Self::Blt { rs1, rs2, imm } => {
297                if state.regs.read(rs1).cast_signed() < state.regs.read(rs2).cast_signed() {
298                    let old_pc = state.instruction_fetcher.old_pc(self.size());
299                    return state
300                        .instruction_fetcher
301                        .set_pc(
302                            &state.memory,
303                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
304                        )
305                        .map_err(ExecutionError::from);
306                }
307            }
308            Self::Bge { rs1, rs2, imm } => {
309                if state.regs.read(rs1).cast_signed() >= state.regs.read(rs2).cast_signed() {
310                    let old_pc = state.instruction_fetcher.old_pc(self.size());
311                    return state
312                        .instruction_fetcher
313                        .set_pc(
314                            &state.memory,
315                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
316                        )
317                        .map_err(ExecutionError::from);
318                }
319            }
320            Self::Bltu { rs1, rs2, imm } => {
321                if state.regs.read(rs1) < state.regs.read(rs2) {
322                    let old_pc = state.instruction_fetcher.old_pc(self.size());
323                    return state
324                        .instruction_fetcher
325                        .set_pc(
326                            &state.memory,
327                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
328                        )
329                        .map_err(ExecutionError::from);
330                }
331            }
332            Self::Bgeu { rs1, rs2, imm } => {
333                if state.regs.read(rs1) >= state.regs.read(rs2) {
334                    let old_pc = state.instruction_fetcher.old_pc(self.size());
335                    return state
336                        .instruction_fetcher
337                        .set_pc(
338                            &state.memory,
339                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
340                        )
341                        .map_err(ExecutionError::from);
342                }
343            }
344
345            Self::Lui { rd, imm } => {
346                state.regs.write(rd, i64::from(imm).cast_unsigned());
347            }
348
349            Self::Auipc { rd, imm } => {
350                let old_pc = state.instruction_fetcher.old_pc(self.size());
351                state
352                    .regs
353                    .write(rd, old_pc.wrapping_add(i64::from(imm).cast_unsigned()));
354            }
355
356            Self::Jal { rd, imm } => {
357                let pc = state.instruction_fetcher.get_pc();
358                let old_pc = state.instruction_fetcher.old_pc(self.size());
359                state.regs.write(rd, pc);
360                return state
361                    .instruction_fetcher
362                    .set_pc(
363                        &state.memory,
364                        old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
365                    )
366                    .map_err(ExecutionError::from);
367            }
368
369            Self::Fence { pred, succ } => {
370                state.system_instruction_handler.handle_fence(pred, succ);
371            }
372            Self::FenceTso => {
373                state.system_instruction_handler.handle_fence_tso();
374            }
375
376            Self::Ecall => {
377                return state.system_instruction_handler.handle_ecall(
378                    &mut state.regs,
379                    &mut state.memory,
380                    &mut state.instruction_fetcher,
381                );
382            }
383            Self::Ebreak => {
384                state.system_instruction_handler.handle_ebreak(
385                    &mut state.regs,
386                    &mut state.memory,
387                    state.instruction_fetcher.get_pc(),
388                );
389            }
390
391            Self::Unimp => {
392                let old_pc = state.instruction_fetcher.old_pc(self.size());
393                return Err(ExecutionError::IllegalInstruction { address: old_pc });
394            }
395        }
396
397        Ok(ControlFlow::Continue(()))
398    }
399}