Skip to main content

ab_riscv_benchmarks/
host_utils.rs

1extern crate alloc;
2
3use ab_blake3::{CHUNK_LEN, OUT_LEN};
4use ab_contract_file::instruction::{ContractInstruction, ContractRegister};
5use ab_core_primitives::ed25519::{Ed25519PublicKey, Ed25519Signature};
6use ab_io_type::bool::Bool;
7use ab_riscv_interpreter::prelude::*;
8use ab_riscv_primitives::prelude::*;
9use alloc::boxed::Box;
10use alloc::vec::Vec;
11use core::hint::cold_path;
12use core::mem::offset_of;
13use core::ops::ControlFlow;
14
15/// Contract file bytes
16pub const RISCV_CONTRACT_BYTES: &[u8] = cfg_select! {
17    target_env = "abundance" => {
18        &[]
19    }
20    _ => {
21        include_bytes!(env!("CONTRACT_PATH"))
22    }
23};
24
25// TODO: Generate similar helper data structures in the `#[contract]` macro itself, maybe introduce
26//  `SimpleInternalArgs` data trait for this or something
27/// Helper data structure for [`Benchmarks::blake3_hash_chunk()`] method
28///
29/// [`Benchmarks::blake3_hash_chunk()`]: crate::Benchmarks::blake3_hash_chunk
30#[derive(Debug, Copy, Clone)]
31#[repr(C)]
32pub struct Blake3HashChunkInternalArgs {
33    chunk_ptr: u64,
34    chunk_size: u32,
35    chunk_capacity: u32,
36    result_ptr: u64,
37    chunk: [u8; CHUNK_LEN],
38    result: [u8; OUT_LEN],
39}
40
41impl Blake3HashChunkInternalArgs {
42    /// Create a new instance
43    pub fn new(internal_args_addr: u64, chunk: [u8; CHUNK_LEN]) -> Self {
44        Self {
45            chunk_ptr: internal_args_addr + offset_of!(Self, chunk) as u64,
46            chunk_size: CHUNK_LEN as u32,
47            chunk_capacity: CHUNK_LEN as u32,
48            result_ptr: internal_args_addr + offset_of!(Self, result) as u64,
49            chunk,
50            result: [0; _],
51        }
52    }
53
54    /// Extract result
55    pub fn result(&self) -> [u8; OUT_LEN] {
56        self.result
57    }
58}
59
60// TODO: Generate similar helper data structures in the `#[contract]` macro itself, maybe introduce
61//  `SimpleInternalArgs` data trait for this or something
62/// Helper data structure for [`Benchmarks::ed25519_verify()`] method
63///
64/// [`Benchmarks::ed25519_verify()`]: crate::Benchmarks::ed25519_verify
65#[derive(Debug, Copy, Clone)]
66#[repr(C)]
67pub struct Ed25519VerifyInternalArgs {
68    pub public_key_ptr: u64,
69    pub public_key_size: u32,
70    pub public_key_capacity: u32,
71    pub signature_ptr: u64,
72    pub signature_size: u32,
73    pub signature_capacity: u32,
74    pub message_ptr: u64,
75    pub message_size: u32,
76    pub message_capacity: u32,
77    pub result_ptr: u64,
78    pub public_key: Ed25519PublicKey,
79    pub signature: Ed25519Signature,
80    pub message: [u8; OUT_LEN],
81    pub result: Bool,
82}
83
84impl Ed25519VerifyInternalArgs {
85    /// Create a new instance
86    pub fn new(
87        internal_args_addr: u64,
88        public_key: Ed25519PublicKey,
89        signature: Ed25519Signature,
90        message: [u8; OUT_LEN],
91    ) -> Self {
92        Self {
93            public_key_ptr: internal_args_addr + offset_of!(Self, public_key) as u64,
94            public_key_size: Ed25519PublicKey::SIZE as u32,
95            public_key_capacity: Ed25519PublicKey::SIZE as u32,
96            signature_ptr: internal_args_addr + offset_of!(Self, signature) as u64,
97            signature_size: Ed25519Signature::SIZE as u32,
98            signature_capacity: Ed25519Signature::SIZE as u32,
99            message_ptr: internal_args_addr + offset_of!(Self, message) as u64,
100            message_size: OUT_LEN as u32,
101            message_capacity: OUT_LEN as u32,
102            result_ptr: internal_args_addr + offset_of!(Self, result) as u64,
103            public_key,
104            signature,
105            message,
106            result: Bool::new(false),
107        }
108    }
109
110    /// Extract result
111    pub fn result(&self) -> Bool {
112        self.result
113    }
114}
115
116/// Simple test memory implementation
117#[derive(Debug, Copy, Clone)]
118#[repr(align(16))]
119pub struct TestMemory<const BASE_ADDR: u64, const SIZE: usize> {
120    data: [u8; SIZE],
121}
122
123impl<const BASE_ADDR: u64, const SIZE: usize> VirtualMemory for TestMemory<BASE_ADDR, SIZE> {
124    #[inline(always)]
125    fn read<T>(&self, address: u64) -> Result<T, VirtualMemoryError>
126    where
127        T: BasicInt,
128    {
129        let offset = address.wrapping_sub(BASE_ADDR);
130
131        if offset.saturating_add(size_of::<T>() as u64) > self.data.len() as u64 {
132            cold_path();
133            return Err(VirtualMemoryError::OutOfBoundsRead { address });
134        }
135
136        // SAFETY: Only reading basic integers from initialized memory
137        unsafe {
138            Ok(self
139                .data
140                .as_ptr()
141                .cast::<T>()
142                .byte_add(offset as usize)
143                .read_unaligned())
144        }
145    }
146
147    #[inline(always)]
148    unsafe fn read_unchecked<T>(&self, address: u64) -> T
149    where
150        T: BasicInt,
151    {
152        // SAFETY: Guaranteed by function contract
153        unsafe {
154            let offset = address.unchecked_sub(BASE_ADDR) as usize;
155            self.data
156                .as_ptr()
157                .cast::<T>()
158                .byte_add(offset)
159                .read_unaligned()
160        }
161    }
162
163    fn read_slice(&self, address: u64, len: u32) -> Result<&[u8], VirtualMemoryError> {
164        let offset = address.wrapping_sub(BASE_ADDR);
165
166        if offset > self.data.len() as u64 {
167            cold_path();
168            return Err(VirtualMemoryError::OutOfBoundsRead { address });
169        }
170
171        self.data
172            .get(offset as usize..)
173            .and_then(|data| data.get(..len as usize))
174            .ok_or(VirtualMemoryError::OutOfBoundsRead { address })
175    }
176
177    fn read_slice_up_to(&self, address: u64, len: u32) -> &[u8] {
178        let offset = address.wrapping_sub(BASE_ADDR);
179
180        if offset > self.data.len() as u64 {
181            cold_path();
182            return &[];
183        }
184
185        let remaining = self.data.get(offset as usize..).unwrap_or_default();
186        remaining.get(..len as usize).unwrap_or(remaining)
187    }
188
189    #[inline(always)]
190    fn write<T>(&mut self, address: u64, value: T) -> Result<(), VirtualMemoryError>
191    where
192        T: BasicInt,
193    {
194        let offset = address.wrapping_sub(BASE_ADDR);
195
196        if offset.saturating_add(size_of::<T>() as u64) > self.data.len() as u64 {
197            cold_path();
198            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
199        }
200
201        // SAFETY: Only writing basic integers to initialized memory
202        unsafe {
203            self.data
204                .as_mut_ptr()
205                .cast::<T>()
206                .byte_add(offset as usize)
207                .write_unaligned(value);
208        }
209
210        Ok(())
211    }
212
213    fn write_slice(&mut self, address: u64, data: &[u8]) -> Result<(), VirtualMemoryError> {
214        let offset = address.wrapping_sub(BASE_ADDR);
215
216        if offset > self.data.len() as u64 {
217            cold_path();
218            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
219        }
220
221        let len = data.len();
222        let Some(target_data) = self
223            .data
224            .get_mut(offset as usize..)
225            .and_then(|data| data.get_mut(..len))
226        else {
227            cold_path();
228            return Err(VirtualMemoryError::OutOfBoundsWrite { address });
229        };
230
231        target_data.copy_from_slice(data);
232
233        Ok(())
234    }
235}
236
237impl<const BASE_ADDR: u64, const SIZE: usize> Default for TestMemory<BASE_ADDR, SIZE> {
238    fn default() -> Self {
239        Self { data: [0; SIZE] }
240    }
241}
242
243impl<const BASE_ADDR: u64, const SIZE: usize> TestMemory<BASE_ADDR, SIZE> {
244    /// Get a mutable slice of memory
245    pub fn get_mut_bytes(
246        &mut self,
247        address: u64,
248        size: usize,
249    ) -> Result<&mut [u8], VirtualMemoryError> {
250        let Some(offset) = address.checked_sub(BASE_ADDR) else {
251            cold_path();
252            return Err(VirtualMemoryError::OutOfBoundsRead { address });
253        };
254        let offset = offset as usize;
255
256        let Some(slice) = self
257            .data
258            .get_mut(offset..)
259            .and_then(|data| data.get_mut(..size))
260        else {
261            cold_path();
262            return Err(VirtualMemoryError::OutOfBoundsRead { address });
263        };
264
265        Ok(slice)
266    }
267}
268
269/// Lazy instruction fetcher implementation
270#[derive(Debug, Copy, Clone)]
271pub struct LazyInstructionFetcher {
272    return_trap_address: u64,
273    pc: u64,
274}
275
276impl<Memory> ProgramCounter<u64, Memory> for LazyInstructionFetcher
277where
278    Memory: VirtualMemory,
279{
280    #[inline(always)]
281    fn get_pc(&self) -> u64 {
282        self.pc
283    }
284
285    #[inline]
286    fn set_pc(
287        &mut self,
288        memory: &Memory,
289        pc: u64,
290    ) -> Result<ControlFlow<()>, ProgramCounterError<u64>> {
291        if pc == self.return_trap_address {
292            cold_path();
293            return Ok(ControlFlow::Break(()));
294        }
295
296        if !pc.is_multiple_of(u64::from(
297            ContractInstruction::<ContractRegister>::alignment(),
298        )) {
299            cold_path();
300            return Err(ProgramCounterError::UnalignedInstruction { address: pc });
301        }
302
303        // Note: This will not allow reading a 16-bit instruction at the very end of memory range,
304        // but that is going to be the case here anyway since code is followed by read-write memory
305        // anyway
306        if let Err(error) = memory.read::<u32>(pc) {
307            cold_path();
308            return Err(error.into());
309        }
310
311        self.pc = pc;
312
313        Ok(ControlFlow::Continue(()))
314    }
315}
316
317impl<Memory> InstructionFetcher<ContractInstruction, Memory> for LazyInstructionFetcher
318where
319    Memory: VirtualMemory,
320{
321    #[inline]
322    fn fetch_instruction(
323        &mut self,
324        memory: &Memory,
325    ) -> Result<FetchInstructionResult<ContractInstruction>, ExecutionError<u64>> {
326        // SAFETY: Constructor guarantees that the last instruction is a jump, which means going
327        // through `Self::set_pc()` method does the necessary bounds check and advancing forward by
328        // one instruction can't result in out-of-bounds access.
329        let instruction = unsafe { memory.read_unchecked(self.pc) };
330        // SAFETY: All instructions are valid, according to the constructor contract
331        let instruction =
332            unsafe { ContractInstruction::try_decode(instruction).unwrap_unchecked() };
333
334        self.pc += u64::from(instruction.size());
335
336        Ok(FetchInstructionResult::Instruction(instruction))
337    }
338}
339
340impl LazyInstructionFetcher {
341    /// Create a new instance.
342    ///
343    /// `return_trap_address` is the address at which the interpreter will stop execution
344    /// (gracefully).
345    ///
346    /// # Safety
347    /// The program counter must be valid and aligned, the instructions processed must be valid and
348    /// end with a jump instruction.
349    #[inline(always)]
350    pub unsafe fn new(return_trap_address: u64, pc: u64) -> Self {
351        Self {
352            return_trap_address,
353            pc,
354        }
355    }
356}
357
358/// Eager instruction handler eagerly decodes all instructions upfront
359#[derive(Debug, Clone)]
360#[repr(C, align(16))]
361pub struct EagerTestInstructionFetcher {
362    decoded_instruction_byte_offset: usize,
363    // A simple raw pointer separate field helps LLVM with SROA and aliasing analysis, so it can
364    // retain this pointer in the native register
365    instructions: Box<[ContractInstruction]>,
366    base_addr: u64,
367    return_trap_address: u64,
368}
369
370impl<Memory> ProgramCounter<u64, Memory> for EagerTestInstructionFetcher
371where
372    Memory: VirtualMemory,
373{
374    #[inline(always)]
375    fn get_pc(&self) -> u64 {
376        self.base_addr
377            + self.decoded_instruction_byte_offset as u64 * size_of::<u16>() as u64
378                / size_of::<ContractInstruction>() as u64
379    }
380
381    #[inline]
382    fn set_pc(
383        &mut self,
384        _memory: &Memory,
385        pc: u64,
386    ) -> Result<ControlFlow<()>, ProgramCounterError<u64>> {
387        let address = pc;
388
389        if address == self.return_trap_address {
390            cold_path();
391            return Ok(ControlFlow::Break(()));
392        }
393
394        if !address.is_multiple_of(size_of::<u16>() as u64) {
395            cold_path();
396            return Err(ProgramCounterError::UnalignedInstruction { address });
397        }
398
399        let Some(offset) = address.checked_sub(self.base_addr) else {
400            cold_path();
401            return Err(ProgramCounterError::MemoryAccess(
402                VirtualMemoryError::OutOfBoundsRead { address },
403            ));
404        };
405        let offset = offset as usize;
406        let instruction_offset = offset / size_of::<u16>();
407
408        if instruction_offset >= self.instructions.len() {
409            cold_path();
410            return Err(VirtualMemoryError::OutOfBoundsRead { address }.into());
411        }
412
413        self.decoded_instruction_byte_offset =
414            instruction_offset * size_of::<ContractInstruction>();
415
416        Ok(ControlFlow::Continue(()))
417    }
418}
419
420impl<Memory> InstructionFetcher<ContractInstruction, Memory> for EagerTestInstructionFetcher
421where
422    Memory: VirtualMemory,
423{
424    #[inline(always)]
425    fn fetch_instruction(
426        &mut self,
427        _memory: &Memory,
428    ) -> Result<FetchInstructionResult<ContractInstruction>, ExecutionError<u64>> {
429        // SAFETY: Constructor guarantees that the last instruction is a jump, which means going
430        // through `Self::set_pc()` method does the necessary bounds check and advancing forward by
431        // one instruction can't result in out-of-bounds access.
432        let instruction = unsafe {
433            // Reading through byte offset rather than index to avoid extra computation (converting
434            // an index to a byte offset) on each fetch
435            self.instructions
436                .as_ptr()
437                .byte_add(self.decoded_instruction_byte_offset)
438                .read()
439        };
440        self.decoded_instruction_byte_offset +=
441            usize::from(instruction.size()) / size_of::<u16>() * size_of::<ContractInstruction>();
442
443        Ok(FetchInstructionResult::Instruction(instruction))
444    }
445}
446
447impl EagerTestInstructionFetcher {
448    /// Create a new instance with the specified instructions and base address.
449    ///
450    /// Instructions are decoded during instantiation of the instruction fetcher, and the base
451    /// address corresponds to the first instruction.
452    ///
453    /// `return_trap_address` is the address at which the interpreter will stop execution
454    /// (gracefully).
455    ///
456    /// # Safety
457    /// The program counter must be valid and aligned, the instructions processed must end with a
458    /// jump instruction.
459    #[inline(always)]
460    pub unsafe fn new(
461        instructions: &[u8],
462        return_trap_address: u64,
463        base_addr: u64,
464        pc: u64,
465    ) -> Self {
466        let mut decoded_instructions = Vec::with_capacity(instructions.len() / size_of::<u16>());
467
468        let mut offset = 0;
469        while let Some(instruction_bytes) = instructions.get(offset..offset + size_of::<u32>()) {
470            let decoded_instruction = u32::from_le_bytes([
471                instruction_bytes[0],
472                instruction_bytes[1],
473                instruction_bytes[2],
474                instruction_bytes[3],
475            ]);
476            // Use `Unimp` as a fallback, though contract is expected to only contain legal
477            // instructions
478            let decoded_instruction = Instruction::try_decode(decoded_instruction).unwrap_or(
479                ContractInstruction::Unimp {
480                    rs1: Register::ZERO,
481                    rs2: Register::ZERO,
482                },
483            );
484            decoded_instructions.push(decoded_instruction);
485            match decoded_instruction.size() {
486                2 => {
487                    offset += 2;
488                }
489                4 => {
490                    // The second half of a 32-bit instruction is a valid offset and may or may not
491                    // decode to a valid instruction on its own. Try to decode it but ignore
492                    // decoding failures.
493
494                    offset += 2;
495
496                    // Could be both 16-bit and 32-bit instruction, need to handle end of the
497                    // instruction stream
498                    let instruction_word = if let Some(instruction_bytes) =
499                        instructions.get(offset..offset + size_of::<u32>())
500                    {
501                        u32::from_le_bytes([
502                            instruction_bytes[0],
503                            instruction_bytes[1],
504                            instruction_bytes[2],
505                            instruction_bytes[3],
506                        ])
507                    } else {
508                        u32::from_le_bytes([instruction_bytes[2], instruction_bytes[3], 0, 0])
509                    };
510
511                    decoded_instructions.push(Instruction::try_decode(instruction_word).unwrap_or(
512                        ContractInstruction::Unimp {
513                            rs1: Register::ZERO,
514                            rs2: Register::ZERO,
515                        },
516                    ));
517                    offset += 2;
518                }
519                instruction_size => {
520                    unreachable!("Invalid instruction size {instruction_size}, expected 2 or 4");
521                }
522            }
523        }
524
525        let remainder_bytes = instructions.get(offset..).unwrap_or(&[]);
526
527        if remainder_bytes.len() == size_of::<u16>() {
528            let instruction_word =
529                u32::from_le_bytes([remainder_bytes[0], remainder_bytes[1], 0, 0]);
530            decoded_instructions.push(Instruction::try_decode(instruction_word).unwrap_or(
531                ContractInstruction::Unimp {
532                    rs1: Register::ZERO,
533                    rs2: Register::ZERO,
534                },
535            ));
536        }
537
538        Self {
539            decoded_instruction_byte_offset: (pc - base_addr) as usize / size_of::<u16>()
540                * size_of::<ContractInstruction>(),
541            instructions: decoded_instructions.into_boxed_slice(),
542            base_addr,
543            return_trap_address,
544        }
545    }
546}