Skip to main content

ab_riscv_interpreter/v/zvexx/arith/
zvexx_arith_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9use core::hint::cold_path;
10
11/// Check that `vreg` (`vd`/`vs`) is aligned to `group_regs` and fits within `[0, 32)`
12#[inline(always)]
13#[doc(hidden)]
14pub fn check_vreg_group_alignment<Reg, Memory, PC, CustomError>(
15    program_counter: &PC,
16    vreg: VReg,
17    group_regs: u8,
18) -> Result<(), ExecutionError<Reg::Type, CustomError>>
19where
20    Reg: Register,
21    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
22{
23    let vreg_idx = vreg.to_bits();
24    if !vreg_idx.is_multiple_of(group_regs) || vreg_idx + group_regs > 32 {
25        cold_path();
26        return Err(ExecutionError::IllegalInstruction {
27            address: program_counter.old_pc(INSTRUCTION_SIZE),
28        });
29    }
30    Ok(())
31}
32
33/// Check mask-destination / source overlap constraint for compare instructions.
34///
35/// Per RVV §11.8: a mask destination register may overlap a source register group only when
36/// the source group occupies a single register (LMUL ≤ 1, i.e. `group_regs == 1`). Otherwise
37/// the encoding is reserved and raises an illegal instruction.
38#[inline(always)]
39#[doc(hidden)]
40pub fn check_mask_dest_no_overlap<Reg, Memory, PC, CustomError>(
41    program_counter: &PC,
42    vd: VReg,
43    src_base: VReg,
44    group_regs: u8,
45) -> Result<(), ExecutionError<Reg::Type, CustomError>>
46where
47    Reg: Register,
48    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
49{
50    if group_regs > 1 {
51        let vd_idx = vd.to_bits();
52        let src = src_base.to_bits();
53        if vd_idx >= src && vd_idx < src + group_regs {
54            cold_path();
55            return Err(ExecutionError::IllegalInstruction {
56                address: program_counter.old_pc(INSTRUCTION_SIZE),
57            });
58        }
59    }
60    Ok(())
61}
62
63/// Read a SEW-wide element from register group `[base_reg, base_reg + group_regs)` as `u64`.
64///
65/// Element `elem_i` occupies bytes at:
66///   - register `base_reg + elem_i / elems_per_reg`
67///   - byte offset `(elem_i % elems_per_reg) * sew_bytes`
68///
69/// The value is zero-extended to `u64`.
70///
71/// # Safety
72/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
73#[inline(always)]
74pub(crate) unsafe fn read_element_u64<const VLENB: usize>(
75    vregs: &VectorRegisterFile<VLENB>,
76    base_reg: VReg,
77    elem_i: u32,
78    sew: Vsew,
79) -> u64 {
80    let sew_bytes = usize::from(sew.bytes_width());
81    let elems_per_reg = VLENB / sew_bytes;
82    let reg_off = elem_i as usize / elems_per_reg;
83    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
84    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
85    let reg = vregs
86        .get(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
87    // SAFETY: `byte_off + sew_bytes <= VLENB` because `byte_off` is at most
88    // `(elems_per_reg - 1) * sew_bytes = VLENB - sew_bytes`
89    let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
90    let mut buf = [0u8; 8];
91    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
92    unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
93    u64::from_le_bytes(buf)
94}
95
96/// Write a SEW-wide element (low `sew_bytes` of `value`) into register group
97/// `[base_reg, base_reg + group_regs)` at element index `elem_i`.
98///
99/// # Safety
100/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
101#[inline(always)]
102pub(crate) unsafe fn write_element_u64<const VLENB: usize>(
103    vregs: &mut VectorRegisterFile<VLENB>,
104    base_reg: VReg,
105    elem_i: u32,
106    sew: Vsew,
107    value: u64,
108) {
109    let sew_bytes = usize::from(sew.bytes_width());
110    let elems_per_reg = VLENB / sew_bytes;
111    let reg_off = elem_i as usize / elems_per_reg;
112    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
113    let buf = value.to_le_bytes();
114    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
115    let reg = vregs
116        .get_mut(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
117    // SAFETY: `byte_off + sew_bytes <= VLENB` - same argument as `read_element_u64`.
118    // `sew_bytes <= 8` for all `Vsew` variants.
119    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
120    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
121    dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
122}
123
124/// Write one mask bit (the comparison result for element `elem_i`) into register `vd`.
125///
126/// Bits are stored LSB-first: element `i` lives at byte `i / 8`, bit `i % 8`.
127/// Only the target bit is modified; all other bits are undisturbed (tail-undisturbed semantics
128/// required for mask destinations per spec §5.3).
129///
130/// # Safety
131/// `elem_i / 8 < VLENB` must hold, i.e. `elem_i < VLEN`. This is guaranteed when
132/// `elem_i < vl <= VLMAX <= VLEN`.
133#[inline(always)]
134pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
135    vregs: &mut VectorRegisterFile<VLENB>,
136    vd: VReg,
137    elem_i: u32,
138    result: bool,
139) {
140    let byte_idx = (elem_i / u8::BITS) as usize;
141    let bit_idx = elem_i % u8::BITS;
142    // SAFETY: `byte_idx < VLENB` by the caller's precondition
143    let byte = unsafe { vregs.get_mut(vd).get_unchecked_mut(byte_idx) };
144    if result {
145        *byte |= 1 << bit_idx;
146    } else {
147        *byte &= !(1 << bit_idx);
148    }
149}
150
151/// Operand source
152#[derive(Debug)]
153#[doc(hidden)]
154pub enum OpSrc {
155    /// Vector-vector: source register index
156    Vreg(VReg),
157    /// Vector-scalar: scalar value (sign- or zero-extended to u64)
158    Scalar(u64),
159}
160
161/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
162///
163/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result (only the
164/// low `sew.bits_width()` are written back).
165///
166/// # Safety
167/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (verified by caller)
168/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
169/// - `vl <= group_regs * VLENB / sew_bytes` (all `vl` elements fit within the register group)
170/// - When `vm=false`: `vd.to_bits() != 0` (vd does not overlap v0)
171#[inline(always)]
172#[doc(hidden)]
173pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
174    ext_state: &mut ExtState,
175    vd: VReg,
176    vs2: VReg,
177    src: OpSrc,
178    vm: bool,
179    sew: Vsew,
180    op: F,
181) where
182    Reg: Register,
183    ExtState: VectorRegistersExt<Reg, CustomError>,
184    [(); ExtState::ELEN as usize]:,
185    [(); ExtState::VLEN as usize]:,
186    [(); ExtState::VLENB as usize]:,
187    CustomError: fmt::Debug,
188    F: Fn(u64, u64, Vsew) -> u64,
189{
190    let vl = ext_state.vl();
191    let vstart = ext_state.vstart();
192    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
193    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
194
195    for i in u32::from(vstart)..vl {
196        if !mask_bit(&mask_buf, i) {
197            continue;
198        }
199
200        // SAFETY: `vs2 % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`, so
201        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
202        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
203
204        let b = match src {
205            OpSrc::Vreg(vs1_base) => {
206                // SAFETY: same argument as vs2
207                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
208            }
209            OpSrc::Scalar(val) => val,
210        };
211
212        let result = op(a, b, sew);
213
214        // SAFETY: `vd % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`, so
215        // `vd + i / elems_per_reg < vd + group_regs <= 32`
216        unsafe {
217            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
218        }
219    }
220
221    ext_state.mark_vs_dirty();
222    ext_state.reset_vstart();
223}
224
225/// Execute a single-width element-wise integer compare over `vstart..vl`, writing one result
226/// bit per element into the mask register `vd`.
227///
228/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew) -> bool`.
229///
230/// Mask destination tail bits (indices `>= vl`) are always left undisturbed per spec §5.3,
231/// regardless of `vta`. Only bits in `vstart..vl` are written.
232///
233/// # Safety
234/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32` (verified by caller)
235/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
236/// - `vl <= group_regs * VLENB / sew_bytes`
237/// - `vl <= VLEN` (so every element index fits within the mask register)
238#[inline(always)]
239#[doc(hidden)]
240pub unsafe fn execute_compare_op<Reg, ExtState, CustomError, F>(
241    ext_state: &mut ExtState,
242    vd: VReg,
243    vs2: VReg,
244    src: OpSrc,
245    vm: bool,
246    sew: Vsew,
247    op: F,
248) where
249    Reg: Register,
250    ExtState: VectorRegistersExt<Reg, CustomError>,
251    [(); ExtState::ELEN as usize]:,
252    [(); ExtState::VLEN as usize]:,
253    [(); ExtState::VLENB as usize]:,
254    CustomError: fmt::Debug,
255    F: Fn(u64, u64, Vsew) -> bool,
256{
257    let vl = ext_state.vl();
258    let vstart = ext_state.vstart();
259    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
260    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
261
262    for i in u32::from(vstart)..vl {
263        // When masked, inactive elements in the destination mask register are left undisturbed
264        // (spec §12.8: "mask register results follow mask-undisturbed policy")
265        if !mask_bit(&mask_buf, i) {
266            continue;
267        }
268
269        // SAFETY: same argument as in `execute_arith_op`
270        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
271
272        let b = match src {
273            OpSrc::Vreg(vs1_base) => {
274                // SAFETY: same argument as vs2
275                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
276            }
277            OpSrc::Scalar(val) => val,
278        };
279
280        let result = op(a, b, sew);
281
282        // SAFETY: `i < vl <= VLMAX <= VLEN`, so `i / 8 < VLEN / 8 = VLENB`
283        unsafe {
284            write_mask_bit(ext_state.write_vregs(), vd, i, result);
285        }
286    }
287
288    ext_state.mark_vs_dirty();
289    ext_state.reset_vstart();
290}
291
292/// Sign-extend the low `sew.bits_width()` of `val` to a full `i64`
293#[inline(always)]
294#[doc(hidden)]
295pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
296    let shift = u64::BITS - u32::from(sew.bits_width());
297    (val.cast_signed() << shift) >> shift
298}
299
300/// Mask off the upper bits of a `u64` to leave only the low `sew.bits_width()`.
301///
302/// Used for unsigned arithmetic and comparisons where only the SEW-wide portion is significant. For
303/// SEW = 64 this is a no-op (all bits are significant).
304#[inline(always)]
305#[doc(hidden)]
306pub fn sew_mask(sew: Vsew) -> u64 {
307    if u32::from(sew.bits_width()) == u64::BITS {
308        u64::MAX
309    } else {
310        (1u64 << sew.bits_width()) - 1
311    }
312}