ab_riscv_interpreter/zvbb/zvbb_helpers.rs
1//! Opaque helpers for Zvbb extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zvexx::arith::zvexx_arith_helpers::{read_element_u64, write_element_u64};
6use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute element-wise full bit-reversal over `vstart..vl`, writing SEW-wide results into `vd`.
11///
12/// For each active element i: all bits within `vs2[i]` are reversed end-to-end
13/// (bit 0 <-> bit SEW-1). This differs from `vbrev8`, which reverses bits within each byte while
14/// preserving byte order; `vbrev` also inverts the byte order as a side effect of reversing the
15/// whole element.
16///
17/// When `vm=false`, masked-off elements are left undisturbed.
18///
19/// # Safety
20/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
21/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
22/// - `vl <= group_regs * VLENB / sew_bytes`
23#[inline(always)]
24#[doc(hidden)]
25pub unsafe fn execute_vbrev<Reg, ExtState, CustomError>(
26 ext_state: &mut ExtState,
27 vd: VReg,
28 vs2: VReg,
29 sew: Vsew,
30 vm: bool,
31) where
32 Reg: Register,
33 ExtState: VectorRegistersExt<Reg, CustomError>,
34 [(); ExtState::ELEN as usize]:,
35 [(); ExtState::VLEN as usize]:,
36 [(); ExtState::VLENB as usize]:,
37 CustomError: fmt::Debug,
38{
39 let vl = ext_state.vl();
40 let vstart = ext_state.vstart();
41 for i in u32::from(vstart)..vl {
42 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
43 continue;
44 }
45 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
46 let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
47 // `elem` is zero-extended from SEW bits to u64; reverse_bits() on the primitive type
48 // of exactly SEW width naturally handles the upper zero bits from zero-extension
49 let result = match sew {
50 Vsew::E8 => u64::from((elem as u8).reverse_bits()),
51 Vsew::E16 => u64::from((elem as u16).reverse_bits()),
52 Vsew::E32 => u64::from((elem as u32).reverse_bits()),
53 Vsew::E64 => elem.reverse_bits(),
54 };
55 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
56 unsafe {
57 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
58 }
59 }
60 ext_state.mark_vs_dirty();
61 ext_state.reset_vstart();
62}
63
64/// Execute element-wise count-leading-zeros over `vstart..vl`, writing SEW-wide results into `vd`.
65///
66/// For each active element i: `vd[i] = clz(vs2[i])`, counting within the SEW-wide field. An
67/// all-zero element produces SEW, not 64.
68///
69/// When `vm=false`, masked-off elements are left undisturbed.
70///
71/// # Safety
72/// Same register-group constraints as [`execute_vbrev`].
73#[inline(always)]
74#[doc(hidden)]
75pub unsafe fn execute_vclz<Reg, ExtState, CustomError>(
76 ext_state: &mut ExtState,
77 vd: VReg,
78 vs2: VReg,
79 sew: Vsew,
80 vm: bool,
81) where
82 Reg: Register,
83 ExtState: VectorRegistersExt<Reg, CustomError>,
84 [(); ExtState::ELEN as usize]:,
85 [(); ExtState::VLEN as usize]:,
86 [(); ExtState::VLENB as usize]:,
87 CustomError: fmt::Debug,
88{
89 let vl = ext_state.vl();
90 let vstart = ext_state.vstart();
91 let sew_bits = u32::from(sew.bits_width());
92 for i in u32::from(vstart)..vl {
93 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
94 continue;
95 }
96 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
97 let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
98 // `elem` is zero-extended from SEW bits to u64; `leading_zeros()` on a u64 therefore counts
99 // the extra (64 - SEW) upper zero bits introduced by zero-extension. Subtracting them gives
100 // the count within the SEW-wide field.
101 let clz = elem.leading_zeros() - (64 - sew_bits);
102 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
103 unsafe {
104 write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(clz));
105 }
106 }
107 ext_state.mark_vs_dirty();
108 ext_state.reset_vstart();
109}
110
111/// Execute element-wise count-trailing-zeros over `vstart..vl`, writing SEW-wide results into `vd`.
112///
113/// For each active element i: `vd[i] = ctz(vs2[i])`, counting within the SEW-wide field. An
114/// all-zero element produces SEW, not 64.
115///
116/// When `vm=false`, masked-off elements are left undisturbed.
117///
118/// # Safety
119/// Same register-group constraints as [`execute_vbrev`].
120#[inline(always)]
121#[doc(hidden)]
122pub unsafe fn execute_vctz<Reg, ExtState, CustomError>(
123 ext_state: &mut ExtState,
124 vd: VReg,
125 vs2: VReg,
126 sew: Vsew,
127 vm: bool,
128) where
129 Reg: Register,
130 ExtState: VectorRegistersExt<Reg, CustomError>,
131 [(); ExtState::ELEN as usize]:,
132 [(); ExtState::VLEN as usize]:,
133 [(); ExtState::VLENB as usize]:,
134 CustomError: fmt::Debug,
135{
136 let vl = ext_state.vl();
137 let vstart = ext_state.vstart();
138 let sew_bits = u32::from(sew.bits_width());
139 for i in u32::from(vstart)..vl {
140 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
141 continue;
142 }
143 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
144 let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
145 // For non-zero `elem`, `trailing_zeros()` on the zero-extended u64 value is correct: the
146 // upper zero bits do not affect the trailing count. For zero, `trailing_zeros()` returns
147 // 64, but the spec result is SEW; cap at `sew_bits` handles both cases.
148 let ctz = elem.trailing_zeros().min(sew_bits);
149 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
150 unsafe {
151 write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(ctz));
152 }
153 }
154 ext_state.mark_vs_dirty();
155 ext_state.reset_vstart();
156}
157
158/// Execute element-wise population count over `vstart..vl`, writing SEW-wide results into `vd`.
159///
160/// For each active element i: `vd[i] = popcount(vs2[i])`, in range `[0, SEW]`.
161///
162/// When `vm=false`, masked-off elements are left undisturbed.
163///
164/// # Safety
165/// Same register-group constraints as [`execute_vbrev`].
166#[inline(always)]
167#[doc(hidden)]
168pub unsafe fn execute_vcpop<Reg, ExtState, CustomError>(
169 ext_state: &mut ExtState,
170 vd: VReg,
171 vs2: VReg,
172 sew: Vsew,
173 vm: bool,
174) where
175 Reg: Register,
176 ExtState: VectorRegistersExt<Reg, CustomError>,
177 [(); ExtState::ELEN as usize]:,
178 [(); ExtState::VLEN as usize]:,
179 [(); ExtState::VLENB as usize]:,
180 CustomError: fmt::Debug,
181{
182 let vl = ext_state.vl();
183 let vstart = ext_state.vstart();
184 for i in u32::from(vstart)..vl {
185 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
186 continue;
187 }
188 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
189 let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
190 // `elem` is zero-extended from SEW bits; upper bits are already zero, so `count_ones()`
191 // directly gives the population count within the SEW-wide field
192 let cpop = elem.count_ones();
193 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
194 unsafe {
195 write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(cpop));
196 }
197 }
198 ext_state.mark_vs_dirty();
199 ext_state.reset_vstart();
200}
201
202/// Execute element-wise widening shift-left-logical over `vstart..vl`, writing 2*SEW-wide
203/// results into `vd`.
204///
205/// For each active element i: `vd[i] = zero_extend_to_2SEW(vs2[i]) << (src[i] % (2*SEW))`.
206/// The source operand width is SEW; the destination element width is `double_sew` (2*SEW).
207///
208/// The caller must ensure SEW <= E32 (i.e., `sew.double_width()` is `Some`); passing SEW=E64 is a
209/// programming error that would produce a result wider than u64.
210///
211/// When `vm=false`, masked-off destination elements are left undisturbed.
212///
213/// # Safety
214/// - `vd` register group satisfies alignment for EMUL = 2*LMUL: `vd.to_bits() % dest_group_regs ==
215/// 0` and `vd.to_bits() + dest_group_regs <= 32`
216/// - `vs2` register group satisfies alignment for LMUL
217/// - `src` register (if `Vreg`) satisfies the same alignment as `vs2`
218/// - `vl <= dest_group_regs * VLENB / double_sew_bytes`
219#[inline(always)]
220#[doc(hidden)]
221pub unsafe fn execute_vwsll<Reg, ExtState, CustomError>(
222 ext_state: &mut ExtState,
223 vd: VReg,
224 vs2: VReg,
225 src: OpSrc,
226 sew: Vsew,
227 double_sew: Vsew,
228 vm: bool,
229) where
230 Reg: Register,
231 ExtState: VectorRegistersExt<Reg, CustomError>,
232 [(); ExtState::ELEN as usize]:,
233 [(); ExtState::VLEN as usize]:,
234 [(); ExtState::VLENB as usize]:,
235 CustomError: fmt::Debug,
236{
237 let vl = ext_state.vl();
238 let vstart = ext_state.vstart();
239 // `double_sew_bits` is always a power of two (16, 32, or 64); `& (bits - 1)` is equivalent to
240 // `% bits` and avoids a division
241 let double_sew_bits = u64::from(double_sew.bits_width());
242 for i in u32::from(vstart)..vl {
243 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
244 continue;
245 }
246 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
247 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
248 let amount = match src {
249 OpSrc::Vreg(vs1_base) => {
250 // SAFETY: same alignment constraint as vs2; same index bound
251 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
252 }
253 OpSrc::Scalar(val) => val,
254 };
255 let shift = (amount & (double_sew_bits - 1)) as u32;
256 // `a` is zero-extended from SEW bits; `shift < double_sew_bits <= 64`, so this never shifts
257 // by >= 64. The caller guarantees SEW <= E32, hence `double_sew_bits <= 64`.
258 let result = a << shift;
259 // SAFETY: `vd % dest_group_regs == 0` and `vd + dest_group_regs <= 32`; `i < vl`;
260 // `write_element_u64` with `double_sew` writes exactly 2*SEW bits of `result`
261 unsafe {
262 write_element_u64(ext_state.write_vregs(), vd, i, double_sew, result);
263 }
264 }
265 ext_state.mark_vs_dirty();
266 ext_state.reset_vstart();
267}