ab_riscv_interpreter/zvbc/zvbc_helpers.rs
1//! Opaque helpers for Zvbc extension
2
3use crate::rv64::b::zbc::rv64_zbc_helpers;
4use crate::v::vector_registers::VectorRegistersExt;
5pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
6use crate::v::zvexx::arith::zvexx_arith_helpers::{read_element_u64, sew_mask, write_element_u64};
7use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10
11/// Lower SEW bits of the carry-less product of two SEW-wide values.
12///
13/// Both inputs are masked to SEW bits before the multiplication so that the VX form (where
14/// the scalar register may carry bits above the SEW boundary) behaves identically to the VV
15/// form (where `read_element_u64` already zero-extends elements to exactly SEW bits).
16#[inline(always)]
17fn vclmul_element(a: u64, b: u64, sew: Vsew) -> u64 {
18 let mask = sew_mask(sew);
19 let a = a & mask;
20 let b = b & mask;
21 rv64_zbc_helpers::clmul(a, b) & mask
22}
23
24/// Upper SEW bits of the carry-less product of two SEW-wide values.
25///
26/// Both inputs are masked to SEW bits (see [`vclmul_element()`] for rationale).
27///
28/// For SEW < 64, the product fits in 64 bits; the upper half lives at bits
29/// `[2*SEW-1 : SEW]` of `clmul(a, b)`. `clmulh` would return 0 for SEW-bit inputs
30/// since the product never reaches bit 64.
31/// For SEW = 64, `clmulh` directly returns the upper half of the 128-bit product.
32#[inline(always)]
33fn vclmulh_element(a: u64, b: u64, sew: Vsew) -> u64 {
34 let mask = sew_mask(sew);
35 let a = a & mask;
36 let b = b & mask;
37 if sew == Vsew::E64 {
38 rv64_zbc_helpers::clmulh(a, b)
39 } else {
40 // The 2*SEW-bit product fits in the 64-bit return value of clmul; extract
41 // bits [2*SEW-1 : SEW] and mask back to SEW bits.
42 (rv64_zbc_helpers::clmul(a, b) >> sew.bits_width()) & mask
43 }
44}
45
46/// Execute element-wise carry-less multiplication (lower half) over `vstart..vl`.
47///
48/// For each active element i: `vd[i] = lower_sew_bits(clmul(vs2[i], src[i]))`.
49///
50/// When `vm=true` all elements are active. When `vm=false` the mask register `v0` gates
51/// each element; masked-off elements are left undisturbed (undisturbed policy).
52///
53/// # Safety
54/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
55/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
56/// - `src` register (if `Vreg`) satisfies the same alignment as `vs2`
57/// - `vl <= group_regs * VLENB / sew_bytes`
58#[inline(always)]
59#[doc(hidden)]
60pub unsafe fn execute_vclmul<Reg, ExtState, CustomError>(
61 ext_state: &mut ExtState,
62 vd: VReg,
63 vs2: VReg,
64 src: OpSrc,
65 sew: Vsew,
66 vm: bool,
67) where
68 Reg: Register,
69 ExtState: VectorRegistersExt<Reg, CustomError>,
70 [(); ExtState::ELEN as usize]:,
71 [(); ExtState::VLEN as usize]:,
72 [(); ExtState::VLENB as usize]:,
73 CustomError: fmt::Debug,
74{
75 let vl = ext_state.vl();
76 let vstart = ext_state.vstart();
77 for i in u32::from(vstart)..vl {
78 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
79 continue;
80 }
81 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
82 // `i < vl <= group_regs * elems_per_reg`, so
83 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
84 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
85 let b = match src {
86 OpSrc::Vreg(vs1_base) => {
87 // SAFETY: caller verified the vs1 register group satisfies the same alignment
88 // constraint as vs2; the index argument is identical, so the same bound holds
89 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
90 }
91 OpSrc::Scalar(val) => val,
92 };
93 let result = vclmul_element(a, b, sew);
94 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
95 // `i < vl <= group_regs * elems_per_reg`, so
96 // `vd + i / elems_per_reg < vd + group_regs <= 32`
97 unsafe {
98 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
99 }
100 }
101 ext_state.mark_vs_dirty();
102 ext_state.reset_vstart();
103}
104
105/// Execute element-wise carry-less multiplication (upper half) over `vstart..vl`.
106///
107/// For each active element i: `vd[i] = upper_sew_bits(clmul(vs2[i], src[i]))`.
108///
109/// When `vm=false`, masked-off elements are left undisturbed.
110///
111/// # Safety
112/// Same register-group constraints as [`execute_vclmul`].
113#[inline(always)]
114#[doc(hidden)]
115pub unsafe fn execute_vclmulh<Reg, ExtState, CustomError>(
116 ext_state: &mut ExtState,
117 vd: VReg,
118 vs2: VReg,
119 src: OpSrc,
120 sew: Vsew,
121 vm: bool,
122) where
123 Reg: Register,
124 ExtState: VectorRegistersExt<Reg, CustomError>,
125 [(); ExtState::ELEN as usize]:,
126 [(); ExtState::VLEN as usize]:,
127 [(); ExtState::VLENB as usize]:,
128 CustomError: fmt::Debug,
129{
130 let vl = ext_state.vl();
131 let vstart = ext_state.vstart();
132 for i in u32::from(vstart)..vl {
133 if !vm && !mask_bit(ext_state.read_vregs().get(VReg::V0), i) {
134 continue;
135 }
136 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32`; `i < vl`
137 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
138 let b = match src {
139 OpSrc::Vreg(vs1_base) => {
140 // SAFETY: same alignment constraint as vs2; same index bound
141 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
142 }
143 OpSrc::Scalar(val) => val,
144 };
145 let result = vclmulh_element(a, b, sew);
146 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32`; `i < vl`
147 unsafe {
148 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
149 }
150 }
151 ext_state.mark_vs_dirty();
152 ext_state.reset_vstart();
153}