Skip to main content

ab_riscv_interpreter/zvbb/
zvbb_helpers.rs

1//! Opaque helpers for Zvbb extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zvexx::arith::zvexx_arith_helpers::{read_element_u64, write_element_u64};
6use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute element-wise full bit-reversal over `vstart..vl`, writing SEW-wide results into `vd`.
11///
12/// For each active element i: all bits within `vs2[i]` are reversed end-to-end
13/// (bit 0 <-> bit SEW-1). This differs from `vbrev8`, which reverses bits within each byte while
14/// preserving byte order; `vbrev` also inverts the byte order as a side effect of reversing the
15/// whole element.
16///
17/// When `vm=false`, masked-off elements are left undisturbed.
18///
19/// # Safety
20/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
21/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
22/// - `vl <= group_regs * VLENB / sew_bytes`
23#[inline(always)]
24#[doc(hidden)]
25pub unsafe fn execute_vbrev<Reg, ExtState, CustomError>(
26    ext_state: &mut ExtState,
27    vd: VReg,
28    vs2: VReg,
29    sew: Vsew,
30    vm: bool,
31) where
32    Reg: Register,
33    ExtState: VectorRegistersExt<Reg, CustomError>,
34    [(); ExtState::ELEN as usize]:,
35    [(); ExtState::VLEN as usize]:,
36    [(); ExtState::VLENB as usize]:,
37    CustomError: fmt::Debug,
38{
39    let vl = ext_state.vl();
40    let vstart = ext_state.vstart();
41    for i in u32::from(vstart)..vl {
42        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
43            continue;
44        }
45        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
46        let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
47        // `elem` is zero-extended from SEW bits to u64; reverse_bits() on the primitive type
48        // of exactly SEW width naturally handles the upper zero bits from zero-extension
49        let result = match sew {
50            Vsew::E8 => u64::from((elem as u8).reverse_bits()),
51            Vsew::E16 => u64::from((elem as u16).reverse_bits()),
52            Vsew::E32 => u64::from((elem as u32).reverse_bits()),
53            Vsew::E64 => elem.reverse_bits(),
54        };
55        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
56        unsafe {
57            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
58        }
59    }
60    ext_state.mark_vs_dirty();
61    ext_state.reset_vstart();
62}
63
64/// Execute element-wise count-leading-zeros over `vstart..vl`, writing SEW-wide results into `vd`.
65///
66/// For each active element i: `vd[i] = clz(vs2[i])`, counting within the SEW-wide field. An
67/// all-zero element produces SEW, not 64.
68///
69/// When `vm=false`, masked-off elements are left undisturbed.
70///
71/// # Safety
72/// Same register-group constraints as [`execute_vbrev`].
73#[inline(always)]
74#[doc(hidden)]
75pub unsafe fn execute_vclz<Reg, ExtState, CustomError>(
76    ext_state: &mut ExtState,
77    vd: VReg,
78    vs2: VReg,
79    sew: Vsew,
80    vm: bool,
81) where
82    Reg: Register,
83    ExtState: VectorRegistersExt<Reg, CustomError>,
84    [(); ExtState::ELEN as usize]:,
85    [(); ExtState::VLEN as usize]:,
86    [(); ExtState::VLENB as usize]:,
87    CustomError: fmt::Debug,
88{
89    let vl = ext_state.vl();
90    let vstart = ext_state.vstart();
91    let sew_bits = u32::from(sew.bits_width());
92    for i in u32::from(vstart)..vl {
93        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
94            continue;
95        }
96        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
97        let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
98        // `elem` is zero-extended from SEW bits to u64; `leading_zeros()` on a u64 therefore counts
99        // the extra (64 - SEW) upper zero bits introduced by zero-extension. Subtracting them gives
100        // the count within the SEW-wide field.
101        let clz = elem.leading_zeros() - (64 - sew_bits);
102        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
103        unsafe {
104            write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(clz));
105        }
106    }
107    ext_state.mark_vs_dirty();
108    ext_state.reset_vstart();
109}
110
111/// Execute element-wise count-trailing-zeros over `vstart..vl`, writing SEW-wide results into `vd`.
112///
113/// For each active element i: `vd[i] = ctz(vs2[i])`, counting within the SEW-wide field. An
114/// all-zero element produces SEW, not 64.
115///
116/// When `vm=false`, masked-off elements are left undisturbed.
117///
118/// # Safety
119/// Same register-group constraints as [`execute_vbrev`].
120#[inline(always)]
121#[doc(hidden)]
122pub unsafe fn execute_vctz<Reg, ExtState, CustomError>(
123    ext_state: &mut ExtState,
124    vd: VReg,
125    vs2: VReg,
126    sew: Vsew,
127    vm: bool,
128) where
129    Reg: Register,
130    ExtState: VectorRegistersExt<Reg, CustomError>,
131    [(); ExtState::ELEN as usize]:,
132    [(); ExtState::VLEN as usize]:,
133    [(); ExtState::VLENB as usize]:,
134    CustomError: fmt::Debug,
135{
136    let vl = ext_state.vl();
137    let vstart = ext_state.vstart();
138    let sew_bits = u32::from(sew.bits_width());
139    for i in u32::from(vstart)..vl {
140        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
141            continue;
142        }
143        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
144        let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
145        // For non-zero `elem`, `trailing_zeros()` on the zero-extended u64 value is correct: the
146        // upper zero bits do not affect the trailing count. For zero, `trailing_zeros()` returns
147        // 64, but the spec result is SEW; cap at `sew_bits` handles both cases.
148        let ctz = elem.trailing_zeros().min(sew_bits);
149        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
150        unsafe {
151            write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(ctz));
152        }
153    }
154    ext_state.mark_vs_dirty();
155    ext_state.reset_vstart();
156}
157
158/// Execute element-wise population count over `vstart..vl`, writing SEW-wide results into `vd`.
159///
160/// For each active element i: `vd[i] = popcount(vs2[i])`, in range `[0, SEW]`.
161///
162/// When `vm=false`, masked-off elements are left undisturbed.
163///
164/// # Safety
165/// Same register-group constraints as [`execute_vbrev`].
166#[inline(always)]
167#[doc(hidden)]
168pub unsafe fn execute_vcpop<Reg, ExtState, CustomError>(
169    ext_state: &mut ExtState,
170    vd: VReg,
171    vs2: VReg,
172    sew: Vsew,
173    vm: bool,
174) where
175    Reg: Register,
176    ExtState: VectorRegistersExt<Reg, CustomError>,
177    [(); ExtState::ELEN as usize]:,
178    [(); ExtState::VLEN as usize]:,
179    [(); ExtState::VLENB as usize]:,
180    CustomError: fmt::Debug,
181{
182    let vl = ext_state.vl();
183    let vstart = ext_state.vstart();
184    for i in u32::from(vstart)..vl {
185        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
186            continue;
187        }
188        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
189        let elem = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
190        // `elem` is zero-extended from SEW bits; upper bits are already zero, so `count_ones()`
191        // directly gives the population count within the SEW-wide field
192        let cpop = elem.count_ones();
193        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
194        unsafe {
195            write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(cpop));
196        }
197    }
198    ext_state.mark_vs_dirty();
199    ext_state.reset_vstart();
200}
201
202/// Execute element-wise widening shift-left-logical over `vstart..vl`, writing 2*SEW-wide
203/// results into `vd`.
204///
205/// For each active element i: `vd[i] = zero_extend_to_2SEW(vs2[i]) << (src[i] % (2*SEW))`.
206/// The source operand width is SEW; the destination element width is `double_sew` (2*SEW).
207///
208/// The caller must ensure SEW <= E32 (i.e., `sew.double_width()` is `Some`); passing SEW=E64 is a
209/// programming error that would produce a result wider than u64.
210///
211/// When `vm=false`, masked-off destination elements are left undisturbed.
212///
213/// # Safety
214/// - `vd` register group satisfies alignment for EMUL = 2*LMUL: `vd.to_bits() % dest_group_regs ==
215///   0` and `vd.to_bits() + dest_group_regs <= 32`
216/// - `vs2` register group satisfies alignment for LMUL
217/// - `src` register (if `Vreg`) satisfies the same alignment as `vs2`
218/// - `vl <= dest_group_regs * VLENB / double_sew_bytes`
219#[inline(always)]
220#[doc(hidden)]
221pub unsafe fn execute_vwsll<Reg, ExtState, CustomError>(
222    ext_state: &mut ExtState,
223    vd: VReg,
224    vs2: VReg,
225    src: OpSrc,
226    sew: Vsew,
227    double_sew: Vsew,
228    vm: bool,
229) where
230    Reg: Register,
231    ExtState: VectorRegistersExt<Reg, CustomError>,
232    [(); ExtState::ELEN as usize]:,
233    [(); ExtState::VLEN as usize]:,
234    [(); ExtState::VLENB as usize]:,
235    CustomError: fmt::Debug,
236{
237    let vl = ext_state.vl();
238    let vstart = ext_state.vstart();
239    // `double_sew_bits` is always a power of two (16, 32, or 64); `& (bits - 1)` is equivalent to
240    // `% bits` and avoids a division
241    let double_sew_bits = u64::from(double_sew.bits_width());
242    for i in u32::from(vstart)..vl {
243        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
244            continue;
245        }
246        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
247        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
248        let amount = match src {
249            OpSrc::Vreg(vs1_base) => {
250                // SAFETY: same alignment constraint as vs2; same index bound
251                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
252            }
253            OpSrc::Scalar(val) => val,
254        };
255        let shift = (amount & (double_sew_bits - 1)) as u32;
256        // `a` is zero-extended from SEW bits; `shift < double_sew_bits <= 64`, so this never shifts
257        // by >= 64. The caller guarantees SEW <= E32, hence `double_sew_bits <= 64`.
258        let result = a << shift;
259        // SAFETY: `vd % dest_group_regs == 0` and `vd + dest_group_regs <= 32`; `i < vl`;
260        // `write_element_u64` with `double_sew` writes exactly 2*SEW bits of `result`
261        unsafe {
262            write_element_u64(ext_state.write_vregs(), vd, i, double_sew, result);
263        }
264    }
265    ext_state.mark_vs_dirty();
266    ext_state.reset_vstart();
267}