ab_riscv_interpreter/v/zve64x/store/zve64x_store_helpers.rs
1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, read_group_element, snapshot_mask};
5use crate::{ExecutionError, InterpreterState, ProgramCounter, VirtualMemory, VirtualMemoryError};
6use ab_riscv_primitives::instructions::v::Eew;
7use ab_riscv_primitives::registers::general_purpose::Register;
8use ab_riscv_primitives::registers::vector::VReg;
9use core::fmt;
10
11/// Interpret `buf[..index_eew.bytes()]` as a little-endian unsigned integer and return it as
12/// `u64`. Used to convert a packed index element into a byte offset.
13///
14/// # Safety
15/// `index_eew.bytes() <= Eew::MAX_BYTES`, which is always true by construction.
16#[inline(always)]
17unsafe fn index_buf_to_u64(buf: [u8; Eew::MAX_BYTES as usize], index_eew: Eew) -> u64 {
18 match index_eew {
19 Eew::E8 => u64::from(buf[0]),
20 Eew::E16 => u64::from(u16::from_le_bytes([buf[0], buf[1]])),
21 Eew::E32 => u64::from(u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])),
22 Eew::E64 => u64::from_le_bytes(buf),
23 }
24}
25
26/// Write `eew`-sized data from `buf[..eew.bytes()]` to memory at `addr` (little-endian)
27#[inline(always)]
28fn write_mem_element(
29 memory: &mut impl VirtualMemory,
30 addr: u64,
31 eew: Eew,
32 buf: [u8; Eew::MAX_BYTES as usize],
33) -> Result<(), VirtualMemoryError> {
34 memory.write_slice(addr, &buf[..usize::from(eew.bytes())])
35}
36
37/// Execute a unit-stride or unit-stride segment store.
38///
39/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
40/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
41/// plain unit-stride store.
42///
43/// # Safety
44/// - `vs3.bits() % group_regs == 0`
45/// - `vs3.bits() + nf * group_regs <= 32`
46/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the source register
47/// group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL register
48/// count for the given `eew` and `vtype`)
49/// - When `vm=false`: `vs3` does not overlap `v0` (i.e. `vs3.bits() != 0`)
50#[inline(always)]
51#[expect(clippy::too_many_arguments, reason = "Internal API")]
52#[doc(hidden)]
53pub unsafe fn execute_unit_stride_store<Reg, ExtState, Memory, PC, IH, CustomError>(
54 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
55 vs3: VReg,
56 vm: bool,
57 vl: u32,
58 vstart: u16,
59 base: u64,
60 eew: Eew,
61 group_regs: u8,
62 nf: u8,
63) -> Result<(), ExecutionError<Reg::Type, CustomError>>
64where
65 Reg: Register,
66 [(); Reg::N]:,
67 ExtState: VectorRegistersExt<Reg, CustomError>,
68 [(); ExtState::ELEN as usize]:,
69 [(); ExtState::VLEN as usize]:,
70 [(); ExtState::VLENB as usize]:,
71 Memory: VirtualMemory,
72 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
73 CustomError: fmt::Debug,
74{
75 let elem_bytes = eew.bytes();
76 let segment_stride = u64::from(nf * elem_bytes);
77 // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
78 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
79 for i in u32::from(vstart)..vl {
80 if !vm && !mask_bit(&mask_buf, i) {
81 continue;
82 }
83 let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
84 for f in 0..nf {
85 let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
86 let field_base_reg = vs3.bits() + f * group_regs;
87 // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
88 //
89 // Let `elems_per_reg = VLENB / elem_bytes`.
90 // `i < vl <= group_regs * elems_per_reg` (precondition), so
91 // `i / elems_per_reg < group_regs`.
92 //
93 // `field_base_reg = vs3.bits() + f * group_regs`. Since `f < nf` and the
94 // precondition guarantees `vs3.bits() + nf * group_regs <= 32`:
95 // `field_base_reg + group_regs <= vs3.bits() + (f+1) * group_regs
96 // <= vs3.bits() + nf * group_regs <= 32`.
97 //
98 // Therefore,
99 // `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
100 let data = unsafe {
101 read_group_element(
102 state.ext_state.read_vreg(),
103 usize::from(field_base_reg),
104 i,
105 eew,
106 )
107 };
108 // Record the current element index in `vstart` so that, on a memory fault, the failing
109 // element can be identified and the operation can be restarted
110 if let Err(error) = write_mem_element(&mut state.memory, addr, eew, data) {
111 state.ext_state.set_vstart(i as u16);
112 return Err(ExecutionError::MemoryAccess(error));
113 }
114 }
115 }
116 state.ext_state.reset_vstart();
117 Ok(())
118}
119
120/// Execute a strided or strided-segment store.
121///
122/// The address of element `i`, field `f` is:
123/// `base.wrapping_add(i.wrapping_mul(stride) as u64).wrapping_add(f * eew.bytes())`
124///
125/// `stride` is the raw XLEN register value reinterpreted as a signed integer, matching the RVV
126/// specification where the stride operand is a two's-complement signed offset.
127///
128/// # Safety
129/// - `vs3.bits() % group_regs == 0`
130/// - `vs3.bits() + nf * group_regs <= 32`
131/// - `vl <= group_regs * VLENB / eew.bytes()`
132/// - When `vm=false`: `vs3.bits() != 0`
133#[inline(always)]
134#[expect(clippy::too_many_arguments, reason = "Internal API")]
135#[doc(hidden)]
136pub unsafe fn execute_strided_store<Reg, ExtState, Memory, PC, IH, CustomError>(
137 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
138 vs3: VReg,
139 vm: bool,
140 vl: u32,
141 vstart: u16,
142 base: u64,
143 stride: i64,
144 eew: Eew,
145 group_regs: u8,
146 nf: u8,
147) -> Result<(), ExecutionError<Reg::Type, CustomError>>
148where
149 Reg: Register,
150 [(); Reg::N]:,
151 ExtState: VectorRegistersExt<Reg, CustomError>,
152 [(); ExtState::ELEN as usize]:,
153 [(); ExtState::VLEN as usize]:,
154 [(); ExtState::VLENB as usize]:,
155 Memory: VirtualMemory,
156 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
157 CustomError: fmt::Debug,
158{
159 let elem_bytes = eew.bytes();
160 // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
161 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
162 for i in u32::from(vstart)..vl {
163 if !vm && !mask_bit(&mask_buf, i) {
164 continue;
165 }
166 let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
167 for f in 0..nf {
168 let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
169 let field_base_reg = vs3.bits() + f * group_regs;
170 // SAFETY: same argument as `execute_unit_stride_store`; `field_base_reg +
171 // i / elems_per_reg < field_base_reg + group_regs <= vs3.bits() + nf *
172 // group_regs <= 32`.
173 let data = unsafe {
174 read_group_element(
175 state.ext_state.read_vreg(),
176 usize::from(field_base_reg),
177 i,
178 eew,
179 )
180 };
181 // Record the current element index in `vstart` so that, on a memory fault, the failing
182 // element can be identified and the operation can be restarted
183 if let Err(error) = write_mem_element(&mut state.memory, addr, eew, data) {
184 state.ext_state.set_vstart(i as u16);
185 return Err(ExecutionError::MemoryAccess(error));
186 }
187 }
188 }
189 state.ext_state.reset_vstart();
190 Ok(())
191}
192
193/// Execute an indexed (unordered or ordered) store or indexed-segment store.
194///
195/// The effective address of element `i`, field `f` is:
196/// `base + index[i] + f * eew.bytes()`
197/// where `index[i]` is element `i` of the index register group `vs2`, interpreted as an
198/// unsigned integer of width `index_eew`.
199///
200/// `data_eew` is the element width of the data being stored (from `vtype.vsew`).
201/// `index_eew` is the element width of the indices (from the instruction encoding).
202///
203/// # Safety
204/// - `vs3.bits() % data_group_regs == 0`
205/// - `vs3.bits() + nf * data_group_regs <= 32`
206/// - `vs2` register group is aligned and fits within `[0, 32)` (caller must verify via
207/// `check_register_group_alignment` before calling)
208/// - `vl <= data_group_regs * VLENB / data_eew.bytes()`
209/// - `vl <= index_group_regs * VLENB / index_eew.bytes()` (caller must verify)
210/// - When `vm=false`: `vs3.bits() != 0`
211#[inline(always)]
212#[expect(clippy::too_many_arguments, reason = "Internal API")]
213#[doc(hidden)]
214pub unsafe fn execute_indexed_store<Reg, ExtState, Memory, PC, IH, CustomError>(
215 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
216 vs3: VReg,
217 vs2: VReg,
218 vm: bool,
219 vl: u32,
220 vstart: u32,
221 base: u64,
222 data_eew: Eew,
223 index_eew: Eew,
224 data_group_regs: u8,
225 nf: u8,
226) -> Result<(), ExecutionError<Reg::Type, CustomError>>
227where
228 Reg: Register,
229 [(); Reg::N]:,
230 ExtState: VectorRegistersExt<Reg, CustomError>,
231 [(); ExtState::ELEN as usize]:,
232 [(); ExtState::VLEN as usize]:,
233 [(); ExtState::VLENB as usize]:,
234 Memory: VirtualMemory,
235 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
236 CustomError: fmt::Debug,
237{
238 let data_elem_bytes = data_eew.bytes();
239 // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
240 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
241 for i in vstart..vl {
242 if !vm && !mask_bit(&mask_buf, i) {
243 continue;
244 }
245 // SAFETY: `i < vl <= index_group_regs * VLENB / index_eew.bytes()` (precondition), so
246 // `vs2.bits() + i / (VLENB / index_eew.bytes()) < vs2.bits() + index_group_regs <= 32`.
247 let index_buf = unsafe {
248 read_group_element(
249 state.ext_state.read_vreg(),
250 usize::from(vs2.bits()),
251 i,
252 index_eew,
253 )
254 };
255 // SAFETY: `index_eew.bytes() <= Eew::MAX_BYTES` always holds.
256 let offset = unsafe { index_buf_to_u64(index_buf, index_eew) };
257 let elem_base = base.wrapping_add(offset);
258 for f in 0..nf {
259 let addr = elem_base.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
260 let field_base_reg = vs3.bits() + f * data_group_regs;
261 // SAFETY: `i < vl <= data_group_regs * VLENB / data_eew.bytes()` (precondition), so
262 // `field_base_reg + i / elems_per_reg < field_base_reg + data_group_regs
263 // <= vs3.bits() + nf * data_group_regs <= 32`.
264 let data = unsafe {
265 read_group_element(
266 state.ext_state.read_vreg(),
267 usize::from(field_base_reg),
268 i,
269 data_eew,
270 )
271 };
272 // Record the current element index in `vstart` so that, on a memory fault, the failing
273 // element can be identified and the operation can be restarted
274 if let Err(error) = write_mem_element(&mut state.memory, addr, data_eew, data) {
275 state.ext_state.set_vstart(i as u16);
276 return Err(ExecutionError::MemoryAccess(error));
277 }
278 }
279 }
280 state.ext_state.reset_vstart();
281 Ok(())
282}