ab_riscv_interpreter/v/zve64x/load/zve64x_load_helpers.rs
1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
5use crate::{ExecutionError, ProgramCounter, VirtualMemory, VirtualMemoryError};
6use ab_riscv_primitives::prelude::*;
7use core::fmt;
8
9#[doc(hidden)]
10pub const MAX_NF: u8 = 8;
11
12/// Return whether mask bit `i` is set in the mask byte slice.
13///
14/// Bits are stored LSB-first within each byte: bit `i` is at byte `i / 8`, position `i % 8`.
15/// Returns `false` for any `i` outside the slice bounds.
16#[inline(always)]
17pub(in super::super) fn mask_bit(mask: &[u8], i: u32) -> bool {
18 mask.get((i / u8::BITS) as usize)
19 .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
20}
21
22/// Copy the mask bytes needed to cover `vl` elements from `v0` into a stack buffer and return
23/// it. The copy releases the shared borrow on the register file so the caller can immediately
24/// take an exclusive borrow for writes.
25///
26/// When `vm=true` (unmasked), the buffer is filled with `0xff` so that every mask bit reads as `1`.
27/// This means callers can unconditionally call [`mask_bit()`] on the returned buffer without
28/// branching on `vm`. Current callers short-circuit with `!vm &&` before calling [`mask_bit()`] as
29/// a micro-optimization on the common unmasked path, but correctness does not depend on that guard:
30/// if it were removed, the `0xff` fill ensures [`mask_bit()`] would return `true` for every
31/// element, preserving the unmasked semantics.
32///
33/// # Safety
34/// `vl.div_ceil(8)` must be `<= VLENB`. This holds when `vl <= VLEN`, which is always true
35/// when `vl` is the current architectural `vl` (bounded by `VLMAX <= VLEN`).
36#[inline(always)]
37pub(in super::super) unsafe fn snapshot_mask<const VLENB: usize>(
38 vreg: &[[u8; VLENB]; 32],
39 vm: bool,
40 vl: u32,
41) -> [u8; VLENB] {
42 let mut buf = [0u8; VLENB];
43 if vm {
44 // All-ones: every element active
45 buf = [0xffu8; VLENB];
46 } else {
47 let mask_bytes = vl.div_ceil(u8::BITS) as usize;
48 // SAFETY: `mask_bytes <= VLENB` by the caller's precondition
49 unsafe {
50 buf.get_unchecked_mut(..mask_bytes)
51 .copy_from_slice(vreg[usize::from(VReg::V0.bits())].get_unchecked(..mask_bytes));
52 }
53 }
54 buf
55}
56
57/// Return whether register groups `[a, a+a_regs)` and `[b, b+b_regs)` overlap.
58#[inline(always)]
59#[doc(hidden)]
60pub fn groups_overlap(a: VReg, a_regs: u8, b: VReg, b_regs: u8) -> bool {
61 let (a, b) = (a.bits(), b.bits());
62 a < b + b_regs && b < a + a_regs
63}
64
65/// Check that `vd` is aligned to `group_regs` and that the group fits within `[0, 32)`.
66///
67/// Per spec, the base register of every register group must be a multiple of the group size.
68#[inline(always)]
69#[doc(hidden)]
70pub fn check_register_group_alignment<Reg, Memory, PC, CustomError>(
71 program_counter: &PC,
72 vd: VReg,
73 group_regs: u8,
74) -> Result<(), ExecutionError<Reg::Type, CustomError>>
75where
76 Reg: Register,
77 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
78{
79 let vd = vd.bits();
80 if !vd.is_multiple_of(group_regs) || vd + group_regs > 32 {
81 return Err(ExecutionError::IllegalInstruction {
82 address: program_counter.old_pc(INSTRUCTION_SIZE),
83 });
84 }
85 Ok(())
86}
87
88/// Validate segment register layout: all `nf` field groups fit within `[0, 32)`, the base
89/// register is group-aligned, and the first field group does not include `v0` when masked.
90///
91/// Field `f` occupies registers `[vd + f * group_regs, vd + f * group_regs + group_regs)`.
92/// On `Ok`, `vd.bits() + nf * group_regs <= 32` is guaranteed.
93#[inline(always)]
94#[doc(hidden)]
95pub fn validate_segment_registers<Reg, Memory, PC, CustomError>(
96 program_counter: &PC,
97 vd: VReg,
98 vm: bool,
99 group_regs: u8,
100 nf: u8,
101) -> Result<(), ExecutionError<Reg::Type, CustomError>>
102where
103 Reg: Register,
104 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
105{
106 let group_regs = u32::from(group_regs);
107 let nf = u32::from(nf);
108 let vd_idx = u32::from(vd.bits());
109 if vd_idx % group_regs != 0 || vd_idx + nf * group_regs > 32 {
110 return Err(ExecutionError::IllegalInstruction {
111 address: program_counter.old_pc(INSTRUCTION_SIZE),
112 });
113 }
114 // When masked, no field group may contain v0 (index 0). Since groups are laid out
115 // contiguously from vd and vd is group-aligned, only the first field (f=0) could contain
116 // v0, which happens exactly when vd == 0.
117 if !vm && vd_idx == 0 {
118 return Err(ExecutionError::IllegalInstruction {
119 address: program_counter.old_pc(INSTRUCTION_SIZE),
120 });
121 }
122 Ok(())
123}
124
125/// Read element `elem_i` from register group `[base_reg, base_reg + group_regs)` into a
126/// `[u8; Eew::MAX_BYTES]` buffer.
127///
128/// The in-register position of element `elem_i` is:
129/// - register `base_reg + elem_i / (VLENB / eew.bytes())`
130/// - byte offset `(elem_i % (VLENB / eew.bytes())) * eew.bytes()`
131///
132/// The result is placed in `buf[..eew.bytes()]`; the remaining bytes are zero.
133///
134/// # Safety
135/// `base_reg + elem_i / (VLENB / eew.bytes())` must be less than 32, i.e. `elem_i` must be
136/// a valid element index within the register group.
137#[inline(always)]
138pub(in super::super) unsafe fn read_group_element<const VLENB: usize>(
139 vreg: &[[u8; VLENB]; 32],
140 base_reg: usize,
141 elem_i: u32,
142 eew: Eew,
143) -> [u8; Eew::MAX_BYTES as usize] {
144 let elem_bytes = usize::from(eew.bytes());
145 let elems_per_reg = VLENB / elem_bytes;
146 let reg_off = elem_i as usize / elems_per_reg;
147 let byte_off = (elem_i as usize % elems_per_reg) * elem_bytes;
148 // SAFETY: `base_reg + reg_off < 32` by the caller's precondition.
149 let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
150 // SAFETY: `byte_off + elem_bytes <= VLENB`: the maximum `byte_off` is
151 // `(elems_per_reg - 1) * elem_bytes = VLENB - elem_bytes`, so
152 // `byte_off + elem_bytes <= VLENB - elem_bytes + elem_bytes = VLENB`.
153 // `elem_bytes <= Eew::MAX_BYTES`: all `Eew` variants are at most E64.
154 let src = unsafe { reg.get_unchecked(byte_off..byte_off + elem_bytes) };
155 let mut buf = [0; _];
156 // SAFETY: `elem_bytes <= Eew::MAX_BYTES` as established above, so `..elem_bytes` is in bounds
157 // for `buf`
158 unsafe { buf.get_unchecked_mut(..elem_bytes) }.copy_from_slice(src);
159 buf
160}
161
162/// Write `eew`-sized data from `buf[..eew.bytes()]` into element `elem_i` of register group
163/// `[base_reg, base_reg + group_regs)`.
164///
165/// The in-register position follows the same layout as [`read_group_element`].
166///
167/// # Safety
168/// `base_reg + elem_i / (VLENB / eew.bytes())` must be less than 32, i.e. `elem_i` must be
169/// a valid element index within the register group.
170#[inline(always)]
171unsafe fn write_group_element<const VLENB: usize>(
172 vreg: &mut [[u8; VLENB]; 32],
173 base_reg: u8,
174 elem_i: u32,
175 eew: Eew,
176 buf: [u8; Eew::MAX_BYTES as usize],
177) {
178 let elem_bytes = usize::from(eew.bytes());
179 let elems_per_reg = VLENB / elem_bytes;
180 let reg_off = elem_i as usize / elems_per_reg;
181 let byte_off = (elem_i as usize % elems_per_reg) * elem_bytes;
182 // SAFETY: `base_reg + reg_off < 32` by the caller's precondition
183 let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
184 // SAFETY: `byte_off + elem_bytes <= VLENB` and `elem_bytes <= Eew::MAX_BYTES`: same argument as
185 // in `read_group_element`
186 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + elem_bytes) };
187 // SAFETY: `elem_bytes <= Eew::MAX_BYTES` as established above, so `..elem_bytes` is in bounds
188 // for `buf`
189 dst.copy_from_slice(unsafe { buf.get_unchecked(..elem_bytes) });
190}
191
192/// Read `eew`-sized data from memory at `addr` into a `[u8; Eew::MAX_BYTES]` buffer
193/// (little-endian)
194#[inline(always)]
195fn read_mem_element(
196 memory: &impl VirtualMemory,
197 addr: u64,
198 eew: Eew,
199) -> Result<[u8; Eew::MAX_BYTES as usize], VirtualMemoryError> {
200 let mut out = [0; _];
201 out[..usize::from(eew.bytes())]
202 .copy_from_slice(memory.read_slice(addr, u32::from(eew.bytes()))?);
203 Ok(out)
204}
205
206/// Execute a unit-stride or unit-stride segment load (including fault-only-first variants).
207///
208/// Segment stride between elements is `nf * eew.bytes()`. Field `f` for element `i` is at
209/// `base + i * nf * eew.bytes() + f * eew.bytes()`. When `nf == 1` this degenerates to a
210/// plain unit-stride load.
211///
212/// When `fault_only_first` is set: a memory error at element `i > 0` truncates `vl` to `i`
213/// and returns `Ok`. An error at element `0` always propagates.
214///
215/// # Safety
216/// - `nf <= MAX_NF`
217/// - `vd.bits() % group_regs == 0`
218/// - `vd.bits() + nf * group_regs <= 32`
219/// - `vl <= group_regs * VLENB / eew.bytes()` (all `vl` elements fit within the destination
220/// register group; this holds when `vl` is the architectural `vl` and `group_regs` is the EMUL
221/// register count for the given `eew` and `vtype`)
222/// - When `vm=false`: `vd` does not overlap `v0` (i.e. `vd.bits() != 0`)
223#[inline(always)]
224#[expect(clippy::too_many_arguments, reason = "Internal API")]
225#[doc(hidden)]
226pub unsafe fn execute_unit_stride_load<Reg, ExtState, Memory, CustomError>(
227 ext_state: &mut ExtState,
228 memory: &Memory,
229 vd: VReg,
230 vm: bool,
231 vl: u32,
232 vstart: u32,
233 base: u64,
234 eew: Eew,
235 group_regs: u8,
236 nf: u8,
237 fault_only_first: bool,
238) -> Result<(), ExecutionError<Reg::Type, CustomError>>
239where
240 Reg: Register,
241 ExtState: VectorRegistersExt<Reg, CustomError>,
242 [(); ExtState::ELEN as usize]:,
243 [(); ExtState::VLEN as usize]:,
244 [(); ExtState::VLENB as usize]:,
245 Memory: VirtualMemory,
246 CustomError: fmt::Debug,
247{
248 let elem_bytes = eew.bytes();
249 let segment_stride = u64::from(nf) * u64::from(elem_bytes);
250
251 // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
252 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
253
254 for i in vstart..vl {
255 if !vm && !mask_bit(&mask_buf, i) {
256 continue;
257 }
258
259 let elem_base = base.wrapping_add(u64::from(i) * segment_stride);
260
261 // Read all nf fields into a stack buffer before writing any of them.
262 // This ensures a fault on field f>0 leaves the destination registers
263 // untouched for the faulting element, so only elements with index
264 // new_vl are ever written (fault-only-first semantics).
265 //
266 // Sized by `MAX_NF * Eew::MAX_BYTES`: the V spec allows at most 8
267 // fields (nf in 1..=8) each is at most 8 bytes (E64), giving 64 bytes.
268 let mut field_buf = [[0u8; usize::from(Eew::MAX_BYTES)]; usize::from(MAX_NF)];
269
270 for f in 0..nf {
271 let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
272 match read_mem_element(memory, addr, eew) {
273 Ok(data) => {
274 // SAFETY: `f < nf` and the precondition on this function requires
275 // `nf <= MAX_NF` (the V spec encodes nf in 3 bits giving 1..=8 =
276 // MAX_NF, and the decoder enforces this before constructing the
277 // instruction). Therefore, `f as usize < nf as usize <= MAX_NF`,
278 // which is exactly the length of `field_buf`.
279 unsafe {
280 *field_buf.get_unchecked_mut(f as usize) = data;
281 }
282 }
283 Err(mem_err) => {
284 if fault_only_first && i > 0 {
285 ext_state.set_vl(i);
286 ext_state.mark_vs_dirty();
287 ext_state.reset_vstart();
288 return Ok(());
289 }
290 if i > vstart {
291 // Elements [vstart, i) were committed; VS is now dirty.
292 ext_state.mark_vs_dirty();
293 // vstart records the faulting element for restartability.
294 ext_state.set_vstart(i as u16);
295 }
296 return Err(ExecutionError::MemoryAccess(mem_err));
297 }
298 }
299 }
300
301 // All nf fields for element i were read successfully; commit to the register file.
302 for f in 0..nf {
303 let field_base_reg = vd.bits() + f * group_regs;
304 // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
305 //
306 // Let `elems_per_reg = VLENB / elem_bytes`.
307 // `i < vl <= group_regs * elems_per_reg` (precondition), so
308 // `i / elems_per_reg < group_regs`.
309 //
310 // `field_base_reg = vd.bits() + f * group_regs`. Since `f < nf` and the
311 // precondition guarantees `vd.bits() + nf * group_regs <= 32`:
312 // `field_base_reg + group_regs <= vd.bits() + (f+1) * group_regs
313 // <= vd.bits() + nf * group_regs <= 32`.
314 //
315 // Therefore, `field_base_reg + i / elems_per_reg
316 // < field_base_reg + group_regs <= 32`.
317 //
318 // For `field_buf`: `f < nf <= MAX_NF` (the same argument as in the read loop
319 // above), so `f as usize < MAX_NF = field_buf.len()`.
320 unsafe {
321 write_group_element(
322 ext_state.write_vreg(),
323 field_base_reg,
324 i,
325 eew,
326 *field_buf.get_unchecked(f as usize),
327 );
328 }
329 }
330 }
331
332 ext_state.mark_vs_dirty();
333 ext_state.reset_vstart();
334 Ok(())
335}
336
337/// Execute a strided or strided segment load.
338///
339/// `addr[i] = base + i * stride` where `stride` is a signed XLEN-wide value. Field `f` of
340/// element `i` is at `addr[i] + f * eew.bytes()`.
341///
342/// # Safety
343/// - `vd.bits() % group_regs == 0`
344/// - `vd.bits() + nf * group_regs <= 32`
345/// - `vl <= group_regs * VLENB / eew.bytes()`
346/// - When `vm=false`: `vd` does not overlap `v0` (i.e. `vd.bits() != 0`)
347#[inline(always)]
348#[expect(clippy::too_many_arguments, reason = "Internal API")]
349#[doc(hidden)]
350pub unsafe fn execute_strided_load<Reg, ExtState, Memory, CustomError>(
351 ext_state: &mut ExtState,
352 memory: &Memory,
353 vd: VReg,
354 vm: bool,
355 vl: u32,
356 vstart: u32,
357 base: u64,
358 stride: i64,
359 eew: Eew,
360 group_regs: u8,
361 nf: u8,
362) -> Result<(), ExecutionError<Reg::Type, CustomError>>
363where
364 Reg: Register,
365 ExtState: VectorRegistersExt<Reg, CustomError>,
366 [(); ExtState::ELEN as usize]:,
367 [(); ExtState::VLEN as usize]:,
368 [(); ExtState::VLENB as usize]:,
369 Memory: VirtualMemory,
370 CustomError: fmt::Debug,
371{
372 let elem_bytes = eew.bytes();
373
374 // SAFETY: `vl <= VLMAX <= VLEN` (precondition), so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
375 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
376
377 for i in vstart..vl {
378 if !vm && !mask_bit(&mask_buf, i) {
379 continue;
380 }
381
382 let elem_base = base.wrapping_add(i64::from(i).wrapping_mul(stride).cast_unsigned());
383
384 for f in 0..nf {
385 let addr = elem_base.wrapping_add(u64::from(f) * u64::from(elem_bytes));
386 let data = match read_mem_element(memory, addr, eew) {
387 Ok(data) => data,
388 Err(mem_err) => {
389 if f > 0 || i > vstart {
390 ext_state.mark_vs_dirty();
391 ext_state.set_vstart(i as u16);
392 }
393 return Err(ExecutionError::MemoryAccess(mem_err));
394 }
395 };
396 let field_base_reg = vd.bits() + f * group_regs;
397 // SAFETY: need `field_base_reg + i / (VLENB / elem_bytes) < 32`.
398 //
399 // Let `elems_per_reg = VLENB / elem_bytes`.
400 // `i < vl <= group_regs * elems_per_reg` (precondition), so
401 // `i / elems_per_reg < group_regs`.
402 //
403 // `field_base_reg = vd.bits() + f * group_regs`. Since `f < nf` and
404 // `vd.bits() + nf * group_regs <= 32` (precondition):
405 // `field_base_reg + group_regs <= vd.bits() + (f+1) * group_regs
406 // <= vd.bits() + nf * group_regs <= 32`.
407 //
408 // Therefore, `field_base_reg + i / elems_per_reg < field_base_reg + group_regs <= 32`.
409 unsafe {
410 write_group_element(ext_state.write_vreg(), field_base_reg, i, eew, data);
411 }
412 }
413 }
414
415 ext_state.mark_vs_dirty();
416 ext_state.reset_vstart();
417 Ok(())
418}
419
420/// Execute an indexed (unordered or ordered) or indexed segment load.
421///
422/// For element `i`, reads `index_eew`-sized bytes from register group `vs2` at element `i`
423/// to obtain a zero-extended byte offset, then loads `nf` data fields from
424/// `base + offset + f * data_eew.bytes()`. Unordered vs ordered is functionally identical in
425/// a software interpreter.
426///
427/// # Safety
428/// - `vd.bits() % data_group_regs == 0`
429/// - `vd.bits() + nf * data_group_regs <= 32`
430/// - `vs2.bits() + (vl - 1) / (VLENB / index_eew.bytes()) < 32` (all `vl` index elements fit within
431/// the register file; satisfied when `vs2` is alignment-checked against `EMUL_index` and `vl` is
432/// the architectural `vl` bounded by `VLMAX`)
433/// - `vl <= data_group_regs * VLENB / data_eew.bytes()` (all `vl` elements fit in a data group)
434/// - When `vm=false`: `vd` does not overlap `v0` (i.e. `vd.bits() != 0`)
435#[inline(always)]
436#[expect(clippy::too_many_arguments, reason = "Internal API")]
437#[doc(hidden)]
438pub unsafe fn execute_indexed_load<Reg, ExtState, Memory, CustomError>(
439 ext_state: &mut ExtState,
440 memory: &Memory,
441 vd: VReg,
442 vs2: VReg,
443 vm: bool,
444 vl: u32,
445 vstart: u32,
446 base: u64,
447 data_eew: Eew,
448 index_eew: Eew,
449 data_group_regs: u8,
450 nf: u8,
451) -> Result<(), ExecutionError<Reg::Type, CustomError>>
452where
453 Reg: Register,
454 ExtState: VectorRegistersExt<Reg, CustomError>,
455 [(); ExtState::ELEN as usize]:,
456 [(); ExtState::VLEN as usize]:,
457 [(); ExtState::VLENB as usize]:,
458 Memory: VirtualMemory,
459 CustomError: fmt::Debug,
460{
461 let index_base_reg = usize::from(vs2.bits());
462
463 // SAFETY: `vl <= VLMAX <= VLEN` (precondition), so `vl.div_ceil(8) <= VLEN / 8 = VLENB`.
464 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
465
466 for i in vstart..vl {
467 if !vm && !mask_bit(&mask_buf, i) {
468 continue;
469 }
470
471 // SAFETY: need `index_base_reg + i / (VLENB / index_eew.bytes()) < 32`.
472 //
473 // The caller verified `vs2` is aligned to `EMUL_index` registers and that
474 // `vs2.bits() + EMUL_index <= 32`. `EMUL_index` is defined so that
475 // `EMUL_index * (VLENB / index_eew.bytes()) = VLMAX`. Since `i < vl <= VLMAX`,
476 // `i / (VLENB / index_eew.bytes()) < EMUL_index`, and therefore
477 // `index_base_reg + i / (VLENB / index_eew.bytes()) < index_base_reg + EMUL_index <= 32`.
478 let index_buf =
479 unsafe { read_group_element(ext_state.read_vreg(), index_base_reg, i, index_eew) };
480 let offset = u64::from_le_bytes(index_buf);
481 let elem_addr = base.wrapping_add(offset);
482
483 let data_elem_bytes = data_eew.bytes();
484 for f in 0..nf {
485 let addr = elem_addr.wrapping_add(u64::from(f) * u64::from(data_elem_bytes));
486 let data = match read_mem_element(memory, addr, data_eew) {
487 Ok(data) => data,
488 Err(mem_err) => {
489 if f > 0 || i > vstart {
490 ext_state.mark_vs_dirty();
491 ext_state.set_vstart(i as u16);
492 }
493 return Err(ExecutionError::MemoryAccess(mem_err));
494 }
495 };
496 let field_base_reg = vd.bits() + f * data_group_regs;
497 // SAFETY: need `field_base_reg + i / (VLENB / data_eew.bytes()) < 32`.
498 //
499 // Let `data_elems_per_reg = VLENB / data_eew.bytes()`.
500 // `i < vl <= data_group_regs * data_elems_per_reg` (precondition), so
501 // `i / data_elems_per_reg < data_group_regs`.
502 //
503 // `field_base_reg = vd.bits() + f * data_group_regs`. Since `f < nf` and
504 // `vd.bits() + nf * data_group_regs <= 32` (precondition):
505 // `field_base_reg + data_group_regs <= vd.bits() + (f+1) * data_group_regs
506 // <= vd.bits() + nf * data_group_regs <= 32`.
507 //
508 // Therefore,
509 // `field_base_reg + i / data_elems_per_reg < field_base_reg + data_group_regs <= 32`.
510 unsafe {
511 write_group_element(ext_state.write_vreg(), field_base_reg, i, data_eew, data);
512 }
513 }
514 }
515
516 ext_state.mark_vs_dirty();
517 ext_state.reset_vstart();
518 Ok(())
519}