Skip to main content

ab_riscv_interpreter/v/zvexx/store/
zvexx_store_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zvexx::load::zvexx_load_helpers::{
5    check_register_group_alignment, mask_bit, read_group_element, snapshot_mask,
6};
7use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
8use crate::{ExecutionError, ProgramCounter, VirtualMemory, VirtualMemoryError};
9use ab_riscv_primitives::prelude::*;
10use core::fmt;
11use core::hint::cold_path;
12
13/// Interpret `buf[..index_eew.bytes()]` as a little-endian unsigned integer and return it as
14/// `u64`. Used to convert a packed index element into a byte offset.
15///
16/// # Safety
17/// `index_eew.bytes() <= Eew::MAX_BYTES`, which is always true by construction.
18#[inline(always)]
19unsafe fn index_buf_to_u64(buf: [u8; Eew::MAX_BYTES as usize], index_eew: Eew) -> u64 {
20    match index_eew {
21        Eew::E8 => u64::from(buf[0]),
22        Eew::E16 => u64::from(u16::from_le_bytes([buf[0], buf[1]])),
23        Eew::E32 => u64::from(u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])),
24        Eew::E64 => u64::from_le_bytes(buf),
25    }
26}
27
28/// Write `eew`-sized data from `buf[..eew.bytes()]` to memory at `addr` (little-endian)
29#[inline(always)]
30fn write_mem_element(
31    memory: &mut impl VirtualMemory,
32    addr: u64,
33    eew: Eew,
34    buf: [u8; Eew::MAX_BYTES as usize],
35) -> Result<(), VirtualMemoryError> {
36    memory.write_slice(addr, &buf[..usize::from(eew.bytes_width())])
37}
38
39/// Validate a segment store's destination register group.
40///
41/// Like [`validate_segment_registers`] but omits the v0-overlap check, since
42/// segment stores read `vs3` as a source and the source/v0 overlap restriction
43/// applies only to load destinations.
44#[inline(always)]
45#[doc(hidden)]
46pub fn validate_segment_store_registers<Reg, Memory, PC, CustomError>(
47    program_counter: &PC,
48    vs3: VReg,
49    group_regs: u8,
50    nf: Nf,
51) -> Result<(), ExecutionError<Reg::Type, CustomError>>
52where
53    Reg: Register,
54    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
55{
56    if let Err(error) =
57        check_register_group_alignment::<Reg, _, _, _>(program_counter, vs3, group_regs)
58    {
59        cold_path();
60        return Err(error);
61    }
62    let total =
63        u32::from(vs3.to_bits()) + u32::from(nf.fields_per_segment()) * u32::from(group_regs);
64    if total > 32 {
65        cold_path();
66        return Err(ExecutionError::IllegalInstruction {
67            address: program_counter.old_pc(INSTRUCTION_SIZE),
68        });
69    }
70    Ok(())
71}
72
73/// Execute a unit-stride or unit-stride segment store.
74///
75/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
76/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
77/// plain unit-stride store.
78///
79/// # Safety
80/// - `vs3.to_bits() % group_regs == 0`
81/// - `vs3.to_bits() + nf * group_regs <= 32`
82/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the source register
83///   group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL register
84///   count for the given `eew` and `vtype`)
85/// - When `vm=false`: `vs3` does not overlap `v0` (i.e. `vs3.to_bits() != 0`)
86#[inline(always)]
87#[expect(clippy::too_many_arguments, reason = "Internal API")]
88#[doc(hidden)]
89pub unsafe fn execute_unit_stride_store<Reg, ExtState, Memory, CustomError>(
90    ext_state: &mut ExtState,
91    memory: &mut Memory,
92    vs3: VReg,
93    vm: bool,
94    base: u64,
95    eew: Eew,
96    group_regs: u8,
97    nf: Nf,
98) -> Result<(), ExecutionError<Reg::Type, CustomError>>
99where
100    Reg: Register,
101    ExtState: VectorRegistersExt<Reg, CustomError>,
102    [(); ExtState::ELEN as usize]:,
103    [(); ExtState::VLEN as usize]:,
104    [(); ExtState::VLENB as usize]:,
105    Memory: VirtualMemory,
106    CustomError: fmt::Debug,
107{
108    let vl = ext_state.vl();
109    let vstart = ext_state.vstart();
110    let elem_bytes = eew.bytes_width();
111    let segment_stride = u64::from(nf.fields_per_segment() * elem_bytes);
112    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
113    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
114    for i in u32::from(vstart)..vl {
115        if !vm && !mask_bit(&mask_buf, i) {
116            continue;
117        }
118        let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
119        for f in 0..nf.fields_per_segment() {
120            let addr = elem_base.wrapping_add(u64::from(f * elem_bytes));
121            // SAFETY: Guaranteed by function contract
122            let field_base_reg =
123                unsafe { VReg::from_bits(vs3.to_bits() + f * group_regs).unwrap_unchecked() };
124            // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
125            //
126            // Let `elems_per_reg = VLENB / elem_bytes`.
127            // `i < vl <= group_regs * elems_per_reg` (precondition), so
128            // `i / elems_per_reg < group_regs`.
129            //
130            // `field_base_reg = vs3.to_bits() + f * group_regs`. Since `f < nf` and the
131            // precondition guarantees `vs3.to_bits() + nf * group_regs <= 32`:
132            // `field_base_reg + group_regs <= vs3.to_bits() + (f+1) * group_regs
133            //                             <= vs3.to_bits() + nf * group_regs <= 32`.
134            //
135            // Therefore,
136            // `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
137            let data =
138                unsafe { read_group_element(ext_state.read_vregs(), field_base_reg, i, eew) };
139            // Record the current element index in `vstart` so that, on a memory fault, the failing
140            // element can be identified and the operation can be restarted
141            if let Err(error) = write_mem_element(memory, addr, eew, data) {
142                cold_path();
143                ext_state.set_vstart(i as u16);
144                return Err(ExecutionError::MemoryAccess(error));
145            }
146        }
147    }
148    ext_state.reset_vstart();
149    Ok(())
150}
151
152/// Execute a strided or strided-segment store.
153///
154/// The address of element `i`, field `f` is:
155///   `base.wrapping_add(i.wrapping_mul(stride) as u64).wrapping_add(f * eew.bytes())`
156///
157/// `stride` is the raw XLEN register value reinterpreted as a signed integer, matching the RVV
158/// specification where the stride operand is a two's-complement signed offset.
159///
160/// # Safety
161/// - `vs3.to_bits() % group_regs == 0`
162/// - `vs3.to_bits() + nf * group_regs <= 32`
163/// - `vl <= group_regs * VLENB / eew.bytes()`
164/// - When `vm=false`: `vs3.to_bits() != 0`
165#[inline(always)]
166#[expect(clippy::too_many_arguments, reason = "Internal API")]
167#[doc(hidden)]
168pub unsafe fn execute_strided_store<Reg, ExtState, Memory, CustomError>(
169    ext_state: &mut ExtState,
170    memory: &mut Memory,
171    vs3: VReg,
172    vm: bool,
173    base: u64,
174    stride: i64,
175    eew: Eew,
176    group_regs: u8,
177    nf: Nf,
178) -> Result<(), ExecutionError<Reg::Type, CustomError>>
179where
180    Reg: Register,
181    ExtState: VectorRegistersExt<Reg, CustomError>,
182    [(); ExtState::ELEN as usize]:,
183    [(); ExtState::VLEN as usize]:,
184    [(); ExtState::VLENB as usize]:,
185    Memory: VirtualMemory,
186    CustomError: fmt::Debug,
187{
188    let vl = ext_state.vl();
189    let vstart = ext_state.vstart();
190    let elem_bytes = eew.bytes_width();
191    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
192    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
193    for i in u32::from(vstart)..vl {
194        if !vm && !mask_bit(&mask_buf, i) {
195            continue;
196        }
197        let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
198        for f in 0..nf.fields_per_segment() {
199            let addr = elem_base.wrapping_add(u64::from(f * elem_bytes));
200            // SAFETY: Guaranteed by function contract
201            let field_base_reg =
202                unsafe { VReg::from_bits(vs3.to_bits() + f * group_regs).unwrap_unchecked() };
203            // SAFETY: same argument as `execute_unit_stride_store`; `field_base_reg +
204            // i / elems_per_reg < field_base_reg + group_regs <= vs3.to_bits() + nf *
205            // group_regs <= 32`.
206            let data =
207                unsafe { read_group_element(ext_state.read_vregs(), field_base_reg, i, eew) };
208            // Record the current element index in `vstart` so that, on a memory fault, the failing
209            // element can be identified and the operation can be restarted
210            if let Err(error) = write_mem_element(memory, addr, eew, data) {
211                cold_path();
212                ext_state.set_vstart(i as u16);
213                return Err(ExecutionError::MemoryAccess(error));
214            }
215        }
216    }
217    ext_state.reset_vstart();
218    Ok(())
219}
220
221/// Execute an indexed (unordered or ordered) store or indexed-segment store.
222///
223/// The effective address of element `i`, field `f` is:
224///   `base + index[i] + f * eew.bytes()`
225/// where `index[i]` is element `i` of the index register group `vs2`, interpreted as an
226/// unsigned integer of width `index_eew`.
227///
228/// `data_eew` is the element width of the data being stored (from `vtype.vsew`).
229/// `index_eew` is the element width of the indices (from the instruction encoding).
230///
231/// # Safety
232/// - `vs3.to_bits() % data_group_regs == 0`
233/// - `vs3.to_bits() + nf * data_group_regs <= 32`
234/// - `vs2` register group is aligned and fits within `[0, 32)` (caller must verify via
235///   `check_register_group_alignment` before calling)
236/// - `vl <= data_group_regs * VLENB / data_eew.bytes()`
237/// - `vl <= index_group_regs * VLENB / index_eew.bytes()` (caller must verify)
238/// - When `vm=false`: `vs3.to_bits() != 0`
239#[inline(always)]
240#[expect(clippy::too_many_arguments, reason = "Internal API")]
241#[doc(hidden)]
242pub unsafe fn execute_indexed_store<Reg, ExtState, Memory, CustomError>(
243    ext_state: &mut ExtState,
244    memory: &mut Memory,
245    vs3: VReg,
246    vs2: VReg,
247    vm: bool,
248    base: u64,
249    data_eew: Eew,
250    index_eew: Eew,
251    data_group_regs: u8,
252    nf: Nf,
253) -> Result<(), ExecutionError<Reg::Type, CustomError>>
254where
255    Reg: Register,
256    ExtState: VectorRegistersExt<Reg, CustomError>,
257    [(); ExtState::ELEN as usize]:,
258    [(); ExtState::VLEN as usize]:,
259    [(); ExtState::VLENB as usize]:,
260    Memory: VirtualMemory,
261    CustomError: fmt::Debug,
262{
263    let vl = ext_state.vl();
264    let vstart = ext_state.vstart();
265    let data_elem_bytes = data_eew.bytes_width();
266    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
267    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
268    for i in u32::from(vstart)..vl {
269        if !vm && !mask_bit(&mask_buf, i) {
270            continue;
271        }
272        // SAFETY: `i < vl <= index_group_regs * VLENB / index_eew.bytes()` (precondition), so
273        // `vs2.to_bits() + i / (VLENB / index_eew.bytes()) <
274        //     vs2.to_bits() + index_group_regs <= 32`
275        let index_buf = unsafe { read_group_element(ext_state.read_vregs(), vs2, i, index_eew) };
276        // SAFETY: `index_eew.bytes() <= Eew::MAX_BYTES` always holds.
277        let offset = unsafe { index_buf_to_u64(index_buf, index_eew) };
278        let elem_base = base.wrapping_add(offset);
279        for f in 0..nf.fields_per_segment() {
280            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
281            // SAFETY: Guaranteed by function contract
282            let field_base_reg =
283                unsafe { VReg::from_bits(vs3.to_bits() + f * data_group_regs).unwrap_unchecked() };
284            // SAFETY: `i < vl <= data_group_regs * VLENB / data_eew.bytes()` (precondition), so
285            // `field_base_reg + i / elems_per_reg < field_base_reg + data_group_regs
286            //                                    <= vs3.to_bits() + nf * data_group_regs <= 32`.
287            let data =
288                unsafe { read_group_element(ext_state.read_vregs(), field_base_reg, i, data_eew) };
289            // Record the current element index in `vstart` so that, on a memory fault, the failing
290            // element can be identified and the operation can be restarted
291            if let Err(error) = write_mem_element(memory, addr, data_eew, data) {
292                cold_path();
293                ext_state.set_vstart(i as u16);
294                return Err(ExecutionError::MemoryAccess(error));
295            }
296        }
297    }
298    ext_state.reset_vstart();
299    Ok(())
300}