Skip to main content

ab_riscv_interpreter/v/zve64x/reduction/
zve64x_reduction_helpers.rs

1//! Opaque helpers for Zve64x extension
2use crate::v::vector_registers::VectorRegistersExt;
3use crate::v::zve64x::arith::zve64x_arith_helpers::{
4    read_element_u64, sign_extend, write_element_u64,
5};
6use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute a single-width integer reduction.
11///
12/// # Safety
13/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
14/// - `vstart == 0` (verified by caller; reductions with non-zero vstart are illegal)
15/// - `vl <= group_regs * VLENB / sew_bytes`
16/// - `vl <= VLEN`
17#[inline(always)]
18#[expect(clippy::too_many_arguments, reason = "Internal API")]
19#[doc(hidden)]
20pub unsafe fn execute_reduce_op<Reg, ExtState, CustomError, F>(
21    ext_state: &mut ExtState,
22    vd: VReg,
23    vs2: VReg,
24    vs1: VReg,
25    vm: bool,
26    vl: u32,
27    sew: Vsew,
28    op: F,
29) where
30    Reg: Register,
31    ExtState: VectorRegistersExt<Reg, CustomError>,
32    [(); ExtState::ELEN as usize]:,
33    [(); ExtState::VLEN as usize]:,
34    [(); ExtState::VLENB as usize]:,
35    CustomError: fmt::Debug,
36    F: Fn(u64, u64, Vsew) -> u64,
37{
38    // Spec ยง5.4: when vstart >= vl, no element of vd is updated. For reductions this means
39    // vl == 0 (since caller has verified vstart == 0). In that case we must not write vd and
40    // must not mark vs dirty.
41    if vl == 0 {
42        ext_state.reset_vstart();
43        return;
44    }
45    // SAFETY: element 0 always fits within register vs1
46    let init = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs1.bits()), 0, sew) };
47    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
48    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
49    let vs2_base = usize::from(vs2.bits());
50    let mut acc = init;
51    for i in 0..vl {
52        if !mask_bit(&mask_buf, i) {
53            continue;
54        }
55        // SAFETY: `vs2_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`
56        let elem = unsafe { read_element_u64(ext_state.read_vreg(), vs2_base, i, sew) };
57        acc = op(acc, elem, sew);
58    }
59    // SAFETY: element 0 always fits within register vd
60    unsafe {
61        write_element_u64(ext_state.write_vreg(), vd.bits(), 0, sew, acc);
62    }
63    ext_state.mark_vs_dirty();
64    ext_state.reset_vstart();
65}
66
67/// Execute a widening integer sum reduction.
68///
69/// # Safety
70/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
71/// - `2 * sew.bits() <= ELEN` (verified by caller)
72/// - `vstart == 0` (verified by caller)
73/// - `vl <= group_regs * VLENB / sew_bytes`
74/// - `vl <= VLEN`
75#[inline(always)]
76#[expect(clippy::too_many_arguments, reason = "Internal API")]
77#[doc(hidden)]
78pub unsafe fn execute_widening_reduce_op<Reg, ExtState, CustomError, F>(
79    ext_state: &mut ExtState,
80    vd: VReg,
81    vs2: VReg,
82    vs1: VReg,
83    vm: bool,
84    vl: u32,
85    sew: Vsew,
86    op: F,
87    sign_extend_src: bool,
88) where
89    Reg: Register,
90    ExtState: VectorRegistersExt<Reg, CustomError>,
91    [(); ExtState::ELEN as usize]:,
92    [(); ExtState::VLEN as usize]:,
93    [(); ExtState::VLENB as usize]:,
94    CustomError: fmt::Debug,
95    F: Fn(u64, u64, Vsew) -> u64,
96{
97    let wide_sew = match sew {
98        Vsew::E8 => Vsew::E16,
99        Vsew::E16 => Vsew::E32,
100        Vsew::E32 => Vsew::E64,
101        // SAFETY: caller verified `2*SEW <= ELEN`; E64 widening is unreachable here
102        Vsew::E64 => unsafe { core::hint::unreachable_unchecked() },
103    };
104    if vl == 0 {
105        ext_state.reset_vstart();
106        return;
107    }
108    // SAFETY: element 0 always fits within register vs1
109    let init =
110        unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs1.bits()), 0, wide_sew) };
111    // SAFETY: `vl <= VLEN`
112    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
113    let vs2_base = usize::from(vs2.bits());
114    let mut acc = init;
115    for i in 0..vl {
116        if !mask_bit(&mask_buf, i) {
117            continue;
118        }
119        // SAFETY: same bounds argument as `execute_reduce_op`
120        let raw = unsafe { read_element_u64(ext_state.read_vreg(), vs2_base, i, sew) };
121        let elem = if sign_extend_src {
122            sign_extend(raw, sew).cast_unsigned()
123        } else {
124            raw
125        };
126        acc = op(acc, elem, wide_sew);
127    }
128    // SAFETY: element 0 always fits within register vd
129    unsafe {
130        write_element_u64(ext_state.write_vreg(), vd.bits(), 0, wide_sew, acc);
131    }
132    ext_state.mark_vs_dirty();
133    ext_state.reset_vstart();
134}