Skip to main content

ab_riscv_interpreter/v/zve64x/load/
zve64x_load_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
5use crate::{ExecutionError, ProgramCounter, VirtualMemory, VirtualMemoryError};
6use ab_riscv_primitives::prelude::*;
7use core::fmt;
8
9#[doc(hidden)]
10pub const MAX_NF: u8 = 8;
11
12/// Return whether mask bit `i` is set in the mask byte slice.
13///
14/// Bits are stored LSB-first within each byte: bit `i` is at byte `i / 8`, position `i % 8`.
15/// Returns `false` for any `i` outside the slice bounds.
16#[inline(always)]
17pub(in super::super) fn mask_bit(mask: &[u8], i: u32) -> bool {
18    mask.get((i / u8::BITS) as usize)
19        .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
20}
21
22/// Copy the mask bytes needed to cover `vl` elements from `v0` into a stack buffer and return
23/// it. The copy releases the shared borrow on the register file so the caller can immediately
24/// take an exclusive borrow for writes.
25///
26/// When `vm=true` (unmasked), the buffer is filled with `0xff` so that every mask bit reads as `1`.
27/// This means callers can unconditionally call [`mask_bit()`] on the returned buffer without
28/// branching on `vm`. Current callers short-circuit with `!vm &&` before calling [`mask_bit()`] as
29/// a micro-optimization on the common unmasked path, but correctness does not depend on that guard:
30/// if it were removed, the `0xff` fill ensures [`mask_bit()`] would return `true` for every
31/// element, preserving the unmasked semantics.
32///
33/// # Safety
34/// `vl.div_ceil(8)` must be `<= VLENB`. This holds when `vl <= VLEN`, which is always true
35/// when `vl` is the current architectural `vl` (bounded by `VLMAX <= VLEN`).
36#[inline(always)]
37pub(in super::super) unsafe fn snapshot_mask<const VLENB: usize>(
38    vreg: &[[u8; VLENB]; 32],
39    vm: bool,
40    vl: u32,
41) -> [u8; VLENB] {
42    let mut buf = [0u8; VLENB];
43    if vm {
44        // All-ones: every element active
45        buf = [0xffu8; VLENB];
46    } else {
47        let mask_bytes = vl.div_ceil(u8::BITS) as usize;
48        // SAFETY: `mask_bytes <= VLENB` by the caller's precondition
49        unsafe {
50            buf.get_unchecked_mut(..mask_bytes)
51                .copy_from_slice(vreg[usize::from(VReg::V0.bits())].get_unchecked(..mask_bytes));
52        }
53    }
54    buf
55}
56
57/// Return whether register groups `[a, a+a_regs)` and `[b, b+b_regs)` overlap.
58#[inline(always)]
59#[doc(hidden)]
60pub fn groups_overlap(a: VReg, a_regs: u8, b: VReg, b_regs: u8) -> bool {
61    let (a, b) = (a.bits(), b.bits());
62    a < b + b_regs && b < a + a_regs
63}
64
65/// Check that `vd` is aligned to `group_regs` and that the group fits within `[0, 32)`.
66///
67/// Per spec, the base register of every register group must be a multiple of the group size.
68#[inline(always)]
69#[doc(hidden)]
70pub fn check_register_group_alignment<Reg, Memory, PC, CustomError>(
71    program_counter: &PC,
72    vd: VReg,
73    group_regs: u8,
74) -> Result<(), ExecutionError<Reg::Type, CustomError>>
75where
76    Reg: Register,
77    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
78{
79    let vd = vd.bits();
80    if !vd.is_multiple_of(group_regs) || vd + group_regs > 32 {
81        return Err(ExecutionError::IllegalInstruction {
82            address: program_counter.old_pc(INSTRUCTION_SIZE),
83        });
84    }
85    Ok(())
86}
87
88/// Validate segment register layout: all `nf` field groups fit within `[0, 32)`, the base
89/// register is group-aligned, and the first field group does not include `v0` when masked.
90///
91/// Field `f` occupies registers `[vd + f * group_regs, vd + f * group_regs + group_regs)`.
92/// On `Ok`, `vd.bits() + nf * group_regs <= 32` is guaranteed.
93#[inline(always)]
94#[doc(hidden)]
95pub fn validate_segment_registers<Reg, Memory, PC, CustomError>(
96    program_counter: &PC,
97    vd: VReg,
98    vm: bool,
99    group_regs: u8,
100    nf: u8,
101) -> Result<(), ExecutionError<Reg::Type, CustomError>>
102where
103    Reg: Register,
104    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
105{
106    let group_regs = u32::from(group_regs);
107    let nf = u32::from(nf);
108    let vd_idx = u32::from(vd.bits());
109    if vd_idx % group_regs != 0 || vd_idx + nf * group_regs > 32 {
110        return Err(ExecutionError::IllegalInstruction {
111            address: program_counter.old_pc(INSTRUCTION_SIZE),
112        });
113    }
114    // When masked, no field group may contain v0 (index 0). Since groups are laid out
115    // contiguously from vd and vd is group-aligned, only the first field (f=0) could contain
116    // v0, which happens exactly when vd == 0.
117    if !vm && vd_idx == 0 {
118        return Err(ExecutionError::IllegalInstruction {
119            address: program_counter.old_pc(INSTRUCTION_SIZE),
120        });
121    }
122    Ok(())
123}
124
125/// Read element `elem_i` from register group `[base_reg, base_reg + group_regs)` into a
126/// `[u8; Eew::MAX_BYTES]` buffer.
127///
128/// The in-register position of element `elem_i` is:
129///   - register `base_reg + elem_i / (VLENB / eew.bytes())`
130///   - byte offset `(elem_i % (VLENB / eew.bytes())) * eew.bytes()`
131///
132/// The result is placed in `buf[..eew.bytes()]`; the remaining bytes are zero.
133///
134/// # Safety
135/// `base_reg + elem_i / (VLENB / eew.bytes())` must be less than 32, i.e. `elem_i` must be
136/// a valid element index within the register group.
137#[inline(always)]
138pub(in super::super) unsafe fn read_group_element<const VLENB: usize>(
139    vreg: &[[u8; VLENB]; 32],
140    base_reg: usize,
141    elem_i: u32,
142    eew: Eew,
143) -> [u8; Eew::MAX_BYTES as usize] {
144    let elem_bytes = usize::from(eew.bytes());
145    let elems_per_reg = VLENB / elem_bytes;
146    let reg_off = elem_i as usize / elems_per_reg;
147    let byte_off = (elem_i as usize % elems_per_reg) * elem_bytes;
148    // SAFETY: `base_reg + reg_off < 32` by the caller's precondition.
149    let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
150    // SAFETY: `byte_off + elem_bytes <= VLENB`: the maximum `byte_off` is
151    // `(elems_per_reg - 1) * elem_bytes = VLENB - elem_bytes`, so
152    // `byte_off + elem_bytes <= VLENB - elem_bytes + elem_bytes = VLENB`.
153    // `elem_bytes <= Eew::MAX_BYTES`: all `Eew` variants are at most E64.
154    let src = unsafe { reg.get_unchecked(byte_off..byte_off + elem_bytes) };
155    let mut buf = [0; _];
156    // SAFETY: `elem_bytes <= Eew::MAX_BYTES` as established above, so `..elem_bytes` is in bounds
157    // for `buf`
158    unsafe { buf.get_unchecked_mut(..elem_bytes) }.copy_from_slice(src);
159    buf
160}
161
162/// Write `eew`-sized data from `buf[..eew.bytes()]` into element `elem_i` of register group
163/// `[base_reg, base_reg + group_regs)`.
164///
165/// The in-register position follows the same layout as [`read_group_element`].
166///
167/// # Safety
168/// `base_reg + elem_i / (VLENB / eew.bytes())` must be less than 32, i.e. `elem_i` must be
169/// a valid element index within the register group.
170#[inline(always)]
171unsafe fn write_group_element<const VLENB: usize>(
172    vreg: &mut [[u8; VLENB]; 32],
173    base_reg: u8,
174    elem_i: u32,
175    eew: Eew,
176    buf: [u8; Eew::MAX_BYTES as usize],
177) {
178    let elem_bytes = usize::from(eew.bytes());
179    let elems_per_reg = VLENB / elem_bytes;
180    let reg_off = elem_i as usize / elems_per_reg;
181    let byte_off = (elem_i as usize % elems_per_reg) * elem_bytes;
182    // SAFETY: `base_reg + reg_off < 32` by the caller's precondition
183    let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
184    // SAFETY: `byte_off + elem_bytes <= VLENB` and `elem_bytes <= Eew::MAX_BYTES`: same argument as
185    // in `read_group_element`
186    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + elem_bytes) };
187    // SAFETY: `elem_bytes <= Eew::MAX_BYTES` as established above, so `..elem_bytes` is in bounds
188    // for `buf`
189    dst.copy_from_slice(unsafe { buf.get_unchecked(..elem_bytes) });
190}
191
192/// Read `eew`-sized data from memory at `addr` into a `[u8; Eew::MAX_BYTES]` buffer
193/// (little-endian)
194#[inline(always)]
195fn read_mem_element(
196    memory: &impl VirtualMemory,
197    addr: u64,
198    eew: Eew,
199) -> Result<[u8; Eew::MAX_BYTES as usize], VirtualMemoryError> {
200    let mut out = [0; _];
201    out[..usize::from(eew.bytes())]
202        .copy_from_slice(memory.read_slice(addr, u32::from(eew.bytes()))?);
203    Ok(out)
204}
205
206/// Execute a unit-stride or unit-stride segment load (including fault-only-first variants).
207///
208/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
209/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
210/// plain unit-stride load.
211///
212/// When `fault_only_first` is set: a memory error at element `i > 0` truncates `vl` to `i`
213/// and returns `Ok`. An error at element `0` always propagates.
214///
215/// # Safety
216/// - `nf <= MAX_NF`
217/// - `vd.bits() % group_regs == 0`
218/// - `vd.bits() + nf * group_regs <= 32`
219/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the destination
220///   register group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL
221///   register count for the given `eew` and `vtype`)
222/// - When `vm=false`: `vd` does not overlap `v0` (i.e. `vd.bits() != 0`)
223#[inline(always)]
224#[expect(clippy::too_many_arguments, reason = "Internal API")]
225#[doc(hidden)]
226pub unsafe fn execute_unit_stride_load<Reg, ExtState, Memory, CustomError>(
227    ext_state: &mut ExtState,
228    memory: &Memory,
229    vd: VReg,
230    vm: bool,
231    vl: u32,
232    vstart: u32,
233    base: u64,
234    eew: Eew,
235    group_regs: u8,
236    nf: u8,
237    fault_only_first: bool,
238) -> Result<(), ExecutionError<Reg::Type, CustomError>>
239where
240    Reg: Register,
241    ExtState: VectorRegistersExt<Reg, CustomError>,
242    [(); ExtState::ELEN as usize]:,
243    [(); ExtState::VLEN as usize]:,
244    [(); ExtState::VLENB as usize]:,
245    Memory: VirtualMemory,
246    CustomError: fmt::Debug,
247{
248    let elem_bytes = eew.bytes();
249    let segment_stride = u64::from(nf) * u64::from(elem_bytes);
250
251    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
252    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
253
254    for i in vstart..vl {
255        if !vm && !mask_bit(&mask_buf, i) {
256            continue;
257        }
258
259        let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
260
261        // Read all nf fields into a stack buffer before writing any of them.
262        // This ensures a fault on field f>0 leaves the destination registers
263        // untouched for the faulting element, so only elements with index
264        // new_vl are ever written (fault-only-first semantics).
265        //
266        // Sized by `MAX_NF * Eew::MAX_BYTES`: the V spec allows at most 8
267        // fields (nf in 1..=8) each is at most 8 bytes (E64), giving 64 bytes.
268        let mut field_buf = [[0u8; usize::from(Eew::MAX_BYTES)]; usize::from(MAX_NF)];
269
270        for f in 0..nf {
271            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
272            match read_mem_element(memory, addr, eew) {
273                Ok(data) => {
274                    // SAFETY: `f < nf` and the precondition on this function requires
275                    // `nf <= MAX_NF` (the V spec encodes nf in 3 bits giving 1..=8 =
276                    // MAX_NF, and the decoder enforces this before constructing the
277                    // instruction). Therefore, `f as usize < nf as usize <= MAX_NF`,
278                    // which is exactly the length of `field_buf`.
279                    unsafe {
280                        *field_buf.get_unchecked_mut(f as usize) = data;
281                    }
282                }
283                Err(mem_err) => {
284                    if fault_only_first && i > 0 {
285                        ext_state.set_vl(i);
286                        ext_state.mark_vs_dirty();
287                        ext_state.reset_vstart();
288                        return Ok(());
289                    }
290                    if i > vstart {
291                        // Elements [vstart, i) were committed; VS is now dirty.
292                        ext_state.mark_vs_dirty();
293                        // vstart records the faulting element for restartability.
294                        ext_state.set_vstart(i as u16);
295                    }
296                    return Err(ExecutionError::MemoryAccess(mem_err));
297                }
298            }
299        }
300
301        // All nf fields for element i were read successfully; commit to the register file.
302        for f in 0..nf {
303            let field_base_reg = vd.bits() + f * group_regs;
304            // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
305            //
306            // Let `elems_per_reg = VLENB / elem_bytes`.
307            // `i < vl <= group_regs * elems_per_reg` (precondition), so
308            // `i / elems_per_reg < group_regs`.
309            //
310            // `field_base_reg = vd.bits() + f * group_regs`. Since `f < nf` and the
311            // precondition guarantees `vd.bits() + nf * group_regs <= 32`:
312            // `field_base_reg + group_regs <= vd.bits() + (f+1) * group_regs
313            //                             <= vd.bits() + nf * group_regs <= 32`.
314            //
315            // Therefore, `field_base_reg + i / elems_per_reg
316            //            < field_base_reg + group_regs <= 32`.
317            //
318            // For `field_buf`: `f < nf <= MAX_NF` (the same argument as in the read loop
319            // above), so `f as usize < MAX_NF = field_buf.len()`.
320            unsafe {
321                write_group_element(
322                    ext_state.write_vreg(),
323                    field_base_reg,
324                    i,
325                    eew,
326                    *field_buf.get_unchecked(f as usize),
327                );
328            }
329        }
330    }
331
332    ext_state.mark_vs_dirty();
333    ext_state.reset_vstart();
334    Ok(())
335}
336
337/// Execute a strided or strided segment load.
338///
339/// `addr[i] = base + i * stride` where `stride` is a signed XLEN-wide value. Field `f` of
340/// element `i` is at `addr[i] + f * eew.bytes()`.
341///
342/// # Safety
343/// - `vd.bits() % group_regs == 0`
344/// - `vd.bits() + nf * group_regs <= 32`
345/// - `vl <= group_regs * VLENB / eew.bytes()`
346/// - When `vm=false`: `vd` does not overlap `v0` (i.e. `vd.bits() != 0`)
347#[inline(always)]
348#[expect(clippy::too_many_arguments, reason = "Internal API")]
349#[doc(hidden)]
350pub unsafe fn execute_strided_load<Reg, ExtState, Memory, CustomError>(
351    ext_state: &mut ExtState,
352    memory: &Memory,
353    vd: VReg,
354    vm: bool,
355    vl: u32,
356    vstart: u32,
357    base: u64,
358    stride: i64,
359    eew: Eew,
360    group_regs: u8,
361    nf: u8,
362) -> Result<(), ExecutionError<Reg::Type, CustomError>>
363where
364    Reg: Register,
365    ExtState: VectorRegistersExt<Reg, CustomError>,
366    [(); ExtState::ELEN as usize]:,
367    [(); ExtState::VLEN as usize]:,
368    [(); ExtState::VLENB as usize]:,
369    Memory: VirtualMemory,
370    CustomError: fmt::Debug,
371{
372    let elem_bytes = eew.bytes();
373
374    // SAFETY: `vl <= VLMAX <= VLEN` (precondition), so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
375    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
376
377    for i in vstart..vl {
378        if !vm && !mask_bit(&mask_buf, i) {
379            continue;
380        }
381
382        let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
383
384        for f in 0..nf {
385            let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
386            let data = match read_mem_element(memory, addr, eew) {
387                Ok(data) => data,
388                Err(mem_err) => {
389                    if f > 0 || i > vstart {
390                        ext_state.mark_vs_dirty();
391                        ext_state.set_vstart(i as u16);
392                    }
393                    return Err(ExecutionError::MemoryAccess(mem_err));
394                }
395            };
396            let field_base_reg = vd.bits() + f * group_regs;
397            // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
398            //
399            // Let `elems_per_reg = VLENB / elem_bytes`.
400            // `i < vl <= group_regs * elems_per_reg` (precondition), so
401            // `i / elems_per_reg < group_regs`.
402            //
403            // `field_base_reg = vd.bits() + f * group_regs`. Since `f < nf` and
404            // `vd.bits() + nf * group_regs <= 32` (precondition):
405            // `field_base_reg + group_regs <= vd.bits() + (f+1) * group_regs
406            //                             <= vd.bits() + nf * group_regs <= 32`.
407            //
408            // Therefore, `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
409            unsafe {
410                write_group_element(ext_state.write_vreg(), field_base_reg, i, eew, data);
411            }
412        }
413    }
414
415    ext_state.mark_vs_dirty();
416    ext_state.reset_vstart();
417    Ok(())
418}
419
420/// Execute an indexed (unordered or ordered) or indexed segment load.
421///
422/// For element `i`, reads `index_eew`-sized bytes from register group `vs2` at element `i`
423/// to obtain a zero-extended byte offset, then loads `nf` data fields from
424/// `base + offset + f * data_eew.bytes()`. Unordered vs ordered is functionally identical in
425/// a software interpreter.
426///
427/// # Safety
428/// - `vd.bits() % data_group_regs == 0`
429/// - `vd.bits() + nf * data_group_regs <= 32`
430/// - `vs2.bits() + (vl - 1) / (VLENB / index_eew.bytes()) < 32` (all `vl` index elements fit within
431///   the register file; satisfied when `vs2` is alignment-checked against `EMUL_index` and `vl` is
432///   the architectural `vl` bounded by `VLMAX`)
433/// - `vl <= data_group_regs * VLENB / data_eew.bytes()` (all `vl` elements fit in a data group)
434/// - When `vm=false`: `vd` does not overlap `v0` (i.e. `vd.bits() != 0`)
435#[inline(always)]
436#[expect(clippy::too_many_arguments, reason = "Internal API")]
437#[doc(hidden)]
438pub unsafe fn execute_indexed_load<Reg, ExtState, Memory, CustomError>(
439    ext_state: &mut ExtState,
440    memory: &Memory,
441    vd: VReg,
442    vs2: VReg,
443    vm: bool,
444    vl: u32,
445    vstart: u32,
446    base: u64,
447    data_eew: Eew,
448    index_eew: Eew,
449    data_group_regs: u8,
450    nf: u8,
451) -> Result<(), ExecutionError<Reg::Type, CustomError>>
452where
453    Reg: Register,
454    ExtState: VectorRegistersExt<Reg, CustomError>,
455    [(); ExtState::ELEN as usize]:,
456    [(); ExtState::VLEN as usize]:,
457    [(); ExtState::VLENB as usize]:,
458    Memory: VirtualMemory,
459    CustomError: fmt::Debug,
460{
461    let index_base_reg = usize::from(vs2.bits());
462
463    // SAFETY: `vl <= VLMAX <= VLEN` (precondition), so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
464    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
465
466    for i in vstart..vl {
467        if !vm && !mask_bit(&mask_buf, i) {
468            continue;
469        }
470
471        // SAFETY: need `index_base_reg + i / (VLENB / index_eew.bytes()) < 32`.
472        //
473        // The caller verified `vs2` is aligned to `EMUL_index` registers and that
474        // `vs2.bits() + EMUL_index <= 32`. `EMUL_index` is defined so that
475        // `EMUL_index * (VLENB / index_eew.bytes()) = VLMAX`. Since `i < vl <= VLMAX`,
476        // `i / (VLENB / index_eew.bytes()) < EMUL_index`, and therefore
477        // `index_base_reg + i / (VLENB / index_eew.bytes()) < index_base_reg + EMUL_index <= 32`.
478        let index_buf =
479            unsafe { read_group_element(ext_state.read_vreg(), index_base_reg, i, index_eew) };
480        let offset = u64::from_le_bytes(index_buf);
481        let elem_addr = base.wrapping_add(offset);
482
483        let data_elem_bytes = data_eew.bytes();
484        for f in 0..nf {
485            let addr = elem_addr.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
486            let data = match read_mem_element(memory, addr, data_eew) {
487                Ok(data) => data,
488                Err(mem_err) => {
489                    if f > 0 || i > vstart {
490                        ext_state.mark_vs_dirty();
491                        ext_state.set_vstart(i as u16);
492                    }
493                    return Err(ExecutionError::MemoryAccess(mem_err));
494                }
495            };
496            let field_base_reg = vd.bits() + f * data_group_regs;
497            // SAFETY: need `field_base_reg + i / (VLENB / data_eew.bytes()) < 32`.
498            //
499            // Let `data_elems_per_reg = VLENB / data_eew.bytes()`.
500            // `i < vl <= data_group_regs * data_elems_per_reg` (precondition), so
501            // `i / data_elems_per_reg < data_group_regs`.
502            //
503            // `field_base_reg = vd.bits() + f * data_group_regs`. Since `f < nf` and
504            // `vd.bits() + nf * data_group_regs <= 32` (precondition):
505            // `field_base_reg + data_group_regs <= vd.bits() + (f+1) * data_group_regs
506            //                                  <= vd.bits() + nf * data_group_regs <= 32`.
507            //
508            // Therefore,
509            // `field_base_reg + i / data_elems_per_reg < field_base_reg + data_group_regs <= 32`.
510            unsafe {
511                write_group_element(ext_state.write_vreg(), field_base_reg, i, data_eew, data);
512            }
513        }
514    }
515
516    ext_state.mark_vs_dirty();
517    ext_state.reset_vstart();
518    Ok(())
519}