ab_riscv_interpreter/v/zvexx/mask/zvexx_mask_helpers.rs
1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zvexx::arith::zvexx_arith_helpers::{write_element_u64, write_mask_bit};
5use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
6use ab_riscv_primitives::prelude::*;
7use core::fmt;
8
9/// Execute a mask-register logical operation (§16.1).
10///
11/// Computes the result for the body elements `[vstart, vl)` only. Prestart bits `[0, vstart)`
12/// are left undisturbed, and tail bits `[vl, VLEN)` follow the tail-agnostic policy, realised
13/// here as undisturbed (a permitted agnostic implementation and the one the reference model
14/// produces). `op` receives `(vs2_bit: bool, vs1_bit: bool) -> bool`.
15///
16/// # Safety
17/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg`).
18/// `vl <= VLEN`, so `(vl - 1) / 8 < VLENB`; `vstart <= vl` by the architectural invariant.
19/// The operation snapshots both sources before writing, so `vd` may safely overlap either source.
20#[inline(always)]
21#[doc(hidden)]
22pub unsafe fn execute_mask_logical_op<Reg, ExtState, CustomError, F>(
23 ext_state: &mut ExtState,
24 vd: VReg,
25 vs2: VReg,
26 vs1: VReg,
27 op: F,
28) where
29 Reg: Register,
30 ExtState: VectorRegistersExt<Reg, CustomError>,
31 [(); ExtState::ELEN as usize]:,
32 [(); ExtState::VLEN as usize]:,
33 [(); ExtState::VLENB as usize]:,
34 CustomError: fmt::Debug,
35 F: Fn(bool, bool) -> bool,
36{
37 let vl = ext_state.vl();
38 let vstart = ext_state.vstart();
39 // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
40 let vs2_snap = *ext_state.read_vregs().get(vs2);
41 let vs1_snap = *ext_state.read_vregs().get(vs1);
42 // Body elements [vstart, vl): compute the logical operation bit-by-bit. Prestart bits
43 // [0, vstart) and tail bits [vl, VLEN) are left undisturbed.
44 for i in u32::from(vstart)..vl {
45 let a = mask_bit(&vs2_snap, i);
46 let b = mask_bit(&vs1_snap, i);
47 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
48 unsafe {
49 write_mask_bit(ext_state.write_vregs(), vd, i, op(a, b));
50 }
51 }
52 ext_state.mark_vs_dirty();
53 ext_state.reset_vstart();
54}
55
56/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
57///
58/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
59/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
60///
61/// # Safety
62/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
63/// - `vstart <= vl`
64///
65/// Returns `rd_value`.
66#[inline(always)]
67#[doc(hidden)]
68pub unsafe fn execute_vcpop<Reg, ExtState, CustomError>(
69 ext_state: &mut ExtState,
70 vs2: VReg,
71 vm: bool,
72) -> Reg::Type
73where
74 Reg: Register,
75 ExtState: VectorRegistersExt<Reg, CustomError>,
76 [(); ExtState::ELEN as usize]:,
77 [(); ExtState::VLEN as usize]:,
78 [(); ExtState::VLENB as usize]:,
79 CustomError: fmt::Debug,
80{
81 let vl = ext_state.vl();
82 let vstart = ext_state.vstart();
83 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
84 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
85 let vs2_reg = *ext_state.read_vregs().get(vs2);
86 let mut count = 0u32;
87 for i in u32::from(vstart)..vl {
88 if !mask_bit(&mask_buf, i) {
89 continue;
90 }
91 if mask_bit(&vs2_reg, i) {
92 count += 1;
93 }
94 }
95
96 ext_state.mark_vs_dirty();
97 ext_state.reset_vstart();
98
99 Reg::Type::from(count)
100}
101
102/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
103/// write result (or -1 if none) to `rd`.
104///
105/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
106/// `-1` (all-ones) if no active element of vs2 is set.
107///
108/// # Safety
109/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
110/// - `vstart <= vl`
111///
112/// Returns `rd_value`.
113#[inline(always)]
114#[doc(hidden)]
115pub unsafe fn execute_vfirst<Reg, ExtState, CustomError>(
116 ext_state: &mut ExtState,
117 vs2: VReg,
118 vm: bool,
119) -> Reg::Type
120where
121 Reg: Register,
122 ExtState: VectorRegistersExt<Reg, CustomError>,
123 [(); ExtState::ELEN as usize]:,
124 [(); ExtState::VLEN as usize]:,
125 [(); ExtState::VLENB as usize]:,
126 CustomError: fmt::Debug,
127{
128 let vl = ext_state.vl();
129 let vstart = ext_state.vstart();
130 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
131 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
132 let vs2_reg = *ext_state.read_vregs().get(vs2);
133 // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
134 let not_found = u64::MAX;
135 let mut result = not_found;
136 for i in u32::from(vstart)..vl {
137 if !mask_bit(&mask_buf, i) {
138 continue;
139 }
140 if mask_bit(&vs2_reg, i) {
141 result = u64::from(i);
142 break;
143 }
144 }
145 // Write -1 (all-ones for XLEN bits) or the found index.
146 // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
147 // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
148 // without depending on `From<u64>` (which is not in the `Register` trait bounds).
149 // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
150 let rd_value = if result == not_found {
151 !Reg::Type::from(0u8)
152 } else {
153 Reg::Type::from(result as u32)
154 };
155 ext_state.mark_vs_dirty();
156 ext_state.reset_vstart();
157
158 rd_value
159}
160
161/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
162///
163/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
164/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
165/// destination bits are cleared.
166///
167/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
168///
169/// # Safety
170/// - `vd` does not overlap `vs2` (checked by caller)
171/// - `vm=false` implies `vd != v0` (checked by caller)
172/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
173#[inline(always)]
174#[doc(hidden)]
175pub unsafe fn execute_vmsbf<Reg, ExtState, CustomError>(
176 ext_state: &mut ExtState,
177 vd: VReg,
178 vs2: VReg,
179 vm: bool,
180 vl: u32,
181) where
182 Reg: Register,
183 ExtState: VectorRegistersExt<Reg, CustomError>,
184 [(); ExtState::ELEN as usize]:,
185 [(); ExtState::VLEN as usize]:,
186 [(); ExtState::VLENB as usize]:,
187 CustomError: fmt::Debug,
188{
189 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
190 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
191 let vs2_snap = *ext_state.read_vregs().get(vs2);
192 let mut found_first = false;
193 for i in 0..vl {
194 // Inactive elements: undisturbed
195 if !mask_bit(&mask_buf, i) {
196 continue;
197 }
198 let vs2_bit = mask_bit(&vs2_snap, i);
199 // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
200 let result = !found_first && !vs2_bit;
201 if vs2_bit {
202 found_first = true;
203 }
204 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
205 unsafe {
206 write_mask_bit(ext_state.write_vregs(), vd, i, result);
207 }
208 }
209 ext_state.mark_vs_dirty();
210 // vstart is already zero, doesn't need to be reset
211}
212
213/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
214///
215/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
216/// vs2 has a set bit. All other active destination bits are cleared.
217///
218/// # Safety
219/// Same as [`execute_vmsbf`].
220#[inline(always)]
221#[doc(hidden)]
222pub unsafe fn execute_vmsof<Reg, ExtState, CustomError>(
223 ext_state: &mut ExtState,
224 vd: VReg,
225 vs2: VReg,
226 vm: bool,
227 vl: u32,
228) where
229 Reg: Register,
230 ExtState: VectorRegistersExt<Reg, CustomError>,
231 [(); ExtState::ELEN as usize]:,
232 [(); ExtState::VLEN as usize]:,
233 [(); ExtState::VLENB as usize]:,
234 CustomError: fmt::Debug,
235{
236 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
237 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
238 let vs2_snap = *ext_state.read_vregs().get(vs2);
239 let mut found_first = false;
240 for i in 0..vl {
241 if !mask_bit(&mask_buf, i) {
242 continue;
243 }
244 let vs2_bit = mask_bit(&vs2_snap, i);
245 // vmsof: set only the first set bit position; clear all others (including after first)
246 let result = !found_first && vs2_bit;
247 if vs2_bit && !found_first {
248 found_first = true;
249 }
250 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
251 unsafe {
252 write_mask_bit(ext_state.write_vregs(), vd, i, result);
253 }
254 }
255 ext_state.mark_vs_dirty();
256 // vstart is already zero, doesn't need to be reset
257}
258
259/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
260///
261/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
262/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
263/// has been seen and the current element is past it.
264///
265/// # Safety
266/// Same as [`execute_vmsbf`].
267#[inline(always)]
268#[doc(hidden)]
269pub unsafe fn execute_vmsif<Reg, ExtState, CustomError>(
270 ext_state: &mut ExtState,
271 vd: VReg,
272 vs2: VReg,
273 vm: bool,
274 vl: u32,
275) where
276 Reg: Register,
277 ExtState: VectorRegistersExt<Reg, CustomError>,
278 [(); ExtState::ELEN as usize]:,
279 [(); ExtState::VLEN as usize]:,
280 [(); ExtState::VLENB as usize]:,
281 CustomError: fmt::Debug,
282{
283 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
284 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
285 let vs2_snap = *ext_state.read_vregs().get(vs2);
286 let mut found_first = false;
287 for i in 0..vl {
288 if !mask_bit(&mask_buf, i) {
289 continue;
290 }
291 let vs2_bit = mask_bit(&vs2_snap, i);
292 // vmsif: set bits up to *and including* the first set bit; clear elements past it
293 let result = !found_first;
294 if vs2_bit {
295 found_first = true;
296 }
297 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
298 unsafe {
299 write_mask_bit(ext_state.write_vregs(), vd, i, result);
300 }
301 }
302 ext_state.mark_vs_dirty();
303 // vstart is already zero, doesn't need to be reset
304}
305
306/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2 at
307/// positions `0..i` (strictly before `i`) as a SEW-wide integer into `vd[i]`.
308///
309/// Per spec §16.8: this instruction honors the source mask; inactive mask elements of vs2 are
310/// treated as zero for the prefix sum. Inactive destination elements follow the mask-agnostic
311/// policy (here implemented as undisturbed, which is a permitted realisation).
312///
313/// If SEW is too narrow to hold the prefix count, the value wraps (truncates to SEW) via
314/// [`write_element_u64()`]; the spec does not raise an exception for this case.
315///
316/// The caller must reject `vstart != 0` before invocation (spec §16.8 mandatory trap).
317///
318/// # Safety
319/// - `vd` does not overlap `vs2` (checked by caller)
320/// - `vm=false` implies `vd != v0` (checked by caller)
321/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
322/// - `vl <= VLMAX`; `vl <= VLEN`
323#[inline(always)]
324#[doc(hidden)]
325pub unsafe fn execute_viota<Reg, ExtState, CustomError>(
326 ext_state: &mut ExtState,
327 vd: VReg,
328 vs2: VReg,
329 vm: bool,
330 vl: u32,
331 sew: Vsew,
332) where
333 Reg: Register,
334 ExtState: VectorRegistersExt<Reg, CustomError>,
335 [(); ExtState::ELEN as usize]:,
336 [(); ExtState::VLEN as usize]:,
337 [(); ExtState::VLENB as usize]:,
338 CustomError: fmt::Debug,
339{
340 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
341 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
342 let vs2_snap = *ext_state.read_vregs().get(vs2);
343 // Per spec §16.8: inactive vs2 elements are treated as zero for the prefix sum.
344 // The prefix count advances only when the execution mask is active AND the
345 // corresponding vs2 bit is set.
346 let mut prefix_count = 0u64;
347 for i in 0..vl {
348 if !mask_bit(&mask_buf, i) {
349 continue;
350 }
351 // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
352 unsafe {
353 write_element_u64(ext_state.write_vregs(), vd, i, sew, prefix_count);
354 }
355 if mask_bit(&vs2_snap, i) {
356 prefix_count += 1;
357 }
358 }
359 ext_state.mark_vs_dirty();
360 ext_state.reset_vstart();
361}
362
363/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
364/// active element in `vstart..vl`.
365///
366/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
367///
368/// # Safety
369/// - `vm=false` implies `vd != v0` (checked by caller)
370/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
371/// - `vl <= group_regs * VLENB / sew_bytes`
372/// - `vl <= VLEN`
373#[inline(always)]
374#[doc(hidden)]
375pub unsafe fn execute_vid<Reg, ExtState, CustomError>(
376 ext_state: &mut ExtState,
377 vd: VReg,
378 vm: bool,
379 sew: Vsew,
380) where
381 Reg: Register,
382 ExtState: VectorRegistersExt<Reg, CustomError>,
383 [(); ExtState::ELEN as usize]:,
384 [(); ExtState::VLEN as usize]:,
385 [(); ExtState::VLENB as usize]:,
386 CustomError: fmt::Debug,
387{
388 let vl = ext_state.vl();
389 let vstart = ext_state.vstart();
390 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
391 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
392 for i in u32::from(vstart)..vl {
393 if !mask_bit(&mask_buf, i) {
394 continue;
395 }
396 // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
397 unsafe {
398 write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(i));
399 }
400 }
401 ext_state.mark_vs_dirty();
402 ext_state.reset_vstart();
403}