ab_riscv_interpreter/
rv64.rs

1//! Part of the interpreter responsible for RISC-V RV64 base instruction set
2
3#[cfg(test)]
4mod tests;
5
6use crate::{ExecutionError, ProgramCounter, VirtualMemory};
7use ab_riscv_primitives::instruction::Instruction;
8use ab_riscv_primitives::instruction::rv64::Rv64Instruction;
9use ab_riscv_primitives::registers::{Register, Registers};
10use core::fmt;
11use core::marker::PhantomData;
12use core::ops::ControlFlow;
13
14/// Custom handler for system instructions `ecall` and `ebreak`
15pub trait Rv64SystemInstructionHandler<Reg, Memory, PC, CustomError>
16where
17    Reg: Register<Type = u64>,
18    [(); Reg::N]:,
19    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
20    CustomError: fmt::Display,
21{
22    /// Handle an `ecall` instruction.
23    ///
24    /// NOTE: the program counter here is the current value, meaning it is already incremented past
25    /// the instruction itself.
26    fn handle_ecall(
27        &mut self,
28        regs: &mut Registers<Reg>,
29        memory: &mut Memory,
30        program_counter: &mut PC,
31    ) -> Result<ControlFlow<()>, ExecutionError<Rv64Instruction<Reg>, CustomError>>;
32
33    /// Handle an `ebreak` instruction.
34    ///
35    /// NOTE: the program counter here is the current value, meaning it is already incremented past
36    /// the instruction itself.
37    #[inline(always)]
38    fn handle_ebreak(
39        &mut self,
40        _regs: &mut Registers<Reg>,
41        _memory: &mut Memory,
42        _pc: Reg::Type,
43        _instruction: Rv64Instruction<Reg>,
44    ) {
45        // NOP by default
46    }
47}
48
49/// Basic system instruction handler that does nothing on `ebreak` and returns an error on `ecall`.
50#[derive(Debug, Clone, Copy)]
51pub struct BasicRv64SystemInstructionHandler<Reg> {
52    _phantom: PhantomData<Reg>,
53}
54
55impl<Reg> Default for BasicRv64SystemInstructionHandler<Reg> {
56    #[inline(always)]
57    fn default() -> Self {
58        Self {
59            _phantom: PhantomData,
60        }
61    }
62}
63
64impl<Reg, Memory, PC, CustomError> Rv64SystemInstructionHandler<Reg, Memory, PC, CustomError>
65    for BasicRv64SystemInstructionHandler<Rv64Instruction<Reg>>
66where
67    Reg: Register<Type = u64>,
68    [(); Reg::N]:,
69    Memory: VirtualMemory,
70    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
71    CustomError: fmt::Display,
72{
73    #[inline(always)]
74    fn handle_ecall(
75        &mut self,
76        _regs: &mut Registers<Reg>,
77        _memory: &mut Memory,
78        program_counter: &mut PC,
79    ) -> Result<ControlFlow<()>, ExecutionError<Rv64Instruction<Reg>, CustomError>> {
80        let instruction = Rv64Instruction::Ecall;
81        Err(ExecutionError::UnsupportedInstruction {
82            address: program_counter.get_pc() - Reg::Type::from(instruction.size()),
83            instruction,
84        })
85    }
86}
87
88/// Execute instructions from a base RV64I/RV64E instruction set
89#[inline(always)]
90pub fn execute_rv64<Reg, Memory, PC, InstructionHandler, CustomError>(
91    regs: &mut Registers<Reg>,
92    memory: &mut Memory,
93    program_counter: &mut PC,
94    system_instruction_handlers: &mut InstructionHandler,
95    instruction: Rv64Instruction<Reg>,
96) -> Result<ControlFlow<()>, ExecutionError<Rv64Instruction<Reg>, CustomError>>
97where
98    Reg: Register<Type = u64>,
99    [(); Reg::N]:,
100    Memory: VirtualMemory,
101    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
102    InstructionHandler: Rv64SystemInstructionHandler<Reg, Memory, PC, CustomError>,
103    CustomError: fmt::Display,
104{
105    match instruction {
106        Rv64Instruction::Add { rd, rs1, rs2 } => {
107            let value = regs.read(rs1).wrapping_add(regs.read(rs2));
108            regs.write(rd, value);
109        }
110        Rv64Instruction::Sub { rd, rs1, rs2 } => {
111            let value = regs.read(rs1).wrapping_sub(regs.read(rs2));
112            regs.write(rd, value);
113        }
114        Rv64Instruction::Sll { rd, rs1, rs2 } => {
115            let shamt = regs.read(rs2) & 0x3f;
116            let value = regs.read(rs1) << shamt;
117            regs.write(rd, value);
118        }
119        Rv64Instruction::Slt { rd, rs1, rs2 } => {
120            let value = regs.read(rs1).cast_signed() < regs.read(rs2).cast_signed();
121            regs.write(rd, value as u64);
122        }
123        Rv64Instruction::Sltu { rd, rs1, rs2 } => {
124            let value = regs.read(rs1) < regs.read(rs2);
125            regs.write(rd, value as u64);
126        }
127        Rv64Instruction::Xor { rd, rs1, rs2 } => {
128            let value = regs.read(rs1) ^ regs.read(rs2);
129            regs.write(rd, value);
130        }
131        Rv64Instruction::Srl { rd, rs1, rs2 } => {
132            let shamt = regs.read(rs2) & 0x3f;
133            let value = regs.read(rs1) >> shamt;
134            regs.write(rd, value);
135        }
136        Rv64Instruction::Sra { rd, rs1, rs2 } => {
137            let shamt = regs.read(rs2) & 0x3f;
138            let value = regs.read(rs1).cast_signed() >> shamt;
139            regs.write(rd, value.cast_unsigned());
140        }
141        Rv64Instruction::Or { rd, rs1, rs2 } => {
142            let value = regs.read(rs1) | regs.read(rs2);
143            regs.write(rd, value);
144        }
145        Rv64Instruction::And { rd, rs1, rs2 } => {
146            let value = regs.read(rs1) & regs.read(rs2);
147            regs.write(rd, value);
148        }
149
150        Rv64Instruction::Addw { rd, rs1, rs2 } => {
151            let sum = (regs.read(rs1) as i32).wrapping_add(regs.read(rs2) as i32);
152            regs.write(rd, (sum as i64).cast_unsigned());
153        }
154        Rv64Instruction::Subw { rd, rs1, rs2 } => {
155            let diff = (regs.read(rs1) as i32).wrapping_sub(regs.read(rs2) as i32);
156            regs.write(rd, (diff as i64).cast_unsigned());
157        }
158        Rv64Instruction::Sllw { rd, rs1, rs2 } => {
159            let shamt = regs.read(rs2) & 0x1f;
160            let shifted = (regs.read(rs1) as u32) << shamt;
161            regs.write(rd, (shifted.cast_signed() as i64).cast_unsigned());
162        }
163        Rv64Instruction::Srlw { rd, rs1, rs2 } => {
164            let shamt = regs.read(rs2) & 0x1f;
165            let shifted = (regs.read(rs1) as u32) >> shamt;
166            regs.write(rd, (shifted.cast_signed() as i64).cast_unsigned());
167        }
168        Rv64Instruction::Sraw { rd, rs1, rs2 } => {
169            let shamt = regs.read(rs2) & 0x1f;
170            let shifted = (regs.read(rs1) as i32) >> shamt;
171            regs.write(rd, (shifted as i64).cast_unsigned());
172        }
173
174        Rv64Instruction::Addi { rd, rs1, imm } => {
175            let value = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
176            regs.write(rd, value);
177        }
178        Rv64Instruction::Slti { rd, rs1, imm } => {
179            let value = regs.read(rs1).cast_signed() < (imm as i64);
180            regs.write(rd, value as u64);
181        }
182        Rv64Instruction::Sltiu { rd, rs1, imm } => {
183            let value = regs.read(rs1) < ((imm as i64).cast_unsigned());
184            regs.write(rd, value as u64);
185        }
186        Rv64Instruction::Xori { rd, rs1, imm } => {
187            let value = regs.read(rs1) ^ ((imm as i64).cast_unsigned());
188            regs.write(rd, value);
189        }
190        Rv64Instruction::Ori { rd, rs1, imm } => {
191            let value = regs.read(rs1) | ((imm as i64).cast_unsigned());
192            regs.write(rd, value);
193        }
194        Rv64Instruction::Andi { rd, rs1, imm } => {
195            let value = regs.read(rs1) & ((imm as i64).cast_unsigned());
196            regs.write(rd, value);
197        }
198        Rv64Instruction::Slli { rd, rs1, shamt } => {
199            let value = regs.read(rs1) << shamt;
200            regs.write(rd, value);
201        }
202        Rv64Instruction::Srli { rd, rs1, shamt } => {
203            let value = regs.read(rs1) >> shamt;
204            regs.write(rd, value);
205        }
206        Rv64Instruction::Srai { rd, rs1, shamt } => {
207            let value = regs.read(rs1).cast_signed() >> shamt;
208            regs.write(rd, value.cast_unsigned());
209        }
210
211        Rv64Instruction::Addiw { rd, rs1, imm } => {
212            let sum = (regs.read(rs1) as i32).wrapping_add(imm);
213            regs.write(rd, (sum as i64).cast_unsigned());
214        }
215        Rv64Instruction::Slliw { rd, rs1, shamt } => {
216            let shifted = (regs.read(rs1) as u32) << shamt;
217            regs.write(rd, (shifted.cast_signed() as i64).cast_unsigned());
218        }
219        Rv64Instruction::Srliw { rd, rs1, shamt } => {
220            let shifted = (regs.read(rs1) as u32) >> shamt;
221            regs.write(rd, (shifted.cast_signed() as i64).cast_unsigned());
222        }
223        Rv64Instruction::Sraiw { rd, rs1, shamt } => {
224            let shifted = (regs.read(rs1) as i32) >> shamt;
225            regs.write(rd, (shifted as i64).cast_unsigned());
226        }
227
228        Rv64Instruction::Lb { rd, rs1, imm } => {
229            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
230            let value = memory.read::<i8>(addr)? as i64;
231            regs.write(rd, value.cast_unsigned());
232        }
233        Rv64Instruction::Lh { rd, rs1, imm } => {
234            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
235            let value = memory.read::<i16>(addr)? as i64;
236            regs.write(rd, value.cast_unsigned());
237        }
238        Rv64Instruction::Lw { rd, rs1, imm } => {
239            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
240            let value = memory.read::<i32>(addr)? as i64;
241            regs.write(rd, value.cast_unsigned());
242        }
243        Rv64Instruction::Ld { rd, rs1, imm } => {
244            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
245            let value = memory.read::<u64>(addr)?;
246            regs.write(rd, value);
247        }
248        Rv64Instruction::Lbu { rd, rs1, imm } => {
249            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
250            let value = memory.read::<u8>(addr)?;
251            regs.write(rd, value as u64);
252        }
253        Rv64Instruction::Lhu { rd, rs1, imm } => {
254            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
255            let value = memory.read::<u16>(addr)?;
256            regs.write(rd, value as u64);
257        }
258        Rv64Instruction::Lwu { rd, rs1, imm } => {
259            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
260            let value = memory.read::<u32>(addr)?;
261            regs.write(rd, value as u64);
262        }
263
264        Rv64Instruction::Jalr { rd, rs1, imm } => {
265            let target = (regs.read(rs1).wrapping_add((imm as i64).cast_unsigned())) & !1u64;
266            regs.write(rd, program_counter.get_pc());
267            return program_counter
268                .set_pc(memory, target)
269                .map_err(ExecutionError::from);
270        }
271
272        Rv64Instruction::Sb { rs2, rs1, imm } => {
273            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
274            memory.write(addr, regs.read(rs2) as u8)?;
275        }
276        Rv64Instruction::Sh { rs2, rs1, imm } => {
277            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
278            memory.write(addr, regs.read(rs2) as u16)?;
279        }
280        Rv64Instruction::Sw { rs2, rs1, imm } => {
281            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
282            memory.write(addr, regs.read(rs2) as u32)?;
283        }
284        Rv64Instruction::Sd { rs2, rs1, imm } => {
285            let addr = regs.read(rs1).wrapping_add((imm as i64).cast_unsigned());
286            memory.write(addr, regs.read(rs2))?;
287        }
288
289        Rv64Instruction::Beq { rs1, rs2, imm } => {
290            if regs.read(rs1) == regs.read(rs2) {
291                let old_pc = program_counter
292                    .get_pc()
293                    .wrapping_sub(instruction.size().into());
294                return program_counter
295                    .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
296                    .map_err(ExecutionError::from);
297            }
298        }
299        Rv64Instruction::Bne { rs1, rs2, imm } => {
300            if regs.read(rs1) != regs.read(rs2) {
301                let old_pc = program_counter
302                    .get_pc()
303                    .wrapping_sub(instruction.size().into());
304                return program_counter
305                    .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
306                    .map_err(ExecutionError::from);
307            }
308        }
309        Rv64Instruction::Blt { rs1, rs2, imm } => {
310            if regs.read(rs1).cast_signed() < regs.read(rs2).cast_signed() {
311                let old_pc = program_counter
312                    .get_pc()
313                    .wrapping_sub(instruction.size().into());
314                return program_counter
315                    .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
316                    .map_err(ExecutionError::from);
317            }
318        }
319        Rv64Instruction::Bge { rs1, rs2, imm } => {
320            if regs.read(rs1).cast_signed() >= regs.read(rs2).cast_signed() {
321                let old_pc = program_counter
322                    .get_pc()
323                    .wrapping_sub(instruction.size().into());
324                return program_counter
325                    .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
326                    .map_err(ExecutionError::from);
327            }
328        }
329        Rv64Instruction::Bltu { rs1, rs2, imm } => {
330            if regs.read(rs1) < regs.read(rs2) {
331                let old_pc = program_counter
332                    .get_pc()
333                    .wrapping_sub(instruction.size().into());
334                return program_counter
335                    .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
336                    .map_err(ExecutionError::from);
337            }
338        }
339        Rv64Instruction::Bgeu { rs1, rs2, imm } => {
340            if regs.read(rs1) >= regs.read(rs2) {
341                let old_pc = program_counter
342                    .get_pc()
343                    .wrapping_sub(instruction.size().into());
344                return program_counter
345                    .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
346                    .map_err(ExecutionError::from);
347            }
348        }
349
350        Rv64Instruction::Lui { rd, imm } => {
351            regs.write(rd, (imm as i64).cast_unsigned());
352        }
353
354        Rv64Instruction::Auipc { rd, imm } => {
355            let old_pc = program_counter
356                .get_pc()
357                .wrapping_sub(instruction.size().into());
358            regs.write(rd, old_pc.wrapping_add((imm as i64).cast_unsigned()));
359        }
360
361        Rv64Instruction::Jal { rd, imm } => {
362            let pc = program_counter.get_pc();
363            let old_pc = pc.wrapping_sub(instruction.size().into());
364            regs.write(rd, pc);
365            return program_counter
366                .set_pc(memory, old_pc.wrapping_add((imm as i64).cast_unsigned()))
367                .map_err(ExecutionError::from);
368        }
369
370        Rv64Instruction::Fence { .. } => {
371            // NOP for single-threaded
372        }
373
374        Rv64Instruction::Ecall => {
375            return system_instruction_handlers.handle_ecall(regs, memory, program_counter);
376        }
377        Rv64Instruction::Ebreak => {
378            system_instruction_handlers.handle_ebreak(
379                regs,
380                memory,
381                program_counter.get_pc(),
382                Rv64Instruction::Ebreak,
383            );
384        }
385
386        Rv64Instruction::Unimp => {
387            let old_pc = program_counter
388                .get_pc()
389                .wrapping_sub(instruction.size().into());
390            return Err(ExecutionError::UnimpInstruction { address: old_pc });
391        }
392
393        Rv64Instruction::Invalid(raw_instruction) => {
394            let old_pc = program_counter
395                .get_pc()
396                .wrapping_sub(instruction.size().into());
397            return Err(ExecutionError::InvalidInstruction {
398                address: old_pc,
399                instruction: raw_instruction,
400            });
401        }
402    }
403
404    Ok(ControlFlow::Continue(()))
405}