Skip to main content

ab_riscv_interpreter/v/zve64x/store/
zve64x_store_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, read_group_element, snapshot_mask};
5use crate::{ExecutionError, InterpreterState, ProgramCounter, VirtualMemory, VirtualMemoryError};
6use ab_riscv_primitives::instructions::v::Eew;
7use ab_riscv_primitives::registers::general_purpose::Register;
8use ab_riscv_primitives::registers::vector::VReg;
9use core::fmt;
10
11/// Interpret `buf[..index_eew.bytes()]` as a little-endian unsigned integer and return it as
12/// `u64`. Used to convert a packed index element into a byte offset.
13///
14/// # Safety
15/// `index_eew.bytes() <= Eew::MAX_BYTES`, which is always true by construction.
16#[inline(always)]
17unsafe fn index_buf_to_u64(buf: [u8; Eew::MAX_BYTES as usize], index_eew: Eew) -> u64 {
18    match index_eew {
19        Eew::E8 => u64::from(buf[0]),
20        Eew::E16 => u64::from(u16::from_le_bytes([buf[0], buf[1]])),
21        Eew::E32 => u64::from(u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])),
22        Eew::E64 => u64::from_le_bytes(buf),
23    }
24}
25
26/// Write `eew`-sized data from `buf[..eew.bytes()]` to memory at `addr` (little-endian)
27#[inline(always)]
28fn write_mem_element(
29    memory: &mut impl VirtualMemory,
30    addr: u64,
31    eew: Eew,
32    buf: [u8; Eew::MAX_BYTES as usize],
33) -> Result<(), VirtualMemoryError> {
34    memory.write_slice(addr, &buf[..usize::from(eew.bytes())])
35}
36
37/// Execute a unit-stride or unit-stride segment store.
38///
39/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
40/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
41/// plain unit-stride store.
42///
43/// # Safety
44/// - `vs3.bits() % group_regs == 0`
45/// - `vs3.bits() + nf * group_regs <= 32`
46/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the source register
47///   group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL register
48///   count for the given `eew` and `vtype`)
49/// - When `vm=false`: `vs3` does not overlap `v0` (i.e. `vs3.bits() != 0`)
50#[inline(always)]
51#[expect(clippy::too_many_arguments, reason = "Internal API")]
52#[doc(hidden)]
53pub unsafe fn execute_unit_stride_store<Reg, ExtState, Memory, PC, IH, CustomError>(
54    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
55    vs3: VReg,
56    vm: bool,
57    vl: u32,
58    vstart: u16,
59    base: u64,
60    eew: Eew,
61    group_regs: u8,
62    nf: u8,
63) -> Result<(), ExecutionError<Reg::Type, CustomError>>
64where
65    Reg: Register,
66    [(); Reg::N]:,
67    ExtState: VectorRegistersExt<Reg, CustomError>,
68    [(); ExtState::ELEN as usize]:,
69    [(); ExtState::VLEN as usize]:,
70    [(); ExtState::VLENB as usize]:,
71    Memory: VirtualMemory,
72    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
73    CustomError: fmt::Debug,
74{
75    let elem_bytes = eew.bytes();
76    let segment_stride = u64::from(nf * elem_bytes);
77    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
78    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
79    for i in u32::from(vstart)..vl {
80        if !vm && !mask_bit(&mask_buf, i) {
81            continue;
82        }
83        let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
84        for f in 0..nf {
85            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
86            let field_base_reg = vs3.bits() + f * group_regs;
87            // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
88            //
89            // Let `elems_per_reg = VLENB / elem_bytes`.
90            // `i < vl <= group_regs * elems_per_reg` (precondition), so
91            // `i / elems_per_reg < group_regs`.
92            //
93            // `field_base_reg = vs3.bits() + f * group_regs`. Since `f < nf` and the
94            // precondition guarantees `vs3.bits() + nf * group_regs <= 32`:
95            // `field_base_reg + group_regs <= vs3.bits() + (f+1) * group_regs
96            //                             <= vs3.bits() + nf * group_regs <= 32`.
97            //
98            // Therefore,
99            // `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
100            let data = unsafe {
101                read_group_element(
102                    state.ext_state.read_vreg(),
103                    usize::from(field_base_reg),
104                    i,
105                    eew,
106                )
107            };
108            // Record the current element index in `vstart` so that, on a memory fault, the failing
109            // element can be identified and the operation can be restarted
110            if let Err(error) = write_mem_element(&mut state.memory, addr, eew, data) {
111                state.ext_state.set_vstart(i as u16);
112                return Err(ExecutionError::MemoryAccess(error));
113            }
114        }
115    }
116    state.ext_state.reset_vstart();
117    Ok(())
118}
119
120/// Execute a strided or strided-segment store.
121///
122/// The address of element `i`, field `f` is:
123///   `base.wrapping_add(i.wrapping_mul(stride) as u64).wrapping_add(f * eew.bytes())`
124///
125/// `stride` is the raw XLEN register value reinterpreted as a signed integer, matching the RVV
126/// specification where the stride operand is a two's-complement signed offset.
127///
128/// # Safety
129/// - `vs3.bits() % group_regs == 0`
130/// - `vs3.bits() + nf * group_regs <= 32`
131/// - `vl <= group_regs * VLENB / eew.bytes()`
132/// - When `vm=false`: `vs3.bits() != 0`
133#[inline(always)]
134#[expect(clippy::too_many_arguments, reason = "Internal API")]
135#[doc(hidden)]
136pub unsafe fn execute_strided_store<Reg, ExtState, Memory, PC, IH, CustomError>(
137    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
138    vs3: VReg,
139    vm: bool,
140    vl: u32,
141    vstart: u16,
142    base: u64,
143    stride: i64,
144    eew: Eew,
145    group_regs: u8,
146    nf: u8,
147) -> Result<(), ExecutionError<Reg::Type, CustomError>>
148where
149    Reg: Register,
150    [(); Reg::N]:,
151    ExtState: VectorRegistersExt<Reg, CustomError>,
152    [(); ExtState::ELEN as usize]:,
153    [(); ExtState::VLEN as usize]:,
154    [(); ExtState::VLENB as usize]:,
155    Memory: VirtualMemory,
156    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
157    CustomError: fmt::Debug,
158{
159    let elem_bytes = eew.bytes();
160    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
161    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
162    for i in u32::from(vstart)..vl {
163        if !vm && !mask_bit(&mask_buf, i) {
164            continue;
165        }
166        let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
167        for f in 0..nf {
168            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
169            let field_base_reg = vs3.bits() + f * group_regs;
170            // SAFETY: same argument as `execute_unit_stride_store`; `field_base_reg +
171            // i / elems_per_reg < field_base_reg + group_regs <= vs3.bits() + nf *
172            // group_regs <= 32`.
173            let data = unsafe {
174                read_group_element(
175                    state.ext_state.read_vreg(),
176                    usize::from(field_base_reg),
177                    i,
178                    eew,
179                )
180            };
181            // Record the current element index in `vstart` so that, on a memory fault, the failing
182            // element can be identified and the operation can be restarted
183            if let Err(error) = write_mem_element(&mut state.memory, addr, eew, data) {
184                state.ext_state.set_vstart(i as u16);
185                return Err(ExecutionError::MemoryAccess(error));
186            }
187        }
188    }
189    state.ext_state.reset_vstart();
190    Ok(())
191}
192
193/// Execute an indexed (unordered or ordered) store or indexed-segment store.
194///
195/// The effective address of element `i`, field `f` is:
196///   `base + index[i] + f * eew.bytes()`
197/// where `index[i]` is element `i` of the index register group `vs2`, interpreted as an
198/// unsigned integer of width `index_eew`.
199///
200/// `data_eew` is the element width of the data being stored (from `vtype.vsew`).
201/// `index_eew` is the element width of the indices (from the instruction encoding).
202///
203/// # Safety
204/// - `vs3.bits() % data_group_regs == 0`
205/// - `vs3.bits() + nf * data_group_regs <= 32`
206/// - `vs2` register group is aligned and fits within `[0, 32)` (caller must verify via
207///   `check_register_group_alignment` before calling)
208/// - `vl <= data_group_regs * VLENB / data_eew.bytes()`
209/// - `vl <= index_group_regs * VLENB / index_eew.bytes()` (caller must verify)
210/// - When `vm=false`: `vs3.bits() != 0`
211#[inline(always)]
212#[expect(clippy::too_many_arguments, reason = "Internal API")]
213#[doc(hidden)]
214pub unsafe fn execute_indexed_store<Reg, ExtState, Memory, PC, IH, CustomError>(
215    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
216    vs3: VReg,
217    vs2: VReg,
218    vm: bool,
219    vl: u32,
220    vstart: u32,
221    base: u64,
222    data_eew: Eew,
223    index_eew: Eew,
224    data_group_regs: u8,
225    nf: u8,
226) -> Result<(), ExecutionError<Reg::Type, CustomError>>
227where
228    Reg: Register,
229    [(); Reg::N]:,
230    ExtState: VectorRegistersExt<Reg, CustomError>,
231    [(); ExtState::ELEN as usize]:,
232    [(); ExtState::VLEN as usize]:,
233    [(); ExtState::VLENB as usize]:,
234    Memory: VirtualMemory,
235    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
236    CustomError: fmt::Debug,
237{
238    let data_elem_bytes = data_eew.bytes();
239    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
240    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
241    for i in vstart..vl {
242        if !vm && !mask_bit(&mask_buf, i) {
243            continue;
244        }
245        // SAFETY: `i < vl <= index_group_regs * VLENB / index_eew.bytes()` (precondition), so
246        // `vs2.bits() + i / (VLENB / index_eew.bytes()) < vs2.bits() + index_group_regs <= 32`.
247        let index_buf = unsafe {
248            read_group_element(
249                state.ext_state.read_vreg(),
250                usize::from(vs2.bits()),
251                i,
252                index_eew,
253            )
254        };
255        // SAFETY: `index_eew.bytes() <= Eew::MAX_BYTES` always holds.
256        let offset = unsafe { index_buf_to_u64(index_buf, index_eew) };
257        let elem_base = base.wrapping_add(offset);
258        for f in 0..nf {
259            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
260            let field_base_reg = vs3.bits() + f * data_group_regs;
261            // SAFETY: `i < vl <= data_group_regs * VLENB / data_eew.bytes()` (precondition), so
262            // `field_base_reg + i / elems_per_reg < field_base_reg + data_group_regs
263            //                                    <= vs3.bits() + nf * data_group_regs <= 32`.
264            let data = unsafe {
265                read_group_element(
266                    state.ext_state.read_vreg(),
267                    usize::from(field_base_reg),
268                    i,
269                    data_eew,
270                )
271            };
272            // Record the current element index in `vstart` so that, on a memory fault, the failing
273            // element can be identified and the operation can be restarted
274            if let Err(error) = write_mem_element(&mut state.memory, addr, data_eew, data) {
275                state.ext_state.set_vstart(i as u16);
276                return Err(ExecutionError::MemoryAccess(error));
277            }
278        }
279    }
280    state.ext_state.reset_vstart();
281    Ok(())
282}