Skip to main content

ab_riscv_interpreter/v/zvexx/store/
zvexx_store_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zvexx::load::zvexx_load_helpers::{
5    check_register_group_alignment, mask_bit, read_group_element, snapshot_mask,
6};
7use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
8use crate::{ExecutionError, ProgramCounter, VirtualMemory, VirtualMemoryError};
9use ab_riscv_primitives::prelude::*;
10use core::fmt;
11
12/// Interpret `buf[..index_eew.bytes()]` as a little-endian unsigned integer and return it as
13/// `u64`. Used to convert a packed index element into a byte offset.
14///
15/// # Safety
16/// `index_eew.bytes() <= Eew::MAX_BYTES`, which is always true by construction.
17#[inline(always)]
18unsafe fn index_buf_to_u64(buf: [u8; Eew::MAX_BYTES as usize], index_eew: Eew) -> u64 {
19    match index_eew {
20        Eew::E8 => u64::from(buf[0]),
21        Eew::E16 => u64::from(u16::from_le_bytes([buf[0], buf[1]])),
22        Eew::E32 => u64::from(u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])),
23        Eew::E64 => u64::from_le_bytes(buf),
24    }
25}
26
27/// Write `eew`-sized data from `buf[..eew.bytes()]` to memory at `addr` (little-endian)
28#[inline(always)]
29fn write_mem_element(
30    memory: &mut impl VirtualMemory,
31    addr: u64,
32    eew: Eew,
33    buf: [u8; Eew::MAX_BYTES as usize],
34) -> Result<(), VirtualMemoryError> {
35    memory.write_slice(addr, &buf[..usize::from(eew.bytes_width())])
36}
37
38/// Validate a segment store's destination register group.
39///
40/// Like [`validate_segment_registers`] but omits the v0-overlap check, since
41/// segment stores read `vs3` as a source and the source/v0 overlap restriction
42/// applies only to load destinations.
43#[inline(always)]
44#[doc(hidden)]
45pub fn validate_segment_store_registers<Reg, Memory, PC, CustomError>(
46    program_counter: &PC,
47    vs3: VReg,
48    group_regs: u8,
49    nf: Nf,
50) -> Result<(), ExecutionError<Reg::Type, CustomError>>
51where
52    Reg: Register,
53    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
54{
55    check_register_group_alignment::<Reg, _, _, _>(program_counter, vs3, group_regs)?;
56    let total =
57        u32::from(vs3.to_bits()) + u32::from(nf.fields_per_segment()) * u32::from(group_regs);
58    if total > 32 {
59        return Err(ExecutionError::IllegalInstruction {
60            address: program_counter.old_pc(INSTRUCTION_SIZE),
61        });
62    }
63    Ok(())
64}
65
66/// Execute a unit-stride or unit-stride segment store.
67///
68/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
69/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
70/// plain unit-stride store.
71///
72/// # Safety
73/// - `vs3.to_bits() % group_regs == 0`
74/// - `vs3.to_bits() + nf * group_regs <= 32`
75/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the source register
76///   group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL register
77///   count for the given `eew` and `vtype`)
78/// - When `vm=false`: `vs3` does not overlap `v0` (i.e. `vs3.to_bits() != 0`)
79#[inline(always)]
80#[expect(clippy::too_many_arguments, reason = "Internal API")]
81#[doc(hidden)]
82pub unsafe fn execute_unit_stride_store<Reg, ExtState, Memory, CustomError>(
83    ext_state: &mut ExtState,
84    memory: &mut Memory,
85    vs3: VReg,
86    vm: bool,
87    base: u64,
88    eew: Eew,
89    group_regs: u8,
90    nf: Nf,
91) -> Result<(), ExecutionError<Reg::Type, CustomError>>
92where
93    Reg: Register,
94    ExtState: VectorRegistersExt<Reg, CustomError>,
95    [(); ExtState::ELEN as usize]:,
96    [(); ExtState::VLEN as usize]:,
97    [(); ExtState::VLENB as usize]:,
98    Memory: VirtualMemory,
99    CustomError: fmt::Debug,
100{
101    let vl = ext_state.vl();
102    let vstart = ext_state.vstart();
103    let elem_bytes = eew.bytes_width();
104    let segment_stride = u64::from(nf.fields_per_segment() * elem_bytes);
105    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
106    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
107    for i in u32::from(vstart)..vl {
108        if !vm && !mask_bit(&mask_buf, i) {
109            continue;
110        }
111        let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
112        for f in 0..nf.fields_per_segment() {
113            let addr = elem_base.wrapping_add(u64::from(f * elem_bytes));
114            // SAFETY: Guaranteed by function contract
115            let field_base_reg =
116                unsafe { VReg::from_bits(vs3.to_bits() + f * group_regs).unwrap_unchecked() };
117            // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
118            //
119            // Let `elems_per_reg = VLENB / elem_bytes`.
120            // `i < vl <= group_regs * elems_per_reg` (precondition), so
121            // `i / elems_per_reg < group_regs`.
122            //
123            // `field_base_reg = vs3.to_bits() + f * group_regs`. Since `f < nf` and the
124            // precondition guarantees `vs3.to_bits() + nf * group_regs <= 32`:
125            // `field_base_reg + group_regs <= vs3.to_bits() + (f+1) * group_regs
126            //                             <= vs3.to_bits() + nf * group_regs <= 32`.
127            //
128            // Therefore,
129            // `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
130            let data =
131                unsafe { read_group_element(ext_state.read_vregs(), field_base_reg, i, eew) };
132            // Record the current element index in `vstart` so that, on a memory fault, the failing
133            // element can be identified and the operation can be restarted
134            if let Err(error) = write_mem_element(memory, addr, eew, data) {
135                ext_state.set_vstart(i as u16);
136                return Err(ExecutionError::MemoryAccess(error));
137            }
138        }
139    }
140    ext_state.reset_vstart();
141    Ok(())
142}
143
144/// Execute a strided or strided-segment store.
145///
146/// The address of element `i`, field `f` is:
147///   `base.wrapping_add(i.wrapping_mul(stride) as u64).wrapping_add(f * eew.bytes())`
148///
149/// `stride` is the raw XLEN register value reinterpreted as a signed integer, matching the RVV
150/// specification where the stride operand is a two's-complement signed offset.
151///
152/// # Safety
153/// - `vs3.to_bits() % group_regs == 0`
154/// - `vs3.to_bits() + nf * group_regs <= 32`
155/// - `vl <= group_regs * VLENB / eew.bytes()`
156/// - When `vm=false`: `vs3.to_bits() != 0`
157#[inline(always)]
158#[expect(clippy::too_many_arguments, reason = "Internal API")]
159#[doc(hidden)]
160pub unsafe fn execute_strided_store<Reg, ExtState, Memory, CustomError>(
161    ext_state: &mut ExtState,
162    memory: &mut Memory,
163    vs3: VReg,
164    vm: bool,
165    base: u64,
166    stride: i64,
167    eew: Eew,
168    group_regs: u8,
169    nf: Nf,
170) -> Result<(), ExecutionError<Reg::Type, CustomError>>
171where
172    Reg: Register,
173    ExtState: VectorRegistersExt<Reg, CustomError>,
174    [(); ExtState::ELEN as usize]:,
175    [(); ExtState::VLEN as usize]:,
176    [(); ExtState::VLENB as usize]:,
177    Memory: VirtualMemory,
178    CustomError: fmt::Debug,
179{
180    let vl = ext_state.vl();
181    let vstart = ext_state.vstart();
182    let elem_bytes = eew.bytes_width();
183    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
184    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
185    for i in u32::from(vstart)..vl {
186        if !vm && !mask_bit(&mask_buf, i) {
187            continue;
188        }
189        let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
190        for f in 0..nf.fields_per_segment() {
191            let addr = elem_base.wrapping_add(u64::from(f * elem_bytes));
192            // SAFETY: Guaranteed by function contract
193            let field_base_reg =
194                unsafe { VReg::from_bits(vs3.to_bits() + f * group_regs).unwrap_unchecked() };
195            // SAFETY: same argument as `execute_unit_stride_store`; `field_base_reg +
196            // i / elems_per_reg < field_base_reg + group_regs <= vs3.to_bits() + nf *
197            // group_regs <= 32`.
198            let data =
199                unsafe { read_group_element(ext_state.read_vregs(), field_base_reg, i, eew) };
200            // Record the current element index in `vstart` so that, on a memory fault, the failing
201            // element can be identified and the operation can be restarted
202            if let Err(error) = write_mem_element(memory, addr, eew, data) {
203                ext_state.set_vstart(i as u16);
204                return Err(ExecutionError::MemoryAccess(error));
205            }
206        }
207    }
208    ext_state.reset_vstart();
209    Ok(())
210}
211
212/// Execute an indexed (unordered or ordered) store or indexed-segment store.
213///
214/// The effective address of element `i`, field `f` is:
215///   `base + index[i] + f * eew.bytes()`
216/// where `index[i]` is element `i` of the index register group `vs2`, interpreted as an
217/// unsigned integer of width `index_eew`.
218///
219/// `data_eew` is the element width of the data being stored (from `vtype.vsew`).
220/// `index_eew` is the element width of the indices (from the instruction encoding).
221///
222/// # Safety
223/// - `vs3.to_bits() % data_group_regs == 0`
224/// - `vs3.to_bits() + nf * data_group_regs <= 32`
225/// - `vs2` register group is aligned and fits within `[0, 32)` (caller must verify via
226///   `check_register_group_alignment` before calling)
227/// - `vl <= data_group_regs * VLENB / data_eew.bytes()`
228/// - `vl <= index_group_regs * VLENB / index_eew.bytes()` (caller must verify)
229/// - When `vm=false`: `vs3.to_bits() != 0`
230#[inline(always)]
231#[expect(clippy::too_many_arguments, reason = "Internal API")]
232#[doc(hidden)]
233pub unsafe fn execute_indexed_store<Reg, ExtState, Memory, CustomError>(
234    ext_state: &mut ExtState,
235    memory: &mut Memory,
236    vs3: VReg,
237    vs2: VReg,
238    vm: bool,
239    base: u64,
240    data_eew: Eew,
241    index_eew: Eew,
242    data_group_regs: u8,
243    nf: Nf,
244) -> Result<(), ExecutionError<Reg::Type, CustomError>>
245where
246    Reg: Register,
247    ExtState: VectorRegistersExt<Reg, CustomError>,
248    [(); ExtState::ELEN as usize]:,
249    [(); ExtState::VLEN as usize]:,
250    [(); ExtState::VLENB as usize]:,
251    Memory: VirtualMemory,
252    CustomError: fmt::Debug,
253{
254    let vl = ext_state.vl();
255    let vstart = ext_state.vstart();
256    let data_elem_bytes = data_eew.bytes_width();
257    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
258    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
259    for i in u32::from(vstart)..vl {
260        if !vm && !mask_bit(&mask_buf, i) {
261            continue;
262        }
263        // SAFETY: `i < vl <= index_group_regs * VLENB / index_eew.bytes()` (precondition), so
264        // `vs2.to_bits() + i / (VLENB / index_eew.bytes()) <
265        //     vs2.to_bits() + index_group_regs <= 32`
266        let index_buf = unsafe { read_group_element(ext_state.read_vregs(), vs2, i, index_eew) };
267        // SAFETY: `index_eew.bytes() <= Eew::MAX_BYTES` always holds.
268        let offset = unsafe { index_buf_to_u64(index_buf, index_eew) };
269        let elem_base = base.wrapping_add(offset);
270        for f in 0..nf.fields_per_segment() {
271            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
272            // SAFETY: Guaranteed by function contract
273            let field_base_reg =
274                unsafe { VReg::from_bits(vs3.to_bits() + f * data_group_regs).unwrap_unchecked() };
275            // SAFETY: `i < vl <= data_group_regs * VLENB / data_eew.bytes()` (precondition), so
276            // `field_base_reg + i / elems_per_reg < field_base_reg + data_group_regs
277            //                                    <= vs3.to_bits() + nf * data_group_regs <= 32`.
278            let data =
279                unsafe { read_group_element(ext_state.read_vregs(), field_base_reg, i, data_eew) };
280            // Record the current element index in `vstart` so that, on a memory fault, the failing
281            // element can be identified and the operation can be restarted
282            if let Err(error) = write_mem_element(memory, addr, data_eew, data) {
283                ext_state.set_vstart(i as u16);
284                return Err(ExecutionError::MemoryAccess(error));
285            }
286        }
287    }
288    ext_state.reset_vstart();
289    Ok(())
290}