Skip to main content

ab_riscv_benchmarks/
host_utils.rs

1extern crate alloc;
2
3use ab_blake3::{CHUNK_LEN, OUT_LEN};
4use ab_contract_file::instruction::{ContractInstruction, ContractRegister};
5use ab_core_primitives::ed25519::{Ed25519PublicKey, Ed25519Signature};
6use ab_io_type::IoType;
7use ab_io_type::bool::Bool;
8use ab_riscv_interpreter::basic::{BasicInterpreterState, IgnoreEcallSystemInstructionHandler};
9use ab_riscv_interpreter::prelude::*;
10use ab_riscv_primitives::prelude::*;
11use alloc::boxed::Box;
12use alloc::vec::Vec;
13use core::hint::cold_path;
14use core::mem::offset_of;
15use core::ops::ControlFlow;
16
17/// Contract file bytes
18pub const RISCV_CONTRACT_BYTES: &[u8] = cfg_select! {
19    target_env = "abundance" => {
20        &[]
21    }
22    _ => {
23        include_bytes!(env!("CONTRACT_PATH"))
24    }
25};
26
27// TODO: Generate similar helper data structures in the `#[contract]` macro itself, maybe introduce
28//  `SimpleInternalArgs` data trait for this or something
29/// Helper data structure for [`Benchmarks::blake3_hash_chunk()`] method
30///
31/// [`Benchmarks::blake3_hash_chunk()`]: crate::Benchmarks::blake3_hash_chunk
32#[derive(Debug, Copy, Clone)]
33#[repr(C)]
34pub struct Blake3HashChunkInternalArgs {
35    chunk_ptr: u64,
36    chunk_size: u32,
37    chunk_capacity: u32,
38    result_ptr: u64,
39    chunk: [u8; CHUNK_LEN],
40    result: [u8; OUT_LEN],
41}
42
43impl Blake3HashChunkInternalArgs {
44    /// Create a new instance
45    pub fn new(internal_args_addr: u64, chunk: [u8; CHUNK_LEN]) -> Self {
46        Self {
47            chunk_ptr: internal_args_addr + offset_of!(Self, chunk) as u64,
48            chunk_size: CHUNK_LEN as u32,
49            chunk_capacity: CHUNK_LEN as u32,
50            result_ptr: internal_args_addr + offset_of!(Self, result) as u64,
51            chunk,
52            result: [0; _],
53        }
54    }
55
56    /// Extract result
57    pub fn result(&self) -> [u8; OUT_LEN] {
58        self.result
59    }
60}
61
62// TODO: Generate similar helper data structures in the `#[contract]` macro itself, maybe introduce
63//  `SimpleInternalArgs` data trait for this or something
64/// Helper data structure for [`Benchmarks::ed25519_verify()`] method
65///
66/// [`Benchmarks::ed25519_verify()`]: crate::Benchmarks::ed25519_verify
67#[derive(Debug, Copy, Clone)]
68#[repr(C)]
69pub struct Ed25519VerifyInternalArgs {
70    pub public_key_ptr: u64,
71    pub public_key_size: u32,
72    pub public_key_capacity: u32,
73    pub signature_ptr: u64,
74    pub signature_size: u32,
75    pub signature_capacity: u32,
76    pub message_ptr: u64,
77    pub message_size: u32,
78    pub message_capacity: u32,
79    pub result_ptr: u64,
80    pub public_key: Ed25519PublicKey,
81    pub signature: Ed25519Signature,
82    pub message: [u8; OUT_LEN],
83    pub result: Bool,
84}
85
86impl Ed25519VerifyInternalArgs {
87    /// Create a new instance
88    pub fn new(
89        internal_args_addr: u64,
90        public_key: Ed25519PublicKey,
91        signature: Ed25519Signature,
92        message: [u8; OUT_LEN],
93    ) -> Self {
94        Self {
95            public_key_ptr: internal_args_addr + offset_of!(Self, public_key) as u64,
96            public_key_size: Ed25519PublicKey::SIZE as u32,
97            public_key_capacity: Ed25519PublicKey::SIZE as u32,
98            signature_ptr: internal_args_addr + offset_of!(Self, signature) as u64,
99            signature_size: Ed25519Signature::SIZE as u32,
100            signature_capacity: Ed25519Signature::SIZE as u32,
101            message_ptr: internal_args_addr + offset_of!(Self, message) as u64,
102            message_size: OUT_LEN as u32,
103            message_capacity: OUT_LEN as u32,
104            result_ptr: internal_args_addr + offset_of!(Self, result) as u64,
105            public_key,
106            signature,
107            message,
108            result: Bool::new(false),
109        }
110    }
111
112    /// Extract result
113    pub fn result(&self) -> Bool {
114        self.result
115    }
116}
117
118/// Simple test memory implementation
119#[derive(Debug, Copy, Clone)]
120#[repr(align(16))]
121pub struct TestMemory<const BASE_ADDR: u64, const SIZE: usize> {
122    data: [u8; SIZE],
123}
124
125impl<const BASE_ADDR: u64, const SIZE: usize> VirtualMemory for TestMemory<BASE_ADDR, SIZE> {
126    #[inline(always)]
127    fn read<T>(&self, address: u64) -> Result<T, VirtualMemoryError>
128    where
129        T: BasicInt,
130    {
131        let offset = address.wrapping_sub(BASE_ADDR);
132
133        if offset.saturating_add(size_of::<T>() as u64) > self.data.len() as u64 {
134            cold_path();
135            return Err(VirtualMemoryError::OutOfBoundsRead { address });
136        }
137
138        // SAFETY: Only reading basic integers from initialized memory
139        unsafe {
140            Ok(self
141                .data
142                .as_ptr()
143                .cast::<T>()
144                .byte_add(offset as usize)
145                .read_unaligned())
146        }
147    }
148
149    #[inline(always)]
150    unsafe fn read_unchecked<T>(&self, address: u64) -> T
151    where
152        T: BasicInt,
153    {
154        // SAFETY: Guaranteed by function contract
155        unsafe {
156            let offset = address.unchecked_sub(BASE_ADDR) as usize;
157            self.data
158                .as_ptr()
159                .cast::<T>()
160                .byte_add(offset)
161                .read_unaligned()
162        }
163    }
164
165    fn read_slice(&self, address: u64, len: u32) -> Result<&[u8], VirtualMemoryError> {
166        let offset = address.wrapping_sub(BASE_ADDR);
167
168        if offset > self.data.len() as u64 {
169            cold_path();
170            return Err(VirtualMemoryError::OutOfBoundsRead { address });
171        }
172
173        self.data
174            .get(offset as usize..)
175            .and_then(|data| data.get(..len as usize))
176            .ok_or(VirtualMemoryError::OutOfBoundsRead { address })
177    }
178
179    fn read_slice_up_to(&self, address: u64, len: u32) -> &[u8] {
180        let offset = address.wrapping_sub(BASE_ADDR);
181
182        if offset > self.data.len() as u64 {
183            cold_path();
184            return &[];
185        }
186
187        let remaining = self.data.get(offset as usize..).unwrap_or_default();
188        remaining.get(..len as usize).unwrap_or(remaining)
189    }
190
191    #[inline(always)]
192    fn write<T>(&mut self, address: u64, value: T) -> Result<(), VirtualMemoryError>
193    where
194        T: BasicInt,
195    {
196        let offset = address.wrapping_sub(BASE_ADDR);
197
198        if offset.saturating_add(size_of::<T>() as u64) > self.data.len() as u64 {
199            cold_path();
200            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
201        }
202
203        // SAFETY: Only writing basic integers to initialized memory
204        unsafe {
205            self.data
206                .as_mut_ptr()
207                .cast::<T>()
208                .byte_add(offset as usize)
209                .write_unaligned(value);
210        }
211
212        Ok(())
213    }
214
215    fn write_slice(&mut self, address: u64, data: &[u8]) -> Result<(), VirtualMemoryError> {
216        let offset = address.wrapping_sub(BASE_ADDR);
217
218        if offset > self.data.len() as u64 {
219            cold_path();
220            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
221        }
222
223        let len = data.len();
224        let Some(target_data) = self
225            .data
226            .get_mut(offset as usize..)
227            .and_then(|data| data.get_mut(..len))
228        else {
229            cold_path();
230            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
231        };
232
233        target_data.copy_from_slice(data);
234
235        Ok(())
236    }
237}
238
239impl<const BASE_ADDR: u64, const SIZE: usize> Default for TestMemory<BASE_ADDR, SIZE> {
240    fn default() -> Self {
241        Self { data: [0; SIZE] }
242    }
243}
244
245impl<const BASE_ADDR: u64, const SIZE: usize> TestMemory<BASE_ADDR, SIZE> {
246    /// Get a mutable slice of memory
247    pub fn get_mut_bytes(
248        &mut self,
249        address: u64,
250        size: usize,
251    ) -> Result<&mut [u8], VirtualMemoryError> {
252        let Some(offset) = address.checked_sub(BASE_ADDR) else {
253            cold_path();
254            return Err(VirtualMemoryError::OutOfBoundsRead { address });
255        };
256        let offset = offset as usize;
257
258        if offset + size > self.data.len() {
259            cold_path();
260            return Err(VirtualMemoryError::OutOfBoundsRead { address });
261        }
262
263        Ok(&mut self.data[offset..][..size])
264    }
265}
266
267/// Lazy instruction fetcher implementation
268#[derive(Debug, Copy, Clone)]
269pub struct LazyInstructionFetcher {
270    return_trap_address: u64,
271    pc: u64,
272}
273
274impl<Memory> ProgramCounter<u64, Memory> for LazyInstructionFetcher
275where
276    Memory: VirtualMemory,
277{
278    #[inline(always)]
279    fn get_pc(&self) -> u64 {
280        self.pc
281    }
282
283    #[inline]
284    fn set_pc(
285        &mut self,
286        memory: &Memory,
287        pc: u64,
288    ) -> Result<ControlFlow<()>, ProgramCounterError<u64>> {
289        if pc == self.return_trap_address {
290            cold_path();
291            return Ok(ControlFlow::Break(()));
292        }
293
294        if !pc.is_multiple_of(u64::from(
295            ContractInstruction::<ContractRegister>::alignment(),
296        )) {
297            cold_path();
298            return Err(ProgramCounterError::UnalignedInstruction { address: pc });
299        }
300
301        // Note: This will not allow reading a 16-bit instruction at the very end of memory range,
302        // but that is going to be the case here anyway since code is followed by read-write memory
303        // anyway
304        if let Err(error) = memory.read::<u32>(pc) {
305            cold_path();
306            return Err(error.into());
307        }
308
309        self.pc = pc;
310
311        Ok(ControlFlow::Continue(()))
312    }
313}
314
315impl<Memory> InstructionFetcher<ContractInstruction, Memory> for LazyInstructionFetcher
316where
317    Memory: VirtualMemory,
318{
319    #[inline]
320    fn fetch_instruction(
321        &mut self,
322        memory: &Memory,
323    ) -> Result<FetchInstructionResult<ContractInstruction>, ExecutionError<u64>> {
324        // SAFETY: Constructor guarantees that the last instruction is a jump, which means going
325        // through `Self::set_pc()` method does the necessary bounds check and advancing forward by
326        // one instruction can't result in out-of-bounds access.
327        let instruction = unsafe { memory.read_unchecked(self.pc) };
328        // SAFETY: All instructions are valid, according to the constructor contract
329        let instruction =
330            unsafe { ContractInstruction::try_decode(instruction).unwrap_unchecked() };
331
332        self.pc += u64::from(instruction.size());
333
334        Ok(FetchInstructionResult::Instruction(instruction))
335    }
336}
337
338impl LazyInstructionFetcher {
339    /// Create a new instance.
340    ///
341    /// `return_trap_address` is the address at which the interpreter will stop execution
342    /// (gracefully).
343    ///
344    /// # Safety
345    /// The program counter must be valid and aligned, the instructions processed must be valid and
346    /// end with a jump instruction.
347    #[inline(always)]
348    pub unsafe fn new(return_trap_address: u64, pc: u64) -> Self {
349        Self {
350            return_trap_address,
351            pc,
352        }
353    }
354}
355
356/// Eager instruction handler eagerly decodes all instructions upfront
357#[derive(Debug, Clone)]
358#[repr(C, align(16))]
359pub struct EagerTestInstructionFetcher {
360    instructions: Box<[ContractInstruction]>,
361    instruction_offset: usize,
362    return_trap_address: u64,
363    base_addr: u64,
364}
365
366impl<Memory> ProgramCounter<u64, Memory> for EagerTestInstructionFetcher
367where
368    Memory: VirtualMemory,
369{
370    #[inline(always)]
371    fn get_pc(&self) -> u64 {
372        self.base_addr + self.instruction_offset as u64 * size_of::<u16>() as u64
373    }
374
375    #[inline]
376    fn set_pc(
377        &mut self,
378        _memory: &Memory,
379        pc: u64,
380    ) -> Result<ControlFlow<()>, ProgramCounterError<u64>> {
381        let address = pc;
382
383        if address == self.return_trap_address {
384            cold_path();
385            return Ok(ControlFlow::Break(()));
386        }
387
388        if !address.is_multiple_of(size_of::<u16>() as u64) {
389            cold_path();
390            return Err(ProgramCounterError::UnalignedInstruction { address });
391        }
392
393        let Some(offset) = address.checked_sub(self.base_addr) else {
394            cold_path();
395            return Err(ProgramCounterError::MemoryAccess(
396                VirtualMemoryError::OutOfBoundsRead { address },
397            ));
398        };
399        let offset = offset as usize;
400        let instruction_offset = offset / size_of::<u16>();
401
402        if instruction_offset >= self.instructions.len() {
403            cold_path();
404            return Err(VirtualMemoryError::OutOfBoundsRead { address }.into());
405        }
406
407        self.instruction_offset = instruction_offset;
408
409        Ok(ControlFlow::Continue(()))
410    }
411}
412
413impl<Memory> InstructionFetcher<ContractInstruction, Memory> for EagerTestInstructionFetcher
414where
415    Memory: VirtualMemory,
416{
417    #[inline(always)]
418    fn fetch_instruction(
419        &mut self,
420        _memory: &Memory,
421    ) -> Result<FetchInstructionResult<ContractInstruction>, ExecutionError<u64>> {
422        // SAFETY: Constructor guarantees that the last instruction is a jump, which means going
423        // through `Self::set_pc()` method does the necessary bounds check and advancing forward by
424        // one instruction can't result in out-of-bounds access.
425        let instruction = *unsafe { self.instructions.get_unchecked(self.instruction_offset) };
426        self.instruction_offset += usize::from(instruction.size()) / size_of::<u16>();
427
428        Ok(FetchInstructionResult::Instruction(instruction))
429    }
430}
431
432impl EagerTestInstructionFetcher {
433    /// Create a new instance with the specified instructions and base address.
434    ///
435    /// Instructions are decoded during instantiation of the instruction fetcher, and the base
436    /// address corresponds to the first instruction.
437    ///
438    /// `return_trap_address` is the address at which the interpreter will stop execution
439    /// (gracefully).
440    ///
441    /// # Safety
442    /// The program counter must be valid and aligned, the instructions processed must end with a
443    /// jump instruction.
444    #[inline(always)]
445    pub unsafe fn new(
446        instructions: &[u8],
447        return_trap_address: u64,
448        base_addr: u64,
449        pc: u64,
450    ) -> Self {
451        let mut decoded_instructions = Vec::with_capacity(instructions.len() / size_of::<u16>());
452
453        let mut offset = 0;
454        while let Some(instruction_bytes) = instructions.get(offset..offset + size_of::<u32>()) {
455            let decoded_instruction = u32::from_le_bytes([
456                instruction_bytes[0],
457                instruction_bytes[1],
458                instruction_bytes[2],
459                instruction_bytes[3],
460            ]);
461            // Use `Unimp` as a fallback, though contract is expected to only contain legal
462            // instructions
463            let decoded_instruction = Instruction::try_decode(decoded_instruction).unwrap_or(
464                ContractInstruction::Unimp {
465                    rs1: Register::ZERO,
466                    rs2: Register::ZERO,
467                },
468            );
469            decoded_instructions.push(decoded_instruction);
470            match decoded_instruction.size() {
471                2 => {
472                    offset += 2;
473                }
474                4 => {
475                    // The second half of a 32-bit instruction is a valid offset and may or may not
476                    // decode to a valid instruction on its own. Try to decode it but ignore
477                    // decoding failures.
478
479                    offset += 2;
480
481                    // Could be both 16-bit and 32-bit instruction, need to handle end of the
482                    // instruction stream
483                    let instruction_word = if let Some(instruction_bytes) =
484                        instructions.get(offset..offset + size_of::<u32>())
485                    {
486                        u32::from_le_bytes([
487                            instruction_bytes[0],
488                            instruction_bytes[1],
489                            instruction_bytes[2],
490                            instruction_bytes[3],
491                        ])
492                    } else {
493                        u32::from_le_bytes([instruction_bytes[2], instruction_bytes[3], 0, 0])
494                    };
495
496                    decoded_instructions.push(Instruction::try_decode(instruction_word).unwrap_or(
497                        ContractInstruction::Unimp {
498                            rs1: Register::ZERO,
499                            rs2: Register::ZERO,
500                        },
501                    ));
502                    offset += 2;
503                }
504                instruction_size => {
505                    unreachable!("Invalid instruction size {instruction_size}, expected 2 or 4");
506                }
507            }
508        }
509
510        let remainder_bytes = instructions.get(offset..).unwrap_or(&[]);
511
512        if remainder_bytes.len() == size_of::<u16>() {
513            let instruction_word =
514                u32::from_le_bytes([remainder_bytes[0], remainder_bytes[1], 0, 0]);
515            decoded_instructions.push(Instruction::try_decode(instruction_word).unwrap_or(
516                ContractInstruction::Unimp {
517                    rs1: Register::ZERO,
518                    rs2: Register::ZERO,
519                },
520            ));
521        }
522
523        Self {
524            instructions: decoded_instructions.into_boxed_slice(),
525            instruction_offset: (pc - base_addr) as usize / size_of::<u16>(),
526            return_trap_address,
527            base_addr,
528        }
529    }
530}
531
532/// Execute [`ContractInstruction`]s
533pub fn execute<Regs, Memory, IF>(
534    state: &mut BasicInterpreterState<Regs, (), Memory, IF, IgnoreEcallSystemInstructionHandler>,
535) -> Result<(), ExecutionError<u64>>
536where
537    Regs: RegisterFile<<ContractInstruction as Instruction>::Reg>,
538    Memory: VirtualMemory,
539    IF: InstructionFetcher<ContractInstruction, Memory>,
540{
541    loop {
542        let instruction = match state.instruction_fetcher.fetch_instruction(&state.memory) {
543            Ok(FetchInstructionResult::Instruction(instruction)) => instruction,
544            Ok(FetchInstructionResult::ControlFlow(ControlFlow::Continue(()))) => {
545                cold_path();
546                continue;
547            }
548            Ok(FetchInstructionResult::ControlFlow(ControlFlow::Break(()))) => {
549                cold_path();
550                break;
551            }
552            Err(error) => {
553                cold_path();
554                return Err(error);
555            }
556        };
557
558        let Rs1Rs2Operands { rs1, rs2 } = instruction.get_rs1_rs2_operands();
559        let rs1rs2_values = Rs1Rs2OperandValues {
560            rs1_value: state.regs.read(rs1),
561            rs2_value: state.regs.read(rs2),
562        };
563
564        match instruction.execute(
565            rs1rs2_values,
566            &mut state.regs,
567            &mut state.ext_state,
568            &mut state.memory,
569            &mut state.instruction_fetcher,
570            &mut state.system_instruction_handler,
571        ) {
572            Ok(ControlFlow::Continue((rd, rd_value))) => {
573                state.regs.write(rd, rd_value);
574                continue;
575            }
576            Ok(ControlFlow::Break(())) => {
577                cold_path();
578                break;
579            }
580            Err(error) => {
581                cold_path();
582                return Err(error);
583            }
584        }
585    }
586
587    Ok(())
588}