Skip to main content

ab_riscv_interpreter/
basic.rs

1//! Basic implementations of various interpreter traits
2
3#[cfg(test)]
4mod tests;
5
6use crate::{
7    Address, BasicInt, CustomErrorPlaceholder, ExecutableInstruction, ExecutionError,
8    FetchInstructionResult, InstructionFetcher, ProgramCounter, ProgramCounterError, RegisterFile,
9    Rs1Rs2OperandValues, Rs1Rs2Operands, SystemInstructionHandler, VirtualMemory,
10    VirtualMemoryError,
11};
12use ab_riscv_primitives::prelude::*;
13use core::hint::cold_path;
14use core::marker::PhantomData;
15use core::ops::ControlFlow;
16use replace_with::replace_with_or_abort_and_return;
17
18/// Basic general purpose register to be used with [`BasicRegisters`]
19///
20/// # Safety
21/// `Self::offset()` must return values in `0..Self::N` range. `Self::from_bits()` must return
22/// `Some()` for `0..=31` if `Self::RVE = false` and `0..=15` if `Self::RVE = true`.
23pub const unsafe trait BasicRegister
24where
25    Self: [const] Register,
26{
27    /// The number of general purpose registers.
28    ///
29    /// Canonically 32 unless E extension is used, in which case 16.
30    const N: usize;
31
32    /// Offset in a set of registers
33    fn offset(self) -> u8;
34}
35
36// SAFETY: `Self::offset()` returns values within `0..Self::N` range
37unsafe impl<Type> const BasicRegister for EReg<Type>
38where
39    Self: [const] Register,
40{
41    const N: usize = 16;
42
43    #[inline(always)]
44    fn offset(self) -> u8 {
45        // SAFETY: Enum is `#[repr(u8)]` and doesn't have any fields
46        unsafe { core::mem::transmute::<Self, u8>(self) }
47    }
48}
49
50// SAFETY: `Self::offset()` returns values within `0..Self::N` range
51unsafe impl<Type> const BasicRegister for Reg<Type>
52where
53    Self: [const] Register,
54{
55    const N: usize = 32;
56
57    #[inline(always)]
58    fn offset(self) -> u8 {
59        // SAFETY: Enum is `#[repr(u8)]` and doesn't have any fields
60        unsafe { core::mem::transmute::<Self, u8>(self) }
61    }
62}
63
64/// A basic set of RISC-V GPRs (General Purpose Registers)
65#[derive(Debug, Clone, Copy)]
66#[repr(align(16))]
67pub struct BasicRegisters<Reg>
68where
69    Reg: BasicRegister,
70    [(); Reg::N]:,
71{
72    regs: [Reg::Type; Reg::N],
73}
74
75impl<Reg> Default for BasicRegisters<Reg>
76where
77    Reg: BasicRegister,
78    [(); Reg::N]:,
79{
80    #[inline(always)]
81    fn default() -> Self {
82        Self {
83            regs: [Reg::Type::default(); Reg::N],
84        }
85    }
86}
87
88impl<Reg> const RegisterFile<Reg> for BasicRegisters<Reg>
89where
90    Reg: [const] BasicRegister,
91    [(); Reg::N]:,
92{
93    #[inline(always)]
94    fn read(&self, reg: Reg) -> Reg::Type {
95        if reg == Reg::ZERO {
96            // Always zero
97            return Reg::Type::default();
98        }
99
100        // SAFETY: register offset is always within bounds
101        *unsafe { self.regs.get_unchecked(usize::from(reg.offset())) }
102    }
103
104    #[inline(always)]
105    fn write(&mut self, reg: Reg, value: Reg::Type) {
106        // SAFETY: register offset is always within bounds
107        *unsafe { self.regs.get_unchecked_mut(usize::from(reg.offset())) } = value;
108    }
109}
110
111/// Basic interpreter state.
112///
113/// This is a simple container, which is not required to be used, is helpful for storing the whole
114/// state related to the interpreter together.
115#[derive(Debug)]
116pub struct BasicInterpreterState<Regs, ExtState, Memory, IF, InstructionHandler> {
117    /// General purpose registers
118    pub regs: Regs,
119    /// Extended state.
120    ///
121    /// Extensions might use this to place additional constraints on `ExtState` to require
122    /// additional registers or other resources. If no such extension is used, `()` can be used as
123    /// a placeholder.
124    pub ext_state: ExtState,
125    /// Memory
126    pub memory: Memory,
127    /// Instruction fetcher
128    pub instruction_fetcher: IF,
129    /// System instruction handler
130    pub system_instruction_handler: InstructionHandler,
131}
132
133impl<Regs, ExtState, Memory, IF, InstructionHandler>
134    BasicInterpreterState<Regs, ExtState, Memory, IF, InstructionHandler>
135{
136    /// Execute the program with a given basic interpreter state.
137    ///
138    /// The implementation is designed to be efficient with little left to optimize further. Though
139    /// it is still possible to improve performance by applying additional constraints on the
140    /// program.
141    pub fn execute<I>(&mut self) -> Result<(), ExecutionError<Address<I>>>
142    where
143        Regs: RegisterFile<<I as Instruction>::Reg>,
144        I: ExecutableInstruction<Regs, ExtState, Memory, IF, InstructionHandler>,
145        Memory: VirtualMemory,
146        IF: InstructionFetcher<I, Memory> + ProgramCounter<Address<I>, Memory>,
147    {
148        replace_with_or_abort_and_return(
149            &mut self.instruction_fetcher,
150            #[inline(always)]
151            |mut instruction_fetcher| {
152                loop {
153                    let instruction = match instruction_fetcher.fetch_instruction(&self.memory) {
154                        Ok(FetchInstructionResult::Instruction(instruction)) => instruction,
155                        Ok(FetchInstructionResult::ControlFlow(ControlFlow::Continue(()))) => {
156                            cold_path();
157                            continue;
158                        }
159                        Ok(FetchInstructionResult::ControlFlow(ControlFlow::Break(()))) => {
160                            cold_path();
161                            break;
162                        }
163                        Err(error) => {
164                            cold_path();
165                            return (Err(error), instruction_fetcher);
166                        }
167                    };
168
169                    let Rs1Rs2Operands { rs1, rs2 } = instruction.get_rs1_rs2_operands();
170                    let rs1rs2_values = Rs1Rs2OperandValues {
171                        rs1_value: self.regs.read(rs1),
172                        rs2_value: self.regs.read(rs2),
173                    };
174
175                    match instruction.execute(
176                        rs1rs2_values,
177                        &mut self.regs,
178                        &mut self.ext_state,
179                        &mut self.memory,
180                        &mut instruction_fetcher,
181                        &mut self.system_instruction_handler,
182                    ) {
183                        Ok(ControlFlow::Continue((rd, rd_value))) => {
184                            self.regs.write(rd, rd_value);
185                        }
186                        Ok(ControlFlow::Break(())) => {
187                            cold_path();
188                            break;
189                        }
190                        Err(error) => {
191                            cold_path();
192                            return (Err(error), instruction_fetcher);
193                        }
194                    }
195                }
196
197                (Ok(()), instruction_fetcher)
198            },
199        )
200    }
201}
202
203/// Basic memory implementation.
204///
205/// Flat structure, no rwx protections, no alignment requirements. It uses stack, so for larger
206/// allocation it'll need to be boxed (zero-initialized is fine) or a custom implementation to be
207/// used.
208///
209/// This implementation is intentionally basic and correct, but not the most performant. It is
210/// possible to have a more efficient implementation that skips certain checks by placing additional
211/// constraints on the program.
212///
213/// This works for simpler cases, while a more sophisticated implementation might prevent certain
214/// memory from being writable, supporting actual virtual memory with dynamically allocated memory
215/// pages, etc.
216#[derive(Debug, Copy, Clone)]
217#[repr(align(16))]
218pub struct BasicMemory<const BASE_ADDR: u64, const SIZE: usize> {
219    data: [u8; SIZE],
220}
221
222impl<const BASE_ADDR: u64, const SIZE: usize> VirtualMemory for BasicMemory<BASE_ADDR, SIZE> {
223    #[inline(always)]
224    fn read<T>(&self, address: u64) -> Result<T, VirtualMemoryError>
225    where
226        T: BasicInt,
227    {
228        let Some(offset) = address.checked_sub(BASE_ADDR) else {
229            cold_path();
230            return Err(VirtualMemoryError::OutOfBoundsRead { address });
231        };
232
233        if offset.saturating_add(size_of::<T>() as u64) > self.data.len() as u64 {
234            cold_path();
235            return Err(VirtualMemoryError::OutOfBoundsRead { address });
236        }
237
238        // SAFETY: Only reading basic integers from initialized memory
239        unsafe {
240            Ok(self
241                .data
242                .as_ptr()
243                .cast::<T>()
244                .byte_add(offset as usize)
245                .read_unaligned())
246        }
247    }
248
249    #[inline(always)]
250    unsafe fn read_unchecked<T>(&self, address: u64) -> T
251    where
252        T: BasicInt,
253    {
254        // SAFETY: Guaranteed by function contract
255        unsafe {
256            let offset = address.unchecked_sub(BASE_ADDR) as usize;
257            self.data
258                .as_ptr()
259                .cast::<T>()
260                .byte_add(offset)
261                .read_unaligned()
262        }
263    }
264
265    fn read_slice(&self, address: u64, len: u32) -> Result<&[u8], VirtualMemoryError> {
266        let Some(offset) = address.checked_sub(BASE_ADDR) else {
267            cold_path();
268            return Err(VirtualMemoryError::OutOfBoundsRead { address });
269        };
270
271        if offset > self.data.len() as u64 {
272            cold_path();
273            return Err(VirtualMemoryError::OutOfBoundsRead { address });
274        }
275
276        self.data
277            .get(offset as usize..)
278            .and_then(|data| data.get(..len as usize))
279            .ok_or(VirtualMemoryError::OutOfBoundsRead { address })
280    }
281
282    fn read_slice_up_to(&self, address: u64, len: u32) -> &[u8] {
283        let Some(offset) = address.checked_sub(BASE_ADDR) else {
284            cold_path();
285            return &[];
286        };
287
288        if offset > self.data.len() as u64 {
289            cold_path();
290            return &[];
291        }
292
293        let remaining = self.data.get(offset as usize..).unwrap_or_default();
294        remaining.get(..len as usize).unwrap_or(remaining)
295    }
296
297    #[inline(always)]
298    fn write<T>(&mut self, address: u64, value: T) -> Result<(), VirtualMemoryError>
299    where
300        T: BasicInt,
301    {
302        let Some(offset) = address.checked_sub(BASE_ADDR) else {
303            cold_path();
304            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
305        };
306
307        if offset.saturating_add(size_of::<T>() as u64) > self.data.len() as u64 {
308            cold_path();
309            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
310        }
311
312        // SAFETY: Only writing basic integers to initialized memory
313        unsafe {
314            self.data
315                .as_mut_ptr()
316                .cast::<T>()
317                .byte_add(offset as usize)
318                .write_unaligned(value);
319        }
320
321        Ok(())
322    }
323
324    fn write_slice(&mut self, address: u64, data: &[u8]) -> Result<(), VirtualMemoryError> {
325        let Some(offset) = address.checked_sub(BASE_ADDR) else {
326            cold_path();
327            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
328        };
329
330        if offset > self.data.len() as u64 {
331            cold_path();
332            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
333        }
334
335        let len = data.len();
336        let Some(target_data) = self
337            .data
338            .get_mut(offset as usize..)
339            .and_then(|data| data.get_mut(..len))
340        else {
341            cold_path();
342            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
343        };
344
345        target_data.copy_from_slice(data);
346
347        Ok(())
348    }
349}
350
351impl<const BASE_ADDR: u64, const SIZE: usize> Default for BasicMemory<BASE_ADDR, SIZE> {
352    #[inline(always)]
353    fn default() -> Self {
354        Self { data: [0; _] }
355    }
356}
357
358impl<const BASE_ADDR: u64, const SIZE: usize> BasicMemory<BASE_ADDR, SIZE> {
359    /// Get a mutable slice of memory.
360    ///
361    /// This is primarily useful for setting up the program and should not be used beyond that.
362    pub fn get_mut_bytes(
363        &mut self,
364        address: u64,
365        size: usize,
366    ) -> Result<&mut [u8], VirtualMemoryError> {
367        let Some(offset) = address.checked_sub(BASE_ADDR) else {
368            cold_path();
369            return Err(VirtualMemoryError::OutOfBoundsRead { address });
370        };
371        let offset = offset as usize;
372
373        let Some(slice) = self
374            .data
375            .get_mut(offset..)
376            .and_then(|data| data.get_mut(..size))
377        else {
378            cold_path();
379            return Err(VirtualMemoryError::OutOfBoundsRead { address });
380        };
381
382        Ok(slice)
383    }
384}
385
386/// Basic instruction fetcher implementation.
387///
388/// This implementation is intentionally basic and correct, but not the most performant. It is
389/// possible to have a more efficient implementation that skips certain checks by placing additional
390/// constraints on the constructor.
391///
392/// Note that it loads instructions from anywhere in memory. This works, but it is likely that you
393/// want to restrict this to a specific executable region of memory.
394#[derive(Debug, Copy, Clone)]
395pub struct BasicInstructionFetcher<I, CustomError = CustomErrorPlaceholder>
396where
397    I: Instruction,
398{
399    return_trap_address: Address<I>,
400    pc: Address<I>,
401    _phantom: PhantomData<CustomError>,
402}
403
404impl<I, Memory, CustomError> ProgramCounter<Address<I>, Memory, CustomError>
405    for BasicInstructionFetcher<I, CustomError>
406where
407    I: Instruction,
408    Memory: VirtualMemory,
409{
410    #[inline(always)]
411    fn get_pc(&self) -> Address<I> {
412        self.pc
413    }
414
415    #[inline]
416    fn set_pc(
417        &mut self,
418        _memory: &Memory,
419        pc: Address<I>,
420    ) -> Result<ControlFlow<()>, ProgramCounterError<Address<I>, CustomError>> {
421        if pc == self.return_trap_address {
422            cold_path();
423            return Ok(ControlFlow::Break(()));
424        }
425
426        if !pc.as_u64().is_multiple_of(u64::from(I::alignment())) {
427            cold_path();
428            return Err(ProgramCounterError::UnalignedInstruction { address: pc });
429        }
430
431        self.pc = pc;
432
433        Ok(ControlFlow::Continue(()))
434    }
435}
436
437impl<I, Memory, CustomError> InstructionFetcher<I, Memory, CustomError>
438    for BasicInstructionFetcher<I, CustomError>
439where
440    I: Instruction,
441    Memory: VirtualMemory,
442{
443    #[inline]
444    fn fetch_instruction(
445        &mut self,
446        memory: &Memory,
447    ) -> Result<FetchInstructionResult<I>, ExecutionError<Address<I>, CustomError>> {
448        let instruction = match memory.read(self.pc.as_u64()).or_else(|error| {
449            cold_path();
450            // Attempt to read a 16-bit compressed instruction
451            if let Ok(instruction) = memory.read::<u16>(self.pc.as_u64())
452                && (instruction & 0b11) != 0b11
453            {
454                return Ok(u32::from(instruction));
455            }
456            Err(error)
457        }) {
458            Ok(instruction) => instruction,
459            Err(error) => {
460                cold_path();
461                return Err(ExecutionError::MemoryAccess(error));
462            }
463        };
464
465        let Some(instruction) = I::try_decode(instruction) else {
466            cold_path();
467            return Err(ExecutionError::IllegalInstruction { address: self.pc });
468        };
469        self.pc += instruction.size().into();
470
471        Ok(FetchInstructionResult::Instruction(instruction))
472    }
473}
474
475impl<I, CustomError> BasicInstructionFetcher<I, CustomError>
476where
477    I: Instruction,
478{
479    /// Create a new instance.
480    ///
481    /// `return_trap_address` is the address at which the interpreter will stop execution
482    /// (gracefully).
483    #[inline(always)]
484    pub fn new(return_trap_address: Address<I>, pc: Address<I>) -> Self {
485        Self {
486            return_trap_address,
487            pc,
488            _phantom: PhantomData,
489        }
490    }
491}
492
493/// System instruction handler that results in illegal instruction for all system calls and does
494/// nothing for other system instructions
495#[derive(Debug, Default, Clone, Copy)]
496pub struct IllegalEcallSystemInstructionHandler;
497
498impl<Reg, Regs, Memory, PC, CustomError>
499    SystemInstructionHandler<Reg, Regs, Memory, PC, CustomError>
500    for IllegalEcallSystemInstructionHandler
501where
502    Reg: Register,
503    Regs: RegisterFile<Reg>,
504    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
505{
506    fn handle_ecall(
507        &mut self,
508        _regs: &mut Regs,
509        _memory: &mut Memory,
510        program_counter: &mut PC,
511    ) -> Result<ControlFlow<()>, ExecutionError<Reg::Type, CustomError>> {
512        Err(ExecutionError::IllegalInstruction {
513            address: program_counter.old_pc(size_of::<u32>() as u8),
514        })
515    }
516}