Skip to main content

ab_riscv_interpreter/v/zvexx/reduction/
zvexx_reduction_helpers.rs

1//! Opaque helpers for ZveXx extension
2use crate::v::vector_registers::VectorRegistersExt;
3use crate::v::zvexx::arith::zvexx_arith_helpers::{
4    read_element_u64, sign_extend, write_element_u64,
5};
6use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9use core::hint::cold_path;
10
11/// Execute a single-width integer reduction.
12///
13/// # Safety
14/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32` (verified by caller)
15/// - `vstart == 0` (verified by caller; reductions with non-zero vstart are illegal)
16/// - `vl <= group_regs * VLENB / sew_bytes`
17/// - `vl <= VLEN`
18#[inline(always)]
19#[expect(clippy::too_many_arguments, reason = "Internal API")]
20#[doc(hidden)]
21pub unsafe fn execute_reduce_op<Reg, ExtState, CustomError, F>(
22    ext_state: &mut ExtState,
23    vd: VReg,
24    vs2: VReg,
25    vs1: VReg,
26    vm: bool,
27    vl: u32,
28    sew: Vsew,
29    op: F,
30) where
31    Reg: Register,
32    ExtState: VectorRegistersExt<Reg, CustomError>,
33    [(); ExtState::ELEN as usize]:,
34    [(); ExtState::VLEN as usize]:,
35    [(); ExtState::VLENB as usize]:,
36    CustomError: fmt::Debug,
37    F: Fn(u64, u64, Vsew) -> u64,
38{
39    // Spec ยง5.4: when vstart >= vl, no element of vd is updated. For reductions this means
40    // vl == 0 (since caller has verified vstart == 0). In that case we must not write vd and
41    // must not mark vs dirty.
42    if vl == 0 {
43        cold_path();
44        ext_state.reset_vstart();
45        return;
46    }
47    // SAFETY: element 0 always fits within register vs1
48    let init = unsafe { read_element_u64(ext_state.read_vregs(), vs1, 0, sew) };
49    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
50    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
51    let mut acc = init;
52    for i in 0..vl {
53        if !mask_bit(&mask_buf, i) {
54            continue;
55        }
56        // SAFETY: `vs2 % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`
57        let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
58        acc = op(acc, elem, sew);
59    }
60    // SAFETY: element 0 always fits within register vd
61    unsafe {
62        write_element_u64(ext_state.write_vregs(), vd, 0, sew, acc);
63    }
64    ext_state.mark_vs_dirty();
65    ext_state.reset_vstart();
66}
67
68/// Execute a widening integer sum reduction.
69///
70/// # Safety
71/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32` (verified by caller)
72/// - `sew.double_width().is_some()` (verified by caller)
73/// - `vstart == 0` (verified by caller)
74/// - `vl <= group_regs * VLENB / sew_bytes`
75/// - `vl <= VLEN`
76#[inline(always)]
77#[expect(clippy::too_many_arguments, reason = "Internal API")]
78#[doc(hidden)]
79pub unsafe fn execute_widening_reduce_op<
80    const SIGN_EXTEND_SRC: bool,
81    Reg,
82    ExtState,
83    CustomError,
84    F,
85>(
86    ext_state: &mut ExtState,
87    vd: VReg,
88    vs2: VReg,
89    vs1: VReg,
90    vm: bool,
91    vl: u32,
92    sew: Vsew,
93    op: F,
94) where
95    Reg: Register,
96    ExtState: VectorRegistersExt<Reg, CustomError>,
97    [(); ExtState::ELEN as usize]:,
98    [(); ExtState::VLEN as usize]:,
99    [(); ExtState::VLENB as usize]:,
100    CustomError: fmt::Debug,
101    F: Fn(u64, u64, Vsew) -> u64,
102{
103    let Some(wide_sew) = sew.double_width() else {
104        // SAFETY: caller verified `2*SEW <= ELEN`; E64 widening is unreachable here
105        unsafe { core::hint::unreachable_unchecked() }
106    };
107    if vl == 0 {
108        cold_path();
109        ext_state.reset_vstart();
110        return;
111    }
112    // SAFETY: element 0 always fits within register vs1
113    let init = unsafe { read_element_u64(ext_state.read_vregs(), vs1, 0, wide_sew) };
114    // SAFETY: `vl <= VLEN`
115    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
116    let mut acc = init;
117    for i in 0..vl {
118        if !mask_bit(&mask_buf, i) {
119            continue;
120        }
121        // SAFETY: same bounds argument as `execute_reduce_op`
122        let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
123        let elem = if SIGN_EXTEND_SRC {
124            sign_extend(raw, sew).cast_unsigned()
125        } else {
126            raw
127        };
128        acc = op(acc, elem, wide_sew);
129    }
130    // SAFETY: element 0 always fits within register vd
131    unsafe {
132        write_element_u64(ext_state.write_vregs(), vd, 0, wide_sew, acc);
133    }
134    ext_state.mark_vs_dirty();
135    ext_state.reset_vstart();
136}