Skip to main content

ab_riscv_interpreter/v/zve64x/store/
zve64x_store_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{
5    check_register_group_alignment, mask_bit, read_group_element, snapshot_mask,
6};
7use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
8use crate::{ExecutionError, ProgramCounter, VirtualMemory, VirtualMemoryError};
9use ab_riscv_primitives::prelude::*;
10use core::fmt;
11
12/// Interpret `buf[..index_eew.bytes()]` as a little-endian unsigned integer and return it as
13/// `u64`. Used to convert a packed index element into a byte offset.
14///
15/// # Safety
16/// `index_eew.bytes() <= Eew::MAX_BYTES`, which is always true by construction.
17#[inline(always)]
18unsafe fn index_buf_to_u64(buf: [u8; Eew::MAX_BYTES as usize], index_eew: Eew) -> u64 {
19    match index_eew {
20        Eew::E8 => u64::from(buf[0]),
21        Eew::E16 => u64::from(u16::from_le_bytes([buf[0], buf[1]])),
22        Eew::E32 => u64::from(u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])),
23        Eew::E64 => u64::from_le_bytes(buf),
24    }
25}
26
27/// Write `eew`-sized data from `buf[..eew.bytes()]` to memory at `addr` (little-endian)
28#[inline(always)]
29fn write_mem_element(
30    memory: &mut impl VirtualMemory,
31    addr: u64,
32    eew: Eew,
33    buf: [u8; Eew::MAX_BYTES as usize],
34) -> Result<(), VirtualMemoryError> {
35    memory.write_slice(addr, &buf[..usize::from(eew.bytes())])
36}
37
38/// Validate a segment store's destination register group.
39///
40/// Like [`validate_segment_registers`] but omits the v0-overlap check, since
41/// segment stores read `vs3` as a source and the source/v0 overlap restriction
42/// applies only to load destinations.
43#[inline(always)]
44#[doc(hidden)]
45pub fn validate_segment_store_registers<Reg, Memory, PC, CustomError>(
46    program_counter: &PC,
47    vs3: VReg,
48    group_regs: u8,
49    nf: u8,
50) -> Result<(), ExecutionError<Reg::Type, CustomError>>
51where
52    Reg: Register,
53    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
54{
55    check_register_group_alignment::<Reg, _, _, _>(program_counter, vs3, group_regs)?;
56    let total = u32::from(vs3.bits()) + u32::from(nf) * u32::from(group_regs);
57    if total > 32 {
58        return Err(ExecutionError::IllegalInstruction {
59            address: program_counter.old_pc(INSTRUCTION_SIZE),
60        });
61    }
62    Ok(())
63}
64
65/// Execute a unit-stride or unit-stride segment store.
66///
67/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
68/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
69/// plain unit-stride store.
70///
71/// # Safety
72/// - `vs3.bits() % group_regs == 0`
73/// - `vs3.bits() + nf * group_regs <= 32`
74/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the source register
75///   group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL register
76///   count for the given `eew` and `vtype`)
77/// - When `vm=false`: `vs3` does not overlap `v0` (i.e. `vs3.bits() != 0`)
78#[inline(always)]
79#[expect(clippy::too_many_arguments, reason = "Internal API")]
80#[doc(hidden)]
81pub unsafe fn execute_unit_stride_store<Reg, ExtState, Memory, CustomError>(
82    ext_state: &mut ExtState,
83    memory: &mut Memory,
84    vs3: VReg,
85    vm: bool,
86    vl: u32,
87    vstart: u16,
88    base: u64,
89    eew: Eew,
90    group_regs: u8,
91    nf: u8,
92) -> Result<(), ExecutionError<Reg::Type, CustomError>>
93where
94    Reg: Register,
95    ExtState: VectorRegistersExt<Reg, CustomError>,
96    [(); ExtState::ELEN as usize]:,
97    [(); ExtState::VLEN as usize]:,
98    [(); ExtState::VLENB as usize]:,
99    Memory: VirtualMemory,
100    CustomError: fmt::Debug,
101{
102    let elem_bytes = eew.bytes();
103    let segment_stride = u64::from(nf * elem_bytes);
104    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
105    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
106    for i in u32::from(vstart)..vl {
107        if !vm && !mask_bit(&mask_buf, i) {
108            continue;
109        }
110        let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
111        for f in 0..nf {
112            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
113            let field_base_reg = vs3.bits() + f * group_regs;
114            // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
115            //
116            // Let `elems_per_reg = VLENB / elem_bytes`.
117            // `i < vl <= group_regs * elems_per_reg` (precondition), so
118            // `i / elems_per_reg < group_regs`.
119            //
120            // `field_base_reg = vs3.bits() + f * group_regs`. Since `f < nf` and the
121            // precondition guarantees `vs3.bits() + nf * group_regs <= 32`:
122            // `field_base_reg + group_regs <= vs3.bits() + (f+1) * group_regs
123            //                             <= vs3.bits() + nf * group_regs <= 32`.
124            //
125            // Therefore,
126            // `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
127            let data = unsafe {
128                read_group_element(ext_state.read_vreg(), usize::from(field_base_reg), i, eew)
129            };
130            // Record the current element index in `vstart` so that, on a memory fault, the failing
131            // element can be identified and the operation can be restarted
132            if let Err(error) = write_mem_element(memory, addr, eew, data) {
133                ext_state.set_vstart(i as u16);
134                return Err(ExecutionError::MemoryAccess(error));
135            }
136        }
137    }
138    ext_state.reset_vstart();
139    Ok(())
140}
141
142/// Execute a strided or strided-segment store.
143///
144/// The address of element `i`, field `f` is:
145///   `base.wrapping_add(i.wrapping_mul(stride) as u64).wrapping_add(f * eew.bytes())`
146///
147/// `stride` is the raw XLEN register value reinterpreted as a signed integer, matching the RVV
148/// specification where the stride operand is a two's-complement signed offset.
149///
150/// # Safety
151/// - `vs3.bits() % group_regs == 0`
152/// - `vs3.bits() + nf * group_regs <= 32`
153/// - `vl <= group_regs * VLENB / eew.bytes()`
154/// - When `vm=false`: `vs3.bits() != 0`
155#[inline(always)]
156#[expect(clippy::too_many_arguments, reason = "Internal API")]
157#[doc(hidden)]
158pub unsafe fn execute_strided_store<Reg, ExtState, Memory, CustomError>(
159    ext_state: &mut ExtState,
160    memory: &mut Memory,
161    vs3: VReg,
162    vm: bool,
163    vl: u32,
164    vstart: u16,
165    base: u64,
166    stride: i64,
167    eew: Eew,
168    group_regs: u8,
169    nf: u8,
170) -> Result<(), ExecutionError<Reg::Type, CustomError>>
171where
172    Reg: Register,
173    ExtState: VectorRegistersExt<Reg, CustomError>,
174    [(); ExtState::ELEN as usize]:,
175    [(); ExtState::VLEN as usize]:,
176    [(); ExtState::VLENB as usize]:,
177    Memory: VirtualMemory,
178    CustomError: fmt::Debug,
179{
180    let elem_bytes = eew.bytes();
181    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
182    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
183    for i in u32::from(vstart)..vl {
184        if !vm && !mask_bit(&mask_buf, i) {
185            continue;
186        }
187        let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
188        for f in 0..nf {
189            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
190            let field_base_reg = vs3.bits() + f * group_regs;
191            // SAFETY: same argument as `execute_unit_stride_store`; `field_base_reg +
192            // i / elems_per_reg < field_base_reg + group_regs <= vs3.bits() + nf *
193            // group_regs <= 32`.
194            let data = unsafe {
195                read_group_element(ext_state.read_vreg(), usize::from(field_base_reg), i, eew)
196            };
197            // Record the current element index in `vstart` so that, on a memory fault, the failing
198            // element can be identified and the operation can be restarted
199            if let Err(error) = write_mem_element(memory, addr, eew, data) {
200                ext_state.set_vstart(i as u16);
201                return Err(ExecutionError::MemoryAccess(error));
202            }
203        }
204    }
205    ext_state.reset_vstart();
206    Ok(())
207}
208
209/// Execute an indexed (unordered or ordered) store or indexed-segment store.
210///
211/// The effective address of element `i`, field `f` is:
212///   `base + index[i] + f * eew.bytes()`
213/// where `index[i]` is element `i` of the index register group `vs2`, interpreted as an
214/// unsigned integer of width `index_eew`.
215///
216/// `data_eew` is the element width of the data being stored (from `vtype.vsew`).
217/// `index_eew` is the element width of the indices (from the instruction encoding).
218///
219/// # Safety
220/// - `vs3.bits() % data_group_regs == 0`
221/// - `vs3.bits() + nf * data_group_regs <= 32`
222/// - `vs2` register group is aligned and fits within `[0, 32)` (caller must verify via
223///   `check_register_group_alignment` before calling)
224/// - `vl <= data_group_regs * VLENB / data_eew.bytes()`
225/// - `vl <= index_group_regs * VLENB / index_eew.bytes()` (caller must verify)
226/// - When `vm=false`: `vs3.bits() != 0`
227#[inline(always)]
228#[expect(clippy::too_many_arguments, reason = "Internal API")]
229#[doc(hidden)]
230pub unsafe fn execute_indexed_store<Reg, ExtState, Memory, CustomError>(
231    ext_state: &mut ExtState,
232    memory: &mut Memory,
233    vs3: VReg,
234    vs2: VReg,
235    vm: bool,
236    vl: u32,
237    vstart: u32,
238    base: u64,
239    data_eew: Eew,
240    index_eew: Eew,
241    data_group_regs: u8,
242    nf: u8,
243) -> Result<(), ExecutionError<Reg::Type, CustomError>>
244where
245    Reg: Register,
246    ExtState: VectorRegistersExt<Reg, CustomError>,
247    [(); ExtState::ELEN as usize]:,
248    [(); ExtState::VLEN as usize]:,
249    [(); ExtState::VLENB as usize]:,
250    Memory: VirtualMemory,
251    CustomError: fmt::Debug,
252{
253    let data_elem_bytes = data_eew.bytes();
254    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
255    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
256    for i in vstart..vl {
257        if !vm && !mask_bit(&mask_buf, i) {
258            continue;
259        }
260        // SAFETY: `i < vl <= index_group_regs * VLENB / index_eew.bytes()` (precondition), so
261        // `vs2.bits() + i / (VLENB / index_eew.bytes()) < vs2.bits() + index_group_regs <= 32`.
262        let index_buf = unsafe {
263            read_group_element(ext_state.read_vreg(), usize::from(vs2.bits()), i, index_eew)
264        };
265        // SAFETY: `index_eew.bytes() <= Eew::MAX_BYTES` always holds.
266        let offset = unsafe { index_buf_to_u64(index_buf, index_eew) };
267        let elem_base = base.wrapping_add(offset);
268        for f in 0..nf {
269            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
270            let field_base_reg = vs3.bits() + f * data_group_regs;
271            // SAFETY: `i < vl <= data_group_regs * VLENB / data_eew.bytes()` (precondition), so
272            // `field_base_reg + i / elems_per_reg < field_base_reg + data_group_regs
273            //                                    <= vs3.bits() + nf * data_group_regs <= 32`.
274            let data = unsafe {
275                read_group_element(
276                    ext_state.read_vreg(),
277                    usize::from(field_base_reg),
278                    i,
279                    data_eew,
280                )
281            };
282            // Record the current element index in `vstart` so that, on a memory fault, the failing
283            // element can be identified and the operation can be restarted
284            if let Err(error) = write_mem_element(memory, addr, data_eew, data) {
285                ext_state.set_vstart(i as u16);
286                return Err(ExecutionError::MemoryAccess(error));
287            }
288        }
289    }
290    ext_state.reset_vstart();
291    Ok(())
292}