Skip to main content

ab_riscv_interpreter/v/zve64x/arith/
zve64x_arith_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Check that `vreg` (`vd`/`vs`) is aligned to `group_regs` and fits within `[0, 32)`
11#[inline(always)]
12#[doc(hidden)]
13pub fn check_vreg_group_alignment<Reg, Memory, PC, CustomError>(
14    program_counter: &PC,
15    vreg: VReg,
16    group_regs: u8,
17) -> Result<(), ExecutionError<Reg::Type, CustomError>>
18where
19    Reg: Register,
20    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
21{
22    let vreg_idx = vreg.bits();
23    if !vreg_idx.is_multiple_of(group_regs) || vreg_idx + group_regs > 32 {
24        return Err(ExecutionError::IllegalInstruction {
25            address: program_counter.old_pc(INSTRUCTION_SIZE),
26        });
27    }
28    Ok(())
29}
30
31/// Check mask-destination / source overlap constraint for compare instructions.
32///
33/// Per RVV §11.8: a mask destination register may overlap a source register group only when
34/// the source group occupies a single register (LMUL ≤ 1, i.e. `group_regs == 1`). Otherwise
35/// the encoding is reserved and raises an illegal instruction.
36#[inline(always)]
37#[doc(hidden)]
38pub fn check_mask_dest_no_overlap<Reg, Memory, PC, CustomError>(
39    program_counter: &PC,
40    vd: VReg,
41    src_base: VReg,
42    group_regs: u8,
43) -> Result<(), ExecutionError<Reg::Type, CustomError>>
44where
45    Reg: Register,
46    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
47{
48    if group_regs > 1 {
49        let vd_idx = vd.bits();
50        let src = src_base.bits();
51        if vd_idx >= src && vd_idx < src + group_regs {
52            return Err(ExecutionError::IllegalInstruction {
53                address: program_counter.old_pc(INSTRUCTION_SIZE),
54            });
55        }
56    }
57    Ok(())
58}
59
60/// Read a SEW-wide element from register group `[base_reg, base_reg + group_regs)` as `u64`.
61///
62/// Element `elem_i` occupies bytes at:
63///   - register `base_reg + elem_i / elems_per_reg`
64///   - byte offset `(elem_i % elems_per_reg) * sew_bytes`
65///
66/// The value is zero-extended to `u64`.
67///
68/// # Safety
69/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
70#[inline(always)]
71pub(in super::super) unsafe fn read_element_u64<const VLENB: usize>(
72    vreg: &[[u8; VLENB]; 32],
73    base_reg: usize,
74    elem_i: u32,
75    sew: Vsew,
76) -> u64 {
77    let sew_bytes = usize::from(sew.bytes());
78    let elems_per_reg = VLENB / sew_bytes;
79    let reg_off = elem_i as usize / elems_per_reg;
80    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
81    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
82    let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
83    // SAFETY: `byte_off + sew_bytes <= VLENB` because `byte_off` is at most
84    // `(elems_per_reg - 1) * sew_bytes = VLENB - sew_bytes`
85    let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
86    let mut buf = [0u8; 8];
87    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
88    unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
89    u64::from_le_bytes(buf)
90}
91
92/// Write a SEW-wide element (low `sew_bytes` of `value`) into register group
93/// `[base_reg, base_reg + group_regs)` at element index `elem_i`.
94///
95/// # Safety
96/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
97#[inline(always)]
98pub(in super::super) unsafe fn write_element_u64<const VLENB: usize>(
99    vreg: &mut [[u8; VLENB]; 32],
100    base_reg: u8,
101    elem_i: u32,
102    sew: Vsew,
103    value: u64,
104) {
105    let sew_bytes = usize::from(sew.bytes());
106    let elems_per_reg = VLENB / sew_bytes;
107    let reg_off = elem_i as usize / elems_per_reg;
108    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
109    let buf = value.to_le_bytes();
110    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
111    let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
112    // SAFETY: `byte_off + sew_bytes <= VLENB` - same argument as `read_element_u64`.
113    // `sew_bytes <= 8` for all `Vsew` variants.
114    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
115    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
116    dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
117}
118
119/// Write one mask bit (the comparison result for element `elem_i`) into register `vd`.
120///
121/// Bits are stored LSB-first: element `i` lives at byte `i / 8`, bit `i % 8`.
122/// Only the target bit is modified; all other bits are undisturbed (tail-undisturbed semantics
123/// required for mask destinations per spec §5.3).
124///
125/// # Safety
126/// `elem_i / 8 < VLENB` must hold, i.e. `elem_i < VLEN`. This is guaranteed when
127/// `elem_i < vl <= VLMAX <= VLEN`.
128#[inline(always)]
129pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
130    vreg: &mut [[u8; VLENB]; 32],
131    vd: VReg,
132    elem_i: u32,
133    result: bool,
134) {
135    let byte_idx = (elem_i / u8::BITS) as usize;
136    let bit_idx = elem_i % u8::BITS;
137    // SAFETY: `byte_idx < VLENB` by the caller's precondition
138    let byte = unsafe {
139        vreg.get_unchecked_mut(usize::from(vd.bits()))
140            .get_unchecked_mut(byte_idx)
141    };
142    if result {
143        *byte |= 1 << bit_idx;
144    } else {
145        *byte &= !(1 << bit_idx);
146    }
147}
148
149/// Operand source
150#[derive(Debug)]
151#[doc(hidden)]
152pub enum OpSrc {
153    /// Vector-vector: source register index
154    Vreg(u8),
155    /// Vector-scalar: scalar value (sign- or zero-extended to u64)
156    Scalar(u64),
157}
158
159/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
160///
161/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result (only the
162/// low `sew.bits()` are written back).
163///
164/// # Safety
165/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (verified by caller)
166/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
167/// - `vl <= group_regs * VLENB / sew_bytes` (all `vl` elements fit within the register group)
168/// - When `vm=false`: `vd.bits() != 0` (vd does not overlap v0)
169#[inline(always)]
170#[expect(clippy::too_many_arguments, reason = "Internal API")]
171#[doc(hidden)]
172pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
173    ext_state: &mut ExtState,
174    vd: VReg,
175    vs2: VReg,
176    src: OpSrc,
177    vm: bool,
178    vl: u32,
179    vstart: u32,
180    sew: Vsew,
181    op: F,
182) where
183    Reg: Register,
184    ExtState: VectorRegistersExt<Reg, CustomError>,
185    [(); ExtState::ELEN as usize]:,
186    [(); ExtState::VLEN as usize]:,
187    [(); ExtState::VLENB as usize]:,
188    CustomError: fmt::Debug,
189    F: Fn(u64, u64, Vsew) -> u64,
190{
191    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
192    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
193
194    let vd_base = vd.bits();
195    let vs2_base = vs2.bits();
196
197    for i in vstart..vl {
198        if !mask_bit(&mask_buf, i) {
199            continue;
200        }
201
202        // SAFETY: `vs2_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`,
203        // so `vs2_base + i / elems_per_reg < vs2_base + group_regs <= 32`
204        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
205
206        let b = match &src {
207            OpSrc::Vreg(vs1_base) => {
208                // SAFETY: same argument as vs2
209                unsafe { read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew) }
210            }
211            OpSrc::Scalar(val) => *val,
212        };
213
214        let result = op(a, b, sew);
215
216        // SAFETY: `vd_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`,
217        // so `vd_base + i / elems_per_reg < vd_base + group_regs <= 32`
218        unsafe {
219            write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
220        }
221    }
222
223    ext_state.mark_vs_dirty();
224    ext_state.reset_vstart();
225}
226
227/// Execute a single-width element-wise integer compare over `vstart..vl`, writing one result
228/// bit per element into the mask register `vd`.
229///
230/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew) -> bool`.
231///
232/// Mask destination tail bits (indices `>= vl`) are always left undisturbed per spec §5.3,
233/// regardless of `vta`. Only bits in `vstart..vl` are written.
234///
235/// # Safety
236/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
237/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
238/// - `vl <= group_regs * VLENB / sew_bytes`
239/// - `vl <= VLEN` (so every element index fits within the mask register)
240#[inline(always)]
241#[expect(clippy::too_many_arguments, reason = "Internal API")]
242#[doc(hidden)]
243pub unsafe fn execute_compare_op<Reg, ExtState, CustomError, F>(
244    ext_state: &mut ExtState,
245    vd: VReg,
246    vs2: VReg,
247    src: OpSrc,
248    vm: bool,
249    vl: u32,
250    vstart: u32,
251    sew: Vsew,
252    op: F,
253) where
254    Reg: Register,
255    ExtState: VectorRegistersExt<Reg, CustomError>,
256    [(); ExtState::ELEN as usize]:,
257    [(); ExtState::VLEN as usize]:,
258    [(); ExtState::VLENB as usize]:,
259    CustomError: fmt::Debug,
260    F: Fn(u64, u64, Vsew) -> bool,
261{
262    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
263    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
264
265    let vs2_base = vs2.bits();
266
267    for i in vstart..vl {
268        // When masked, inactive elements in the destination mask register are left undisturbed
269        // (spec §12.8: "mask register results follow mask-undisturbed policy")
270        if !mask_bit(&mask_buf, i) {
271            continue;
272        }
273
274        // SAFETY: same argument as in `execute_arith_op`
275        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
276
277        let b = match &src {
278            OpSrc::Vreg(vs1_base) => {
279                // SAFETY: same argument as vs2
280                unsafe { read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew) }
281            }
282            OpSrc::Scalar(val) => *val,
283        };
284
285        let result = op(a, b, sew);
286
287        // SAFETY: `i < vl <= VLMAX <= VLEN`, so `i / 8 < VLEN / 8 = VLENB`
288        unsafe {
289            write_mask_bit(ext_state.write_vreg(), vd, i, result);
290        }
291    }
292
293    ext_state.mark_vs_dirty();
294    ext_state.reset_vstart();
295}
296
297/// Sign-extend the low `sew.bits()` of `val` to a full `i64`
298#[inline(always)]
299#[doc(hidden)]
300pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
301    let shift = u64::BITS - u32::from(sew.bits());
302    (val.cast_signed() << shift) >> shift
303}
304
305/// Mask off the upper bits of a `u64` to leave only the low `sew.bits()`.
306///
307/// Used for unsigned arithmetic and comparisons where only the SEW-wide portion is significant. For
308/// SEW = 64 this is a no-op (all bits are significant).
309#[inline(always)]
310#[doc(hidden)]
311pub fn sew_mask(sew: Vsew) -> u64 {
312    if u32::from(sew.bits()) == u64::BITS {
313        u64::MAX
314    } else {
315        (1u64 << sew.bits()) - 1
316    }
317}