Skip to main content

ab_riscv_interpreter/zvbc/
zvbc_helpers.rs

1//! Opaque helpers for Zvbc extension
2
3use crate::rv64::b::zbc::rv64_zbc_helpers;
4use crate::v::vector_registers::VectorRegistersExt;
5pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
6use crate::v::zvexx::arith::zvexx_arith_helpers::{read_element_u64, sew_mask, write_element_u64};
7use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10
11/// Lower SEW bits of the carry-less product of two SEW-wide values.
12///
13/// Both inputs are masked to SEW bits before the multiplication so that the VX form (where
14/// the scalar register may carry bits above the SEW boundary) behaves identically to the VV
15/// form (where `read_element_u64` already zero-extends elements to exactly SEW bits).
16#[inline(always)]
17fn vclmul_element(a: u64, b: u64, sew: Vsew) -> u64 {
18    let mask = sew_mask(sew);
19    let a = a & mask;
20    let b = b & mask;
21    rv64_zbc_helpers::clmul(a, b) & mask
22}
23
24/// Upper SEW bits of the carry-less product of two SEW-wide values.
25///
26/// Both inputs are masked to SEW bits (see [`vclmul_element()`] for rationale).
27///
28/// For SEW < 64, the product fits in 64 bits; the upper half lives at bits
29/// `[2*SEW-1 : SEW]` of `clmul(a, b)`. `clmulh` would return 0 for SEW-bit inputs
30/// since the product never reaches bit 64.
31/// For SEW = 64, `clmulh` directly returns the upper half of the 128-bit product.
32#[inline(always)]
33fn vclmulh_element(a: u64, b: u64, sew: Vsew) -> u64 {
34    let mask = sew_mask(sew);
35    let a = a & mask;
36    let b = b & mask;
37    if sew == Vsew::E64 {
38        rv64_zbc_helpers::clmulh(a, b)
39    } else {
40        // The 2*SEW-bit product fits in the 64-bit return value of clmul; extract
41        // bits [2*SEW-1 : SEW] and mask back to SEW bits.
42        (rv64_zbc_helpers::clmul(a, b) >> sew.bits_width()) & mask
43    }
44}
45
46/// Execute element-wise carry-less multiplication (lower half) over `vstart..vl`.
47///
48/// For each active element i: `vd[i] = lower_sew_bits(clmul(vs2[i], src[i]))`.
49///
50/// When `vm=true` all elements are active. When `vm=false` the mask register `v0` gates
51/// each element; masked-off elements are left undisturbed (undisturbed policy).
52///
53/// # Safety
54/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
55/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
56/// - `src` register (if `Vreg`) satisfies the same alignment as `vs2`
57/// - `vl <= group_regs * VLENB / sew_bytes`
58#[inline(always)]
59#[doc(hidden)]
60pub unsafe fn execute_vclmul<Reg, ExtState, CustomError>(
61    ext_state: &mut ExtState,
62    vd: VReg,
63    vs2: VReg,
64    src: OpSrc,
65    sew: Vsew,
66    vm: bool,
67) where
68    Reg: Register,
69    ExtState: VectorRegistersExt<Reg, CustomError>,
70    [(); ExtState::ELEN as usize]:,
71    [(); ExtState::VLEN as usize]:,
72    [(); ExtState::VLENB as usize]:,
73    CustomError: fmt::Debug,
74{
75    let vl = ext_state.vl();
76    let vstart = ext_state.vstart();
77    for i in u32::from(vstart)..vl {
78        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
79            continue;
80        }
81        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
82        // `i < vl <= group_regs * elems_per_reg`, so
83        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
84        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
85        let b = match src {
86            OpSrc::Vreg(vs1_base) => {
87                // SAFETY: caller verified the vs1 register group satisfies the same alignment
88                // constraint as vs2; the index argument is identical, so the same bound holds
89                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
90            }
91            OpSrc::Scalar(val) => val,
92        };
93        let result = vclmul_element(a, b, sew);
94        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
95        // `i < vl <= group_regs * elems_per_reg`, so
96        // `vd + i / elems_per_reg < vd + group_regs <= 32`
97        unsafe {
98            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
99        }
100    }
101    ext_state.mark_vs_dirty();
102    ext_state.reset_vstart();
103}
104
105/// Execute element-wise carry-less multiplication (upper half) over `vstart..vl`.
106///
107/// For each active element i: `vd[i] = upper_sew_bits(clmul(vs2[i], src[i]))`.
108///
109/// When `vm=false`, masked-off elements are left undisturbed.
110///
111/// # Safety
112/// Same register-group constraints as [`execute_vclmul`].
113#[inline(always)]
114#[doc(hidden)]
115pub unsafe fn execute_vclmulh<Reg, ExtState, CustomError>(
116    ext_state: &mut ExtState,
117    vd: VReg,
118    vs2: VReg,
119    src: OpSrc,
120    sew: Vsew,
121    vm: bool,
122) where
123    Reg: Register,
124    ExtState: VectorRegistersExt<Reg, CustomError>,
125    [(); ExtState::ELEN as usize]:,
126    [(); ExtState::VLEN as usize]:,
127    [(); ExtState::VLENB as usize]:,
128    CustomError: fmt::Debug,
129{
130    let vl = ext_state.vl();
131    let vstart = ext_state.vstart();
132    for i in u32::from(vstart)..vl {
133        if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
134            continue;
135        }
136        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
137        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
138        let b = match src {
139            OpSrc::Vreg(vs1_base) => {
140                // SAFETY: same alignment constraint as vs2; same index bound
141                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
142            }
143            OpSrc::Scalar(val) => val,
144        };
145        let result = vclmulh_element(a, b, sew);
146        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
147        unsafe {
148            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
149        }
150    }
151    ext_state.mark_vs_dirty();
152    ext_state.reset_vstart();
153}