Skip to main content

ab_riscv_interpreter/v/zve64x/reduction/
zve64x_reduction_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::arith::zve64x_arith_helpers::{
5    read_element_u64, sign_extend, write_element_u64,
6};
7use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
8use crate::{InterpreterState, ProgramCounter, VirtualMemory};
9use ab_riscv_primitives::instructions::v::Vsew;
10use ab_riscv_primitives::registers::general_purpose::Register;
11use ab_riscv_primitives::registers::vector::VReg;
12use core::fmt;
13
14/// Execute a single-width integer reduction over `vstart..vl`.
15///
16/// The initial accumulator is read from element 0 of `vs1` (always SEW-wide, single register).
17/// Active elements of `vs2` (masked by `vm` / `v0`) are folded into the accumulator using `op`.
18/// The scalar result is written to element 0 of `vd` (single register, always SEW-wide).
19///
20/// When `vl == 0` or all elements are masked out, element 0 of `vs1` is passed through to `vd[0]`
21/// unchanged, per spec §14.1.
22///
23/// `op` receives `(accumulator: u64, element: u64, sew: Vsew) -> u64`. Only the low `sew.bits()`
24/// of the returned value are significant; all arithmetic should be performed within that width.
25///
26/// # Safety
27/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
28/// - `vl <= group_regs * VLENB / sew_bytes`
29/// - `vl <= VLEN`
30#[inline(always)]
31#[expect(clippy::too_many_arguments, reason = "Internal API")]
32#[doc(hidden)]
33pub unsafe fn execute_reduce_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
34    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
35    vd: VReg,
36    vs2: VReg,
37    vs1: VReg,
38    vm: bool,
39    vl: u32,
40    vstart: u32,
41    sew: Vsew,
42    op: F,
43) where
44    Reg: Register,
45    [(); Reg::N]:,
46    ExtState: VectorRegistersExt<Reg, CustomError>,
47    [(); ExtState::ELEN as usize]:,
48    [(); ExtState::VLEN as usize]:,
49    [(); ExtState::VLENB as usize]:,
50    Memory: VirtualMemory,
51    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
52    CustomError: fmt::Debug,
53    F: Fn(u64, u64, Vsew) -> u64,
54{
55    // Read scalar initial value from vs1[0]; vs1 is always a single register, SEW-wide
56    // SAFETY: element 0 always fits within register vs1 (0 < VLENB / sew_bytes)
57    let init =
58        unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs1.bits()), 0, sew) };
59
60    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
61    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
62
63    let vs2_base = usize::from(vs2.bits());
64
65    // When vstart > 0 the spec says the reduction resumes from wherever it left off; since
66    // reductions are not restartable in general, vstart != 0 at entry to a reduction instruction
67    // is reserved. We follow the simplest compliant path: treat vstart as the lower bound of
68    // active elements, initialising the accumulator to vs1[0] unconditionally.
69    let mut acc = init;
70    for i in vstart..vl {
71        // Inactive elements are skipped; they do not contribute to the result
72        if !mask_bit(&mask_buf, i) {
73            continue;
74        }
75        // SAFETY: `vs2_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`,
76        // so `vs2_base + i / elems_per_reg < vs2_base + group_regs <= 32`
77        let elem = unsafe { read_element_u64(state.ext_state.read_vreg(), vs2_base, i, sew) };
78        acc = op(acc, elem, sew);
79    }
80
81    // Write scalar result to vd[0]; vd is always a single register, SEW-wide
82    // SAFETY: element 0 always fits within register vd
83    unsafe {
84        write_element_u64(state.ext_state.write_vreg(), vd.bits(), 0, sew, acc);
85    }
86
87    state.ext_state.mark_vs_dirty();
88    state.ext_state.reset_vstart();
89}
90
91/// Execute a widening integer sum reduction over `vstart..vl`.
92///
93/// `vs2` elements are SEW-wide; the accumulator and result (`vs1[0]` / `vd[0]`) are 2*SEW-wide.
94/// The `sign_extend_src` flag controls whether each `vs2` element is sign-extended (`vwredsum`)
95/// or zero-extended (`vwredsumu`) to 2*SEW before accumulation.
96///
97/// Per spec §14.2, the result SEW for the destination is 2*SEW, stored as a single element in
98/// `vd[0]` using the widened width. `vs1[0]` is also read at 2*SEW.
99///
100/// # Safety
101/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
102/// - `2 * sew.bits() <= ELEN` (verified by caller - widening constraint)
103/// - `vl <= group_regs * VLENB / sew_bytes`
104/// - `vl <= VLEN`
105#[inline(always)]
106#[expect(clippy::too_many_arguments, reason = "Internal API")]
107#[doc(hidden)]
108pub unsafe fn execute_widening_reduce_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
109    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
110    vd: VReg,
111    vs2: VReg,
112    vs1: VReg,
113    vm: bool,
114    vl: u32,
115    vstart: u32,
116    sew: Vsew,
117    op: F,
118    sign_extend_src: bool,
119) where
120    Reg: Register,
121    [(); Reg::N]:,
122    ExtState: VectorRegistersExt<Reg, CustomError>,
123    [(); ExtState::ELEN as usize]:,
124    [(); ExtState::VLEN as usize]:,
125    [(); ExtState::VLENB as usize]:,
126    Memory: VirtualMemory,
127    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
128    CustomError: fmt::Debug,
129    F: Fn(u64, u64, Vsew) -> u64,
130{
131    // The widened SEW for vs1/vd operands
132    // Caller guarantees `2 * sew.bits() <= ELEN <= 64`, so this fits in u8
133    let wide_sew = match sew {
134        Vsew::E8 => Vsew::E16,
135        Vsew::E16 => Vsew::E32,
136        // E32 -> E64 is the max widening allowed under ELEN=64
137        Vsew::E32 => Vsew::E64,
138        // E64 widening would require 128-bit result; caller must reject this before entry
139        Vsew::E64 => {
140            // SAFETY: caller verified `2*SEW <= ELEN`; E64 widening is unreachable here
141            unsafe { core::hint::unreachable_unchecked() }
142        }
143    };
144
145    // Read initial accumulator from vs1[0] at 2*SEW
146    // SAFETY: element 0 always fits within register vs1
147    let init = unsafe {
148        read_element_u64(
149            state.ext_state.read_vreg(),
150            usize::from(vs1.bits()),
151            0,
152            wide_sew,
153        )
154    };
155
156    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
157    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
158
159    let vs2_base = usize::from(vs2.bits());
160
161    let mut acc = init;
162    for i in vstart..vl {
163        if !mask_bit(&mask_buf, i) {
164            continue;
165        }
166        // SAFETY: same bounds argument as `execute_reduce_op`
167        let raw = unsafe { read_element_u64(state.ext_state.read_vreg(), vs2_base, i, sew) };
168        // Widen element to 2*SEW
169        let elem = if sign_extend_src {
170            sign_extend(raw, sew).cast_unsigned()
171        } else {
172            // Zero-extension: raw is already zero-extended to u64 by read_element_u64
173            raw
174        };
175        acc = op(acc, elem, wide_sew);
176    }
177
178    // Write result to vd[0] at 2*SEW
179    // SAFETY: element 0 always fits within register vd
180    unsafe {
181        write_element_u64(state.ext_state.write_vreg(), vd.bits(), 0, wide_sew, acc);
182    }
183
184    state.ext_state.mark_vs_dirty();
185    state.ext_state.reset_vstart();
186}