ab_riscv_interpreter/v/zve64x/mask/zve64x_mask_helpers.rs
1//! Opaque helpers for Zve64x extension
2
3use crate::RegisterFile;
4use crate::v::vector_registers::VectorRegistersExt;
5use crate::v::zve64x::arith::zve64x_arith_helpers::{write_element_u64, write_mask_bit};
6use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute a mask-register logical operation on the full `VLENB`-byte mask registers.
11///
12/// Operates on the entire register width independent of `vl` or `vtype`, per spec §16.1.
13/// `op` receives `(vs2_byte: u8, vs1_byte: u8) -> u8`.
14///
15/// # Safety
16/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg` type).
17/// The operation snaps both sources before writing, so `vd` may safely overlap either source.
18#[inline(always)]
19#[doc(hidden)]
20pub unsafe fn execute_mask_logical_op<Reg, ExtState, CustomError, F>(
21 ext_state: &mut ExtState,
22 vd: VReg,
23 vs2: VReg,
24 vs1: VReg,
25 op: F,
26) where
27 Reg: Register,
28 ExtState: VectorRegistersExt<Reg, CustomError>,
29 [(); ExtState::ELEN as usize]:,
30 [(); ExtState::VLEN as usize]:,
31 [(); ExtState::VLENB as usize]:,
32 CustomError: fmt::Debug,
33 F: Fn(u8, u8) -> u8,
34{
35 // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
36 // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
37 let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
38 // SAFETY: `vs1.bits() < 32`; `VReg` values are always in `0..32`
39 let vs1_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs1.bits())) };
40 // SAFETY: `vd.bits() < 32`; `VReg` values are always in `0..32`
41 let vd_reg = unsafe {
42 ext_state
43 .write_vreg()
44 .get_unchecked_mut(usize::from(vd.bits()))
45 };
46 for byte_i in 0..ExtState::VLENB as usize {
47 // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
48 // `vs2_snap` / `vs1_snap` are `[u8; VLENB]` arrays
49 let a = unsafe { *vs2_snap.get_unchecked(byte_i) };
50 // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
51 // `vs2_snap` / `vs1_snap` are `[u8; VLENB]` arrays
52 let b = unsafe { *vs1_snap.get_unchecked(byte_i) };
53 // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
54 // `vd_reg` points to a `[u8; VLENB]` register row
55 *unsafe { vd_reg.get_unchecked_mut(byte_i) } = op(a, b);
56 }
57 ext_state.mark_vs_dirty();
58 ext_state.reset_vstart();
59}
60
61/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
62///
63/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
64/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
65///
66/// # Safety
67/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
68/// - `vstart <= vl`
69#[inline(always)]
70#[doc(hidden)]
71pub unsafe fn execute_vcpop<Reg, Regs, ExtState, CustomError>(
72 regs: &mut Regs,
73 ext_state: &mut ExtState,
74 rd: Reg,
75 vs2: VReg,
76 vm: bool,
77 vl: u32,
78 vstart: u32,
79) where
80 Reg: Register,
81 Regs: RegisterFile<Reg>,
82 ExtState: VectorRegistersExt<Reg, CustomError>,
83 [(); ExtState::ELEN as usize]:,
84 [(); ExtState::VLEN as usize]:,
85 [(); ExtState::VLENB as usize]:,
86 CustomError: fmt::Debug,
87{
88 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
89 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
90 // SAFETY: `vs2` is a valid VReg index (< 32)
91 let vs2_reg = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
92 let mut count = 0u32;
93 for i in vstart..vl {
94 if !mask_bit(&mask_buf, i) {
95 continue;
96 }
97 if mask_bit(&vs2_reg, i) {
98 count += 1;
99 }
100 }
101 regs.write(rd, Reg::Type::from(count));
102 ext_state.mark_vs_dirty();
103 ext_state.reset_vstart();
104}
105
106/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
107/// write result (or -1 if none) to `rd`.
108///
109/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
110/// `-1` (all-ones) if no active element of vs2 is set.
111///
112/// # Safety
113/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
114/// - `vstart <= vl`
115#[inline(always)]
116#[doc(hidden)]
117pub unsafe fn execute_vfirst<Reg, Regs, ExtState, CustomError>(
118 regs: &mut Regs,
119 ext_state: &mut ExtState,
120 rd: Reg,
121 vs2: VReg,
122 vm: bool,
123 vl: u32,
124 vstart: u32,
125) where
126 Reg: Register,
127 Regs: RegisterFile<Reg>,
128 ExtState: VectorRegistersExt<Reg, CustomError>,
129 [(); ExtState::ELEN as usize]:,
130 [(); ExtState::VLEN as usize]:,
131 [(); ExtState::VLENB as usize]:,
132 CustomError: fmt::Debug,
133{
134 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
135 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
136 // SAFETY: `vs2` is a valid VReg index (< 32)
137 let vs2_reg = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
138 // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
139 let not_found = u64::MAX;
140 let mut result = not_found;
141 for i in vstart..vl {
142 if !mask_bit(&mask_buf, i) {
143 continue;
144 }
145 if mask_bit(&vs2_reg, i) {
146 result = u64::from(i);
147 break;
148 }
149 }
150 // Write -1 (all-ones for XLEN bits) or the found index.
151 // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
152 // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
153 // without depending on `From<u64>` (which is not in the `Register` trait bounds).
154 // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
155 let reg_value = if result == not_found {
156 !Reg::Type::from(0u8)
157 } else {
158 Reg::Type::from(result as u32)
159 };
160 regs.write(rd, reg_value);
161 ext_state.mark_vs_dirty();
162 ext_state.reset_vstart();
163}
164
165/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
166///
167/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
168/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
169/// destination bits are cleared.
170///
171/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
172///
173/// # Safety
174/// - `vd` does not overlap `vs2` (checked by caller)
175/// - `vm=false` implies `vd != v0` (checked by caller)
176/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
177#[inline(always)]
178#[doc(hidden)]
179pub unsafe fn execute_vmsbf<Reg, ExtState, CustomError>(
180 ext_state: &mut ExtState,
181 vd: VReg,
182 vs2: VReg,
183 vm: bool,
184 vl: u32,
185) where
186 Reg: Register,
187 ExtState: VectorRegistersExt<Reg, CustomError>,
188 [(); ExtState::ELEN as usize]:,
189 [(); ExtState::VLEN as usize]:,
190 [(); ExtState::VLENB as usize]:,
191 CustomError: fmt::Debug,
192{
193 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
194 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
195 // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
196 let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
197 let mut found_first = false;
198 for i in 0..vl {
199 // Inactive elements: undisturbed
200 if !mask_bit(&mask_buf, i) {
201 continue;
202 }
203 let vs2_bit = mask_bit(&vs2_snap, i);
204 // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
205 let result = !found_first && !vs2_bit;
206 if vs2_bit {
207 found_first = true;
208 }
209 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
210 unsafe { write_mask_bit(ext_state.write_vreg(), vd, i, result) };
211 }
212 ext_state.mark_vs_dirty();
213 // vstart is already zero, doesn't need to be reset
214}
215
216/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
217///
218/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
219/// vs2 has a set bit. All other active destination bits are cleared.
220///
221/// # Safety
222/// Same as [`execute_vmsbf`].
223#[inline(always)]
224#[doc(hidden)]
225pub unsafe fn execute_vmsof<Reg, ExtState, CustomError>(
226 ext_state: &mut ExtState,
227 vd: VReg,
228 vs2: VReg,
229 vm: bool,
230 vl: u32,
231) where
232 Reg: Register,
233 ExtState: VectorRegistersExt<Reg, CustomError>,
234 [(); ExtState::ELEN as usize]:,
235 [(); ExtState::VLEN as usize]:,
236 [(); ExtState::VLENB as usize]:,
237 CustomError: fmt::Debug,
238{
239 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
240 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
241 // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
242 let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
243 let mut found_first = false;
244 for i in 0..vl {
245 if !mask_bit(&mask_buf, i) {
246 continue;
247 }
248 let vs2_bit = mask_bit(&vs2_snap, i);
249 // vmsof: set only the first set bit position; clear all others (including after first)
250 let result = !found_first && vs2_bit;
251 if vs2_bit && !found_first {
252 found_first = true;
253 }
254 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
255 unsafe { write_mask_bit(ext_state.write_vreg(), vd, i, result) };
256 }
257 ext_state.mark_vs_dirty();
258 // vstart is already zero, doesn't need to be reset
259}
260
261/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
262///
263/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
264/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
265/// has been seen and the current element is past it.
266///
267/// # Safety
268/// Same as [`execute_vmsbf`].
269#[inline(always)]
270#[doc(hidden)]
271pub unsafe fn execute_vmsif<Reg, ExtState, CustomError>(
272 ext_state: &mut ExtState,
273 vd: VReg,
274 vs2: VReg,
275 vm: bool,
276 vl: u32,
277) where
278 Reg: Register,
279 ExtState: VectorRegistersExt<Reg, CustomError>,
280 [(); ExtState::ELEN as usize]:,
281 [(); ExtState::VLEN as usize]:,
282 [(); ExtState::VLENB as usize]:,
283 CustomError: fmt::Debug,
284{
285 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
286 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
287 // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
288 let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
289 let mut found_first = false;
290 for i in 0..vl {
291 if !mask_bit(&mask_buf, i) {
292 continue;
293 }
294 let vs2_bit = mask_bit(&vs2_snap, i);
295 // vmsif: set bits up to *and including* the first set bit; clear elements past it
296 let result = !found_first;
297 if vs2_bit {
298 found_first = true;
299 }
300 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
301 unsafe { write_mask_bit(ext_state.write_vreg(), vd, i, result) };
302 }
303 ext_state.mark_vs_dirty();
304 // vstart is already zero, doesn't need to be reset
305}
306
307/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2 at
308/// positions `0..i` (strictly before `i`) as a SEW-wide integer into `vd[i]`.
309///
310/// Per spec §16.8: this instruction honors the source mask; inactive mask elements of vs2 are
311/// treated as zero for the prefix sum. Inactive destination elements follow the mask-agnostic
312/// policy (here implemented as undisturbed, which is a permitted realisation).
313///
314/// The caller must reject `vstart != 0` before invocation (spec §16.8 mandatory trap).
315///
316/// # Safety
317/// - `vd` does not overlap `vs2` (checked by caller)
318/// - `vm=false` implies `vd != v0` (checked by caller)
319/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (checked by caller)
320/// - SEW wide enough to hold VLMAX-1 (checked by caller)
321/// - `vl <= VLMAX`; `vl <= VLEN`
322#[inline(always)]
323#[doc(hidden)]
324pub unsafe fn execute_viota<Reg, ExtState, CustomError>(
325 ext_state: &mut ExtState,
326 vd: VReg,
327 vs2: VReg,
328 vm: bool,
329 vl: u32,
330 sew: Vsew,
331) where
332 Reg: Register,
333 ExtState: VectorRegistersExt<Reg, CustomError>,
334 [(); ExtState::ELEN as usize]:,
335 [(); ExtState::VLEN as usize]:,
336 [(); ExtState::VLENB as usize]:,
337 CustomError: fmt::Debug,
338{
339 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
340 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
341 // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
342 let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
343 // Per spec §16.8: inactive vs2 elements are treated as zero for the prefix sum.
344 // The prefix count advances only when the execution mask is active AND the
345 // corresponding vs2 bit is set.
346 let mut prefix_count = 0u64;
347 for i in 0..vl {
348 if !mask_bit(&mask_buf, i) {
349 continue;
350 }
351 // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
352 unsafe {
353 write_element_u64(ext_state.write_vreg(), vd.bits(), i, sew, prefix_count);
354 }
355 if mask_bit(&vs2_snap, i) {
356 prefix_count += 1;
357 }
358 }
359 ext_state.mark_vs_dirty();
360 ext_state.reset_vstart();
361}
362
363/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
364/// active element in `vstart..vl`.
365///
366/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
367///
368/// # Safety
369/// - `vm=false` implies `vd != v0` (checked by caller)
370/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (checked by caller)
371/// - `vl <= group_regs * VLENB / sew_bytes`
372/// - `vl <= VLEN`
373#[inline(always)]
374#[doc(hidden)]
375pub unsafe fn execute_vid<Reg, ExtState, CustomError>(
376 ext_state: &mut ExtState,
377 vd: VReg,
378 vm: bool,
379 vl: u32,
380 vstart: u32,
381 sew: Vsew,
382) where
383 Reg: Register,
384 ExtState: VectorRegistersExt<Reg, CustomError>,
385 [(); ExtState::ELEN as usize]:,
386 [(); ExtState::VLEN as usize]:,
387 [(); ExtState::VLENB as usize]:,
388 CustomError: fmt::Debug,
389{
390 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
391 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
392 for i in vstart..vl {
393 if !mask_bit(&mask_buf, i) {
394 continue;
395 }
396 // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
397 unsafe {
398 write_element_u64(ext_state.write_vreg(), vd.bits(), i, sew, u64::from(i));
399 }
400 }
401 ext_state.mark_vs_dirty();
402 ext_state.reset_vstart();
403}