Skip to main content

ab_riscv_interpreter/
basic.rs

1//! Basic implementations of various interpreter traits
2
3#[cfg(test)]
4mod tests;
5
6use crate::{
7    Address, CustomErrorPlaceholder, ExecutionError, FetchInstructionResult, InstructionFetcher,
8    ProgramCounter, ProgramCounterError, RegisterFile, VirtualMemory,
9};
10use ab_riscv_primitives::prelude::*;
11use core::marker::PhantomData;
12use core::ops::ControlFlow;
13
14/// Basic general purpose register to be used with [`BasicRegisters`]
15///
16/// # Safety
17/// `Self::offset()` must return values in `0..Self::N` range. `Self::from_bits()` must return
18/// `Some()` for `0..=31` if `Self::RVE = false` and `0..=15` if `Self::RVE = true`.
19pub const unsafe trait BasicRegister
20where
21    Self: [const] Register,
22{
23    /// The number of general purpose registers.
24    ///
25    /// Canonically 32 unless E extension is used, in which case 16.
26    const N: usize;
27
28    /// Offset in a set of registers
29    fn offset(self) -> u8;
30}
31
32// SAFETY: `Self::offset()` returns values within `0..Self::N` range
33unsafe impl<Type> const BasicRegister for EReg<Type>
34where
35    Self: [const] Register,
36{
37    const N: usize = 16;
38
39    #[inline(always)]
40    fn offset(self) -> u8 {
41        // SAFETY: Enum is `#[repr(u8)]` and doesn't have any fields
42        unsafe { core::mem::transmute::<Self, u8>(self) }
43    }
44}
45
46// SAFETY: `Self::offset()` returns values within `0..Self::N` range
47unsafe impl<Type> const BasicRegister for Reg<Type>
48where
49    Self: [const] Register,
50{
51    const N: usize = 32;
52
53    #[inline(always)]
54    fn offset(self) -> u8 {
55        // SAFETY: Enum is `#[repr(u8)]` and doesn't have any fields
56        unsafe { core::mem::transmute::<Self, u8>(self) }
57    }
58}
59
60/// A basic set of RISC-V GPRs (General Purpose Registers)
61#[derive(Debug, Clone, Copy)]
62#[repr(align(16))]
63pub struct BasicRegisters<Reg>
64where
65    Reg: BasicRegister,
66    [(); Reg::N]:,
67{
68    regs: [Reg::Type; Reg::N],
69}
70
71impl<Reg> Default for BasicRegisters<Reg>
72where
73    Reg: BasicRegister,
74    [(); Reg::N]:,
75{
76    #[inline(always)]
77    fn default() -> Self {
78        Self {
79            regs: [Reg::Type::default(); Reg::N],
80        }
81    }
82}
83
84impl<Reg> const RegisterFile<Reg> for BasicRegisters<Reg>
85where
86    Reg: [const] BasicRegister,
87    [(); Reg::N]:,
88{
89    #[inline(always)]
90    fn read(&self, reg: Reg) -> Reg::Type {
91        if reg == Reg::ZERO {
92            // Always zero
93            return Reg::Type::default();
94        }
95
96        // SAFETY: register offset is always within bounds
97        *unsafe { self.regs.get_unchecked(usize::from(reg.offset())) }
98    }
99
100    #[inline(always)]
101    fn write(&mut self, reg: Reg, value: Reg::Type) {
102        if reg == Reg::ZERO {
103            // Writes are ignored
104            return;
105        }
106
107        // SAFETY: register offset is always within bounds
108        *unsafe { self.regs.get_unchecked_mut(usize::from(reg.offset())) } = value;
109    }
110}
111
112/// Basic interpreter state
113#[derive(Debug)]
114pub struct BasicInterpreterState<Regs, ExtState, Memory, IF, InstructionHandler> {
115    /// General purpose registers
116    pub regs: Regs,
117    /// Extended state.
118    ///
119    /// Extensions might use this to place additional constraints on `ExtState` to require
120    /// additional registers or other resources. If no such extension is used, `()` can be used as
121    /// a placeholder.
122    pub ext_state: ExtState,
123    /// Memory
124    pub memory: Memory,
125    /// Instruction fetcher
126    pub instruction_fetcher: IF,
127    /// System instruction handler
128    pub system_instruction_handler: InstructionHandler,
129}
130
131/// Basic instruction fetcher implementation.
132///
133/// Note that it loads instructions from anywhere in memory. This works, but it is likely that you
134/// want to restrict this to a specific executable region of memory.
135#[derive(Debug, Copy, Clone)]
136pub struct BasicInstructionFetcher<I, CustomError = CustomErrorPlaceholder>
137where
138    I: Instruction,
139{
140    return_trap_address: Address<I>,
141    pc: Address<I>,
142    _phantom: PhantomData<CustomError>,
143}
144
145impl<I, Memory, CustomError> ProgramCounter<Address<I>, Memory, CustomError>
146    for BasicInstructionFetcher<I, CustomError>
147where
148    I: Instruction,
149    Memory: VirtualMemory,
150{
151    #[inline(always)]
152    fn get_pc(&self) -> Address<I> {
153        self.pc
154    }
155
156    #[inline]
157    fn set_pc(
158        &mut self,
159        _memory: &Memory,
160        pc: Address<I>,
161    ) -> Result<ControlFlow<()>, ProgramCounterError<Address<I>, CustomError>> {
162        if pc == self.return_trap_address {
163            return Ok(ControlFlow::Break(()));
164        }
165
166        if !pc.as_u64().is_multiple_of(u64::from(I::alignment())) {
167            return Err(ProgramCounterError::UnalignedInstruction { address: pc });
168        }
169
170        self.pc = pc;
171
172        Ok(ControlFlow::Continue(()))
173    }
174}
175
176impl<I, Memory, CustomError> InstructionFetcher<I, Memory, CustomError>
177    for BasicInstructionFetcher<I, CustomError>
178where
179    I: Instruction,
180    Memory: VirtualMemory,
181{
182    #[inline]
183    fn fetch_instruction(
184        &mut self,
185        memory: &Memory,
186    ) -> Result<FetchInstructionResult<I>, ExecutionError<Address<I>, CustomError>> {
187        let instruction = memory.read(self.pc.as_u64()).or_else(|error| {
188            // Attempt to read a 16-bit compressed instruction
189            if let Ok(instruction) = memory.read::<u16>(self.pc.as_u64())
190                && (instruction & 0b11) != 0b11
191            {
192                return Ok(u32::from(instruction));
193            }
194            Err(error)
195        })?;
196
197        let instruction = I::try_decode(instruction)
198            .ok_or(ExecutionError::IllegalInstruction { address: self.pc })?;
199        self.pc += instruction.size().into();
200
201        Ok(FetchInstructionResult::Instruction(instruction))
202    }
203}
204
205impl<I, CustomError> BasicInstructionFetcher<I, CustomError>
206where
207    I: Instruction,
208{
209    /// Create a new instance.
210    ///
211    /// `return_trap_address` is the address at which the interpreter will stop execution
212    /// (gracefully).
213    #[inline(always)]
214    pub fn new(return_trap_address: Address<I>, pc: Address<I>) -> Self {
215        Self {
216            return_trap_address,
217            pc,
218            _phantom: PhantomData,
219        }
220    }
221}