Skip to main content

ab_riscv_interpreter/
rv64.rs

1//! Part of the interpreter responsible for RISC-V RV64 base instruction set
2
3pub mod b;
4pub mod m;
5#[cfg(test)]
6mod test_utils;
7#[cfg(test)]
8mod tests;
9pub mod zk;
10
11use crate::{
12    ExecutableInstruction, ExecutionError, ProgramCounter, ProgramCounterError, VirtualMemory,
13};
14use ab_riscv_macros::instruction_execution;
15use ab_riscv_primitives::instruction::Instruction;
16use ab_riscv_primitives::instruction::rv64::Rv64Instruction;
17use ab_riscv_primitives::registers::{Register, Registers};
18use core::marker::PhantomData;
19use core::ops::ControlFlow;
20
21/// Custom handler for system instructions `ecall` and `ebreak`
22pub trait Rv64SystemInstructionHandler<Reg, Memory, PC, CustomError>
23where
24    Reg: Register<Type = u64>,
25    [(); Reg::N]:,
26{
27    /// Handle an `ecall` instruction.
28    ///
29    /// NOTE: the program counter here is the current value, meaning it is already incremented past
30    /// the instruction itself.
31    fn handle_ecall(
32        &mut self,
33        regs: &mut Registers<Reg>,
34        memory: &mut Memory,
35        program_counter: &mut PC,
36    ) -> Result<ControlFlow<()>, ExecutionError<Reg::Type, Rv64Instruction<Reg>, CustomError>>;
37
38    /// Handle an `ebreak` instruction.
39    ///
40    /// NOTE: the program counter here is the current value, meaning it is already incremented past
41    /// the instruction itself.
42    #[inline(always)]
43    fn handle_ebreak(
44        &mut self,
45        _regs: &mut Registers<Reg>,
46        _memory: &mut Memory,
47        _pc: Reg::Type,
48        _instruction: Rv64Instruction<Reg>,
49    ) {
50        // NOP by default
51    }
52}
53
54/// Basic system instruction handler that does nothing on `ebreak` and returns an error on `ecall`.
55#[derive(Debug, Clone, Copy)]
56pub struct BasicRv64SystemInstructionHandler<Reg> {
57    _phantom: PhantomData<Reg>,
58}
59
60impl<Reg> Default for BasicRv64SystemInstructionHandler<Reg> {
61    #[inline(always)]
62    fn default() -> Self {
63        Self {
64            _phantom: PhantomData,
65        }
66    }
67}
68
69impl<Reg, Memory, PC, CustomError> Rv64SystemInstructionHandler<Reg, Memory, PC, CustomError>
70    for BasicRv64SystemInstructionHandler<Rv64Instruction<Reg>>
71where
72    Reg: Register<Type = u64>,
73    [(); Reg::N]:,
74    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
75{
76    #[inline(always)]
77    fn handle_ecall(
78        &mut self,
79        _regs: &mut Registers<Reg>,
80        _memory: &mut Memory,
81        program_counter: &mut PC,
82    ) -> Result<ControlFlow<()>, ExecutionError<Reg::Type, Rv64Instruction<Reg>, CustomError>> {
83        let instruction = Rv64Instruction::Ecall;
84        Err(ExecutionError::UnsupportedInstruction {
85            address: program_counter.get_pc() - Reg::Type::from(instruction.size()),
86            instruction,
87        })
88    }
89}
90
91/// RV64 interpreter state
92#[derive(Debug)]
93// 16-byte alignment seems faster than 64 (cache line) for some reason, reconsider in the future
94#[repr(align(16))]
95pub struct Rv64InterpreterState<Reg, Memory, IF, InstructionHandler, CustomError>
96where
97    Reg: Register<Type = u64>,
98    [(); Reg::N]:,
99{
100    /// Registers
101    pub regs: Registers<Reg>,
102    /// Memory
103    pub memory: Memory,
104    /// Instruction fetcher
105    pub instruction_fetcher: IF,
106    /// System instruction handler
107    pub system_instruction_handler: InstructionHandler,
108    #[doc(hidden)]
109    pub _phantom: PhantomData<CustomError>,
110}
111
112impl<Reg, Memory, IF, InstructionHandler, CustomError>
113    Rv64InterpreterState<Reg, Memory, IF, InstructionHandler, CustomError>
114where
115    Reg: Register<Type = u64>,
116    [(); Reg::N]:,
117    IF: ProgramCounter<Reg::Type, Memory, CustomError>,
118{
119    /// Set program counter
120    pub fn set_pc(
121        &mut self,
122        pc: Reg::Type,
123    ) -> Result<ControlFlow<()>, ProgramCounterError<Reg::Type, CustomError>> {
124        self.instruction_fetcher.set_pc(&mut self.memory, pc)
125    }
126}
127
128#[instruction_execution]
129impl<Reg, Memory, PC, InstructionHandler, CustomError>
130    ExecutableInstruction<
131        Rv64InterpreterState<Reg, Memory, PC, InstructionHandler, CustomError>,
132        CustomError,
133    > for Rv64Instruction<Reg>
134where
135    Reg: Register<Type = u64>,
136    [(); Reg::N]:,
137    Memory: VirtualMemory,
138    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
139    InstructionHandler: Rv64SystemInstructionHandler<Reg, Memory, PC, CustomError>,
140{
141    #[inline(always)]
142    fn execute(
143        self,
144        state: &mut Rv64InterpreterState<Reg, Memory, PC, InstructionHandler, CustomError>,
145    ) -> Result<ControlFlow<()>, ExecutionError<Reg::Type, Self, CustomError>> {
146        match self {
147            Self::Add { rd, rs1, rs2 } => {
148                let value = state.regs.read(rs1).wrapping_add(state.regs.read(rs2));
149                state.regs.write(rd, value);
150            }
151            Self::Sub { rd, rs1, rs2 } => {
152                let value = state.regs.read(rs1).wrapping_sub(state.regs.read(rs2));
153                state.regs.write(rd, value);
154            }
155            Self::Sll { rd, rs1, rs2 } => {
156                let shamt = state.regs.read(rs2) & 0x3f;
157                let value = state.regs.read(rs1) << shamt;
158                state.regs.write(rd, value);
159            }
160            Self::Slt { rd, rs1, rs2 } => {
161                let value = state.regs.read(rs1).cast_signed() < state.regs.read(rs2).cast_signed();
162                state.regs.write(rd, value as u64);
163            }
164            Self::Sltu { rd, rs1, rs2 } => {
165                let value = state.regs.read(rs1) < state.regs.read(rs2);
166                state.regs.write(rd, value as u64);
167            }
168            Self::Xor { rd, rs1, rs2 } => {
169                let value = state.regs.read(rs1) ^ state.regs.read(rs2);
170                state.regs.write(rd, value);
171            }
172            Self::Srl { rd, rs1, rs2 } => {
173                let shamt = state.regs.read(rs2) & 0x3f;
174                let value = state.regs.read(rs1) >> shamt;
175                state.regs.write(rd, value);
176            }
177            Self::Sra { rd, rs1, rs2 } => {
178                let shamt = state.regs.read(rs2) & 0x3f;
179                let value = state.regs.read(rs1).cast_signed() >> shamt;
180                state.regs.write(rd, value.cast_unsigned());
181            }
182            Self::Or { rd, rs1, rs2 } => {
183                let value = state.regs.read(rs1) | state.regs.read(rs2);
184                state.regs.write(rd, value);
185            }
186            Self::And { rd, rs1, rs2 } => {
187                let value = state.regs.read(rs1) & state.regs.read(rs2);
188                state.regs.write(rd, value);
189            }
190
191            Self::Addw { rd, rs1, rs2 } => {
192                let sum = (state.regs.read(rs1) as i32).wrapping_add(state.regs.read(rs2) as i32);
193                state.regs.write(rd, i64::from(sum).cast_unsigned());
194            }
195            Self::Subw { rd, rs1, rs2 } => {
196                let diff = (state.regs.read(rs1) as i32).wrapping_sub(state.regs.read(rs2) as i32);
197                state.regs.write(rd, i64::from(diff).cast_unsigned());
198            }
199            Self::Sllw { rd, rs1, rs2 } => {
200                let shamt = state.regs.read(rs2) & 0x1f;
201                let shifted = (state.regs.read(rs1) as u32) << shamt;
202                state
203                    .regs
204                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
205            }
206            Self::Srlw { rd, rs1, rs2 } => {
207                let shamt = state.regs.read(rs2) & 0x1f;
208                let shifted = (state.regs.read(rs1) as u32) >> shamt;
209                state
210                    .regs
211                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
212            }
213            Self::Sraw { rd, rs1, rs2 } => {
214                let shamt = state.regs.read(rs2) & 0x1f;
215                let shifted = (state.regs.read(rs1) as i32) >> shamt;
216                state.regs.write(rd, i64::from(shifted).cast_unsigned());
217            }
218
219            Self::Addi { rd, rs1, imm } => {
220                let value = state
221                    .regs
222                    .read(rs1)
223                    .wrapping_add(i64::from(imm).cast_unsigned());
224                state.regs.write(rd, value);
225            }
226            Self::Slti { rd, rs1, imm } => {
227                let value = state.regs.read(rs1).cast_signed() < i64::from(imm);
228                state.regs.write(rd, value as u64);
229            }
230            Self::Sltiu { rd, rs1, imm } => {
231                let value = state.regs.read(rs1) < i64::from(imm).cast_unsigned();
232                state.regs.write(rd, value as u64);
233            }
234            Self::Xori { rd, rs1, imm } => {
235                let value = state.regs.read(rs1) ^ i64::from(imm).cast_unsigned();
236                state.regs.write(rd, value);
237            }
238            Self::Ori { rd, rs1, imm } => {
239                let value = state.regs.read(rs1) | i64::from(imm).cast_unsigned();
240                state.regs.write(rd, value);
241            }
242            Self::Andi { rd, rs1, imm } => {
243                let value = state.regs.read(rs1) & i64::from(imm).cast_unsigned();
244                state.regs.write(rd, value);
245            }
246            Self::Slli { rd, rs1, shamt } => {
247                let value = state.regs.read(rs1) << shamt;
248                state.regs.write(rd, value);
249            }
250            Self::Srli { rd, rs1, shamt } => {
251                let value = state.regs.read(rs1) >> shamt;
252                state.regs.write(rd, value);
253            }
254            Self::Srai { rd, rs1, shamt } => {
255                let value = state.regs.read(rs1).cast_signed() >> shamt;
256                state.regs.write(rd, value.cast_unsigned());
257            }
258
259            Self::Addiw { rd, rs1, imm } => {
260                let sum = (state.regs.read(rs1) as i32).wrapping_add(i32::from(imm));
261                state.regs.write(rd, i64::from(sum).cast_unsigned());
262            }
263            Self::Slliw { rd, rs1, shamt } => {
264                let shifted = (state.regs.read(rs1) as u32) << shamt;
265                state
266                    .regs
267                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
268            }
269            Self::Srliw { rd, rs1, shamt } => {
270                let shifted = (state.regs.read(rs1) as u32) >> shamt;
271                state
272                    .regs
273                    .write(rd, i64::from(shifted.cast_signed()).cast_unsigned());
274            }
275            Self::Sraiw { rd, rs1, shamt } => {
276                let shifted = (state.regs.read(rs1) as i32) >> shamt;
277                state.regs.write(rd, i64::from(shifted).cast_unsigned());
278            }
279
280            Self::Lb { rd, rs1, imm } => {
281                let addr = state
282                    .regs
283                    .read(rs1)
284                    .wrapping_add(i64::from(imm).cast_unsigned());
285                let value = i64::from(state.memory.read::<i8>(addr)?);
286                state.regs.write(rd, value.cast_unsigned());
287            }
288            Self::Lh { rd, rs1, imm } => {
289                let addr = state
290                    .regs
291                    .read(rs1)
292                    .wrapping_add(i64::from(imm).cast_unsigned());
293                let value = i64::from(state.memory.read::<i16>(addr)?);
294                state.regs.write(rd, value.cast_unsigned());
295            }
296            Self::Lw { rd, rs1, imm } => {
297                let addr = state
298                    .regs
299                    .read(rs1)
300                    .wrapping_add(i64::from(imm).cast_unsigned());
301                let value = i64::from(state.memory.read::<i32>(addr)?);
302                state.regs.write(rd, value.cast_unsigned());
303            }
304            Self::Ld { rd, rs1, imm } => {
305                let addr = state
306                    .regs
307                    .read(rs1)
308                    .wrapping_add(i64::from(imm).cast_unsigned());
309                let value = state.memory.read::<u64>(addr)?;
310                state.regs.write(rd, value);
311            }
312            Self::Lbu { rd, rs1, imm } => {
313                let addr = state
314                    .regs
315                    .read(rs1)
316                    .wrapping_add(i64::from(imm).cast_unsigned());
317                let value = state.memory.read::<u8>(addr)?;
318                state.regs.write(rd, value as u64);
319            }
320            Self::Lhu { rd, rs1, imm } => {
321                let addr = state
322                    .regs
323                    .read(rs1)
324                    .wrapping_add(i64::from(imm).cast_unsigned());
325                let value = state.memory.read::<u16>(addr)?;
326                state.regs.write(rd, value as u64);
327            }
328            Self::Lwu { rd, rs1, imm } => {
329                let addr = state
330                    .regs
331                    .read(rs1)
332                    .wrapping_add(i64::from(imm).cast_unsigned());
333                let value = state.memory.read::<u32>(addr)?;
334                state.regs.write(rd, value as u64);
335            }
336
337            Self::Jalr { rd, rs1, imm } => {
338                let target = (state
339                    .regs
340                    .read(rs1)
341                    .wrapping_add(i64::from(imm).cast_unsigned()))
342                    & !1u64;
343                state.regs.write(rd, state.instruction_fetcher.get_pc());
344                return state
345                    .instruction_fetcher
346                    .set_pc(&mut state.memory, target)
347                    .map_err(ExecutionError::from);
348            }
349
350            Self::Sb { rs2, rs1, imm } => {
351                let addr = state
352                    .regs
353                    .read(rs1)
354                    .wrapping_add(i64::from(imm).cast_unsigned());
355                state.memory.write(addr, state.regs.read(rs2) as u8)?;
356            }
357            Self::Sh { rs2, rs1, imm } => {
358                let addr = state
359                    .regs
360                    .read(rs1)
361                    .wrapping_add(i64::from(imm).cast_unsigned());
362                state.memory.write(addr, state.regs.read(rs2) as u16)?;
363            }
364            Self::Sw { rs2, rs1, imm } => {
365                let addr = state
366                    .regs
367                    .read(rs1)
368                    .wrapping_add(i64::from(imm).cast_unsigned());
369                state.memory.write(addr, state.regs.read(rs2) as u32)?;
370            }
371            Self::Sd { rs2, rs1, imm } => {
372                let addr = state
373                    .regs
374                    .read(rs1)
375                    .wrapping_add(i64::from(imm).cast_unsigned());
376                state.memory.write(addr, state.regs.read(rs2))?;
377            }
378
379            Self::Beq { rs1, rs2, imm } => {
380                if state.regs.read(rs1) == state.regs.read(rs2) {
381                    let old_pc = state
382                        .instruction_fetcher
383                        .get_pc()
384                        .wrapping_sub(self.size().into());
385                    return state
386                        .instruction_fetcher
387                        .set_pc(
388                            &mut state.memory,
389                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
390                        )
391                        .map_err(ExecutionError::from);
392                }
393            }
394            Self::Bne { rs1, rs2, imm } => {
395                if state.regs.read(rs1) != state.regs.read(rs2) {
396                    let old_pc = state
397                        .instruction_fetcher
398                        .get_pc()
399                        .wrapping_sub(self.size().into());
400                    return state
401                        .instruction_fetcher
402                        .set_pc(
403                            &mut state.memory,
404                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
405                        )
406                        .map_err(ExecutionError::from);
407                }
408            }
409            Self::Blt { rs1, rs2, imm } => {
410                if state.regs.read(rs1).cast_signed() < state.regs.read(rs2).cast_signed() {
411                    let old_pc = state
412                        .instruction_fetcher
413                        .get_pc()
414                        .wrapping_sub(self.size().into());
415                    return state
416                        .instruction_fetcher
417                        .set_pc(
418                            &mut state.memory,
419                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
420                        )
421                        .map_err(ExecutionError::from);
422                }
423            }
424            Self::Bge { rs1, rs2, imm } => {
425                if state.regs.read(rs1).cast_signed() >= state.regs.read(rs2).cast_signed() {
426                    let old_pc = state
427                        .instruction_fetcher
428                        .get_pc()
429                        .wrapping_sub(self.size().into());
430                    return state
431                        .instruction_fetcher
432                        .set_pc(
433                            &mut state.memory,
434                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
435                        )
436                        .map_err(ExecutionError::from);
437                }
438            }
439            Self::Bltu { rs1, rs2, imm } => {
440                if state.regs.read(rs1) < state.regs.read(rs2) {
441                    let old_pc = state
442                        .instruction_fetcher
443                        .get_pc()
444                        .wrapping_sub(self.size().into());
445                    return state
446                        .instruction_fetcher
447                        .set_pc(
448                            &mut state.memory,
449                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
450                        )
451                        .map_err(ExecutionError::from);
452                }
453            }
454            Self::Bgeu { rs1, rs2, imm } => {
455                if state.regs.read(rs1) >= state.regs.read(rs2) {
456                    let old_pc = state
457                        .instruction_fetcher
458                        .get_pc()
459                        .wrapping_sub(self.size().into());
460                    return state
461                        .instruction_fetcher
462                        .set_pc(
463                            &mut state.memory,
464                            old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
465                        )
466                        .map_err(ExecutionError::from);
467                }
468            }
469
470            Self::Lui { rd, imm } => {
471                state.regs.write(rd, i64::from(imm).cast_unsigned());
472            }
473
474            Self::Auipc { rd, imm } => {
475                let old_pc = state
476                    .instruction_fetcher
477                    .get_pc()
478                    .wrapping_sub(self.size().into());
479                state
480                    .regs
481                    .write(rd, old_pc.wrapping_add(i64::from(imm).cast_unsigned()));
482            }
483
484            Self::Jal { rd, imm } => {
485                let pc = state.instruction_fetcher.get_pc();
486                let old_pc = pc.wrapping_sub(self.size().into());
487                state.regs.write(rd, pc);
488                return state
489                    .instruction_fetcher
490                    .set_pc(
491                        &mut state.memory,
492                        old_pc.wrapping_add(i64::from(imm).cast_unsigned()),
493                    )
494                    .map_err(ExecutionError::from);
495            }
496
497            Self::Fence { .. } => {
498                // NOP for single-threaded
499            }
500
501            Self::Ecall => {
502                return state
503                    .system_instruction_handler
504                    .handle_ecall(
505                        &mut state.regs,
506                        &mut state.memory,
507                        &mut state.instruction_fetcher,
508                    )
509                    .map_err(|error| {
510                        error.map_instruction(|_instruction| {
511                            // This mapping helps with instruction type during inheritance
512                            Self::Ecall
513                        })
514                    });
515            }
516            Self::Ebreak => {
517                state.system_instruction_handler.handle_ebreak(
518                    &mut state.regs,
519                    &mut state.memory,
520                    state.instruction_fetcher.get_pc(),
521                    Rv64Instruction::<Reg>::Ebreak,
522                );
523            }
524
525            Self::Unimp => {
526                let old_pc = state
527                    .instruction_fetcher
528                    .get_pc()
529                    .wrapping_sub(self.size().into());
530                return Err(ExecutionError::UnimpInstruction { address: old_pc });
531            }
532        }
533
534        Ok(ControlFlow::Continue(()))
535    }
536}