ab_riscv_interpreter/v/zve64x/reduction/zve64x_reduction_helpers.rs
1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::arith::zve64x_arith_helpers::{
5 read_element_u64, sign_extend, write_element_u64,
6};
7use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
8use crate::{InterpreterState, ProgramCounter, VirtualMemory};
9use ab_riscv_primitives::instructions::v::Vsew;
10use ab_riscv_primitives::registers::general_purpose::Register;
11use ab_riscv_primitives::registers::vector::VReg;
12use core::fmt;
13
14/// Execute a single-width integer reduction over `vstart..vl`.
15///
16/// The initial accumulator is read from element 0 of `vs1` (always SEW-wide, single register).
17/// Active elements of `vs2` (masked by `vm` / `v0`) are folded into the accumulator using `op`.
18/// The scalar result is written to element 0 of `vd` (single register, always SEW-wide).
19///
20/// When `vl == 0` or all elements are masked out, element 0 of `vs1` is passed through to `vd[0]`
21/// unchanged, per spec §14.1.
22///
23/// `op` receives `(accumulator: u64, element: u64, sew: Vsew) -> u64`. Only the low `sew.bits()`
24/// of the returned value are significant; all arithmetic should be performed within that width.
25///
26/// # Safety
27/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
28/// - `vl <= group_regs * VLENB / sew_bytes`
29/// - `vl <= VLEN`
30#[inline(always)]
31#[expect(clippy::too_many_arguments, reason = "Internal API")]
32#[doc(hidden)]
33pub unsafe fn execute_reduce_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
34 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
35 vd: VReg,
36 vs2: VReg,
37 vs1: VReg,
38 vm: bool,
39 vl: u32,
40 vstart: u32,
41 sew: Vsew,
42 op: F,
43) where
44 Reg: Register,
45 [(); Reg::N]:,
46 ExtState: VectorRegistersExt<Reg, CustomError>,
47 [(); ExtState::ELEN as usize]:,
48 [(); ExtState::VLEN as usize]:,
49 [(); ExtState::VLENB as usize]:,
50 Memory: VirtualMemory,
51 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
52 CustomError: fmt::Debug,
53 F: Fn(u64, u64, Vsew) -> u64,
54{
55 // Read scalar initial value from vs1[0]; vs1 is always a single register, SEW-wide
56 // SAFETY: element 0 always fits within register vs1 (0 < VLENB / sew_bytes)
57 let init =
58 unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs1.bits()), 0, sew) };
59
60 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
61 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
62
63 let vs2_base = usize::from(vs2.bits());
64
65 // When vstart > 0 the spec says the reduction resumes from wherever it left off; since
66 // reductions are not restartable in general, vstart != 0 at entry to a reduction instruction
67 // is reserved. We follow the simplest compliant path: treat vstart as the lower bound of
68 // active elements, initialising the accumulator to vs1[0] unconditionally.
69 let mut acc = init;
70 for i in vstart..vl {
71 // Inactive elements are skipped; they do not contribute to the result
72 if !mask_bit(&mask_buf, i) {
73 continue;
74 }
75 // SAFETY: `vs2_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`,
76 // so `vs2_base + i / elems_per_reg < vs2_base + group_regs <= 32`
77 let elem = unsafe { read_element_u64(state.ext_state.read_vreg(), vs2_base, i, sew) };
78 acc = op(acc, elem, sew);
79 }
80
81 // Write scalar result to vd[0]; vd is always a single register, SEW-wide
82 // SAFETY: element 0 always fits within register vd
83 unsafe {
84 write_element_u64(state.ext_state.write_vreg(), vd.bits(), 0, sew, acc);
85 }
86
87 state.ext_state.mark_vs_dirty();
88 state.ext_state.reset_vstart();
89}
90
91/// Execute a widening integer sum reduction over `vstart..vl`.
92///
93/// `vs2` elements are SEW-wide; the accumulator and result (`vs1[0]` / `vd[0]`) are 2*SEW-wide.
94/// The `sign_extend_src` flag controls whether each `vs2` element is sign-extended (`vwredsum`)
95/// or zero-extended (`vwredsumu`) to 2*SEW before accumulation.
96///
97/// Per spec §14.2, the result SEW for the destination is 2*SEW, stored as a single element in
98/// `vd[0]` using the widened width. `vs1[0]` is also read at 2*SEW.
99///
100/// # Safety
101/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
102/// - `2 * sew.bits() <= ELEN` (verified by caller - widening constraint)
103/// - `vl <= group_regs * VLENB / sew_bytes`
104/// - `vl <= VLEN`
105#[inline(always)]
106#[expect(clippy::too_many_arguments, reason = "Internal API")]
107#[doc(hidden)]
108pub unsafe fn execute_widening_reduce_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
109 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
110 vd: VReg,
111 vs2: VReg,
112 vs1: VReg,
113 vm: bool,
114 vl: u32,
115 vstart: u32,
116 sew: Vsew,
117 op: F,
118 sign_extend_src: bool,
119) where
120 Reg: Register,
121 [(); Reg::N]:,
122 ExtState: VectorRegistersExt<Reg, CustomError>,
123 [(); ExtState::ELEN as usize]:,
124 [(); ExtState::VLEN as usize]:,
125 [(); ExtState::VLENB as usize]:,
126 Memory: VirtualMemory,
127 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
128 CustomError: fmt::Debug,
129 F: Fn(u64, u64, Vsew) -> u64,
130{
131 // The widened SEW for vs1/vd operands
132 // Caller guarantees `2 * sew.bits() <= ELEN <= 64`, so this fits in u8
133 let wide_sew = match sew {
134 Vsew::E8 => Vsew::E16,
135 Vsew::E16 => Vsew::E32,
136 // E32 -> E64 is the max widening allowed under ELEN=64
137 Vsew::E32 => Vsew::E64,
138 // E64 widening would require 128-bit result; caller must reject this before entry
139 Vsew::E64 => {
140 // SAFETY: caller verified `2*SEW <= ELEN`; E64 widening is unreachable here
141 unsafe { core::hint::unreachable_unchecked() }
142 }
143 };
144
145 // Read initial accumulator from vs1[0] at 2*SEW
146 // SAFETY: element 0 always fits within register vs1
147 let init = unsafe {
148 read_element_u64(
149 state.ext_state.read_vreg(),
150 usize::from(vs1.bits()),
151 0,
152 wide_sew,
153 )
154 };
155
156 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
157 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
158
159 let vs2_base = usize::from(vs2.bits());
160
161 let mut acc = init;
162 for i in vstart..vl {
163 if !mask_bit(&mask_buf, i) {
164 continue;
165 }
166 // SAFETY: same bounds argument as `execute_reduce_op`
167 let raw = unsafe { read_element_u64(state.ext_state.read_vreg(), vs2_base, i, sew) };
168 // Widen element to 2*SEW
169 let elem = if sign_extend_src {
170 sign_extend(raw, sew).cast_unsigned()
171 } else {
172 // Zero-extension: raw is already zero-extended to u64 by read_element_u64
173 raw
174 };
175 acc = op(acc, elem, wide_sew);
176 }
177
178 // Write result to vd[0] at 2*SEW
179 // SAFETY: element 0 always fits within register vd
180 unsafe {
181 write_element_u64(state.ext_state.write_vreg(), vd.bits(), 0, wide_sew, acc);
182 }
183
184 state.ext_state.mark_vs_dirty();
185 state.ext_state.reset_vstart();
186}