Skip to main content

ab_riscv_interpreter/v/zvexx/reduction/
zvexx_reduction_helpers.rs

1//! Opaque helpers for ZveXx extension
2use crate::v::vector_registers::VectorRegistersExt;
3use crate::v::zvexx::arith::zvexx_arith_helpers::{
4    read_element_u64, sign_extend, write_element_u64,
5};
6use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute a single-width integer reduction.
11///
12/// # Safety
13/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32` (verified by caller)
14/// - `vstart == 0` (verified by caller; reductions with non-zero vstart are illegal)
15/// - `vl <= group_regs * VLENB / sew_bytes`
16/// - `vl <= VLEN`
17#[inline(always)]
18#[expect(clippy::too_many_arguments, reason = "Internal API")]
19#[doc(hidden)]
20pub unsafe fn execute_reduce_op<Reg, ExtState, CustomError, F>(
21    ext_state: &mut ExtState,
22    vd: VReg,
23    vs2: VReg,
24    vs1: VReg,
25    vm: bool,
26    vl: u32,
27    sew: Vsew,
28    op: F,
29) where
30    Reg: Register,
31    ExtState: VectorRegistersExt<Reg, CustomError>,
32    [(); ExtState::ELEN as usize]:,
33    [(); ExtState::VLEN as usize]:,
34    [(); ExtState::VLENB as usize]:,
35    CustomError: fmt::Debug,
36    F: Fn(u64, u64, Vsew) -> u64,
37{
38    // Spec ยง5.4: when vstart >= vl, no element of vd is updated. For reductions this means
39    // vl == 0 (since caller has verified vstart == 0). In that case we must not write vd and
40    // must not mark vs dirty.
41    if vl == 0 {
42        ext_state.reset_vstart();
43        return;
44    }
45    // SAFETY: element 0 always fits within register vs1
46    let init = unsafe { read_element_u64(ext_state.read_vregs(), vs1, 0, sew) };
47    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
48    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
49    let mut acc = init;
50    for i in 0..vl {
51        if !mask_bit(&mask_buf, i) {
52            continue;
53        }
54        // SAFETY: `vs2 % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`
55        let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
56        acc = op(acc, elem, sew);
57    }
58    // SAFETY: element 0 always fits within register vd
59    unsafe {
60        write_element_u64(ext_state.write_vregs(), vd, 0, sew, acc);
61    }
62    ext_state.mark_vs_dirty();
63    ext_state.reset_vstart();
64}
65
66/// Execute a widening integer sum reduction.
67///
68/// # Safety
69/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32` (verified by caller)
70/// - `sew.double_width().is_some()` (verified by caller)
71/// - `vstart == 0` (verified by caller)
72/// - `vl <= group_regs * VLENB / sew_bytes`
73/// - `vl <= VLEN`
74#[inline(always)]
75#[expect(clippy::too_many_arguments, reason = "Internal API")]
76#[doc(hidden)]
77pub unsafe fn execute_widening_reduce_op<Reg, ExtState, CustomError, F>(
78    ext_state: &mut ExtState,
79    vd: VReg,
80    vs2: VReg,
81    vs1: VReg,
82    vm: bool,
83    vl: u32,
84    sew: Vsew,
85    op: F,
86    sign_extend_src: bool,
87) where
88    Reg: Register,
89    ExtState: VectorRegistersExt<Reg, CustomError>,
90    [(); ExtState::ELEN as usize]:,
91    [(); ExtState::VLEN as usize]:,
92    [(); ExtState::VLENB as usize]:,
93    CustomError: fmt::Debug,
94    F: Fn(u64, u64, Vsew) -> u64,
95{
96    let Some(wide_sew) = sew.double_width() else {
97        // SAFETY: caller verified `2*SEW <= ELEN`; E64 widening is unreachable here
98        unsafe { core::hint::unreachable_unchecked() }
99    };
100    if vl == 0 {
101        ext_state.reset_vstart();
102        return;
103    }
104    // SAFETY: element 0 always fits within register vs1
105    let init = unsafe { read_element_u64(ext_state.read_vregs(), vs1, 0, wide_sew) };
106    // SAFETY: `vl <= VLEN`
107    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
108    let mut acc = init;
109    for i in 0..vl {
110        if !mask_bit(&mask_buf, i) {
111            continue;
112        }
113        // SAFETY: same bounds argument as `execute_reduce_op`
114        let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
115        let elem = if sign_extend_src {
116            sign_extend(raw, sew).cast_unsigned()
117        } else {
118            raw
119        };
120        acc = op(acc, elem, wide_sew);
121    }
122    // SAFETY: element 0 always fits within register vd
123    unsafe {
124        write_element_u64(ext_state.write_vregs(), vd, 0, wide_sew, acc);
125    }
126    ext_state.mark_vs_dirty();
127    ext_state.reset_vstart();
128}