ab_riscv_interpreter/v/zvexx/mask/zvexx_mask_helpers.rs
1//! Opaque helpers for ZveXx extension
2
3use crate::RegisterFile;
4use crate::v::vector_registers::VectorRegistersExt;
5use crate::v::zvexx::arith::zvexx_arith_helpers::{write_element_u64, write_mask_bit};
6use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute a mask-register logical operation (§16.1).
11///
12/// Computes the result for the body elements `[vstart, vl)` only. Prestart bits `[0, vstart)`
13/// are left undisturbed, and tail bits `[vl, VLEN)` follow the tail-agnostic policy, realised
14/// here as undisturbed (a permitted agnostic implementation and the one the reference model
15/// produces). `op` receives `(vs2_bit: bool, vs1_bit: bool) -> bool`.
16///
17/// # Safety
18/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg`).
19/// `vl <= VLEN`, so `(vl - 1) / 8 < VLENB`; `vstart <= vl` by the architectural invariant.
20/// The operation snapshots both sources before writing, so `vd` may safely overlap either source.
21#[inline(always)]
22#[doc(hidden)]
23pub unsafe fn execute_mask_logical_op<Reg, ExtState, CustomError, F>(
24 ext_state: &mut ExtState,
25 vd: VReg,
26 vs2: VReg,
27 vs1: VReg,
28 op: F,
29) where
30 Reg: Register,
31 ExtState: VectorRegistersExt<Reg, CustomError>,
32 [(); ExtState::ELEN as usize]:,
33 [(); ExtState::VLEN as usize]:,
34 [(); ExtState::VLENB as usize]:,
35 CustomError: fmt::Debug,
36 F: Fn(bool, bool) -> bool,
37{
38 let vl = ext_state.vl();
39 let vstart = ext_state.vstart();
40 // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
41 let vs2_snap = *ext_state.read_vregs().get(vs2);
42 let vs1_snap = *ext_state.read_vregs().get(vs1);
43 // Body elements [vstart, vl): compute the logical operation bit-by-bit. Prestart bits
44 // [0, vstart) and tail bits [vl, VLEN) are left undisturbed.
45 for i in u32::from(vstart)..vl {
46 let a = mask_bit(&vs2_snap, i);
47 let b = mask_bit(&vs1_snap, i);
48 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
49 unsafe {
50 write_mask_bit(ext_state.write_vregs(), vd, i, op(a, b));
51 }
52 }
53 ext_state.mark_vs_dirty();
54 ext_state.reset_vstart();
55}
56
57/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
58///
59/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
60/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
61///
62/// # Safety
63/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
64/// - `vstart <= vl`
65#[inline(always)]
66#[doc(hidden)]
67pub unsafe fn execute_vcpop<Reg, Regs, ExtState, CustomError>(
68 regs: &mut Regs,
69 ext_state: &mut ExtState,
70 rd: Reg,
71 vs2: VReg,
72 vm: bool,
73) where
74 Reg: Register,
75 Regs: RegisterFile<Reg>,
76 ExtState: VectorRegistersExt<Reg, CustomError>,
77 [(); ExtState::ELEN as usize]:,
78 [(); ExtState::VLEN as usize]:,
79 [(); ExtState::VLENB as usize]:,
80 CustomError: fmt::Debug,
81{
82 let vl = ext_state.vl();
83 let vstart = ext_state.vstart();
84 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
85 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
86 let vs2_reg = *ext_state.read_vregs().get(vs2);
87 let mut count = 0u32;
88 for i in u32::from(vstart)..vl {
89 if !mask_bit(&mask_buf, i) {
90 continue;
91 }
92 if mask_bit(&vs2_reg, i) {
93 count += 1;
94 }
95 }
96 regs.write(rd, Reg::Type::from(count));
97 ext_state.mark_vs_dirty();
98 ext_state.reset_vstart();
99}
100
101/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
102/// write result (or -1 if none) to `rd`.
103///
104/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
105/// `-1` (all-ones) if no active element of vs2 is set.
106///
107/// # Safety
108/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
109/// - `vstart <= vl`
110#[inline(always)]
111#[doc(hidden)]
112pub unsafe fn execute_vfirst<Reg, Regs, ExtState, CustomError>(
113 regs: &mut Regs,
114 ext_state: &mut ExtState,
115 rd: Reg,
116 vs2: VReg,
117 vm: bool,
118) where
119 Reg: Register,
120 Regs: RegisterFile<Reg>,
121 ExtState: VectorRegistersExt<Reg, CustomError>,
122 [(); ExtState::ELEN as usize]:,
123 [(); ExtState::VLEN as usize]:,
124 [(); ExtState::VLENB as usize]:,
125 CustomError: fmt::Debug,
126{
127 let vl = ext_state.vl();
128 let vstart = ext_state.vstart();
129 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
130 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
131 let vs2_reg = *ext_state.read_vregs().get(vs2);
132 // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
133 let not_found = u64::MAX;
134 let mut result = not_found;
135 for i in u32::from(vstart)..vl {
136 if !mask_bit(&mask_buf, i) {
137 continue;
138 }
139 if mask_bit(&vs2_reg, i) {
140 result = u64::from(i);
141 break;
142 }
143 }
144 // Write -1 (all-ones for XLEN bits) or the found index.
145 // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
146 // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
147 // without depending on `From<u64>` (which is not in the `Register` trait bounds).
148 // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
149 let reg_value = if result == not_found {
150 !Reg::Type::from(0u8)
151 } else {
152 Reg::Type::from(result as u32)
153 };
154 regs.write(rd, reg_value);
155 ext_state.mark_vs_dirty();
156 ext_state.reset_vstart();
157}
158
159/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
160///
161/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
162/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
163/// destination bits are cleared.
164///
165/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
166///
167/// # Safety
168/// - `vd` does not overlap `vs2` (checked by caller)
169/// - `vm=false` implies `vd != v0` (checked by caller)
170/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
171#[inline(always)]
172#[doc(hidden)]
173pub unsafe fn execute_vmsbf<Reg, ExtState, CustomError>(
174 ext_state: &mut ExtState,
175 vd: VReg,
176 vs2: VReg,
177 vm: bool,
178 vl: u32,
179) where
180 Reg: Register,
181 ExtState: VectorRegistersExt<Reg, CustomError>,
182 [(); ExtState::ELEN as usize]:,
183 [(); ExtState::VLEN as usize]:,
184 [(); ExtState::VLENB as usize]:,
185 CustomError: fmt::Debug,
186{
187 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
188 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
189 let vs2_snap = *ext_state.read_vregs().get(vs2);
190 let mut found_first = false;
191 for i in 0..vl {
192 // Inactive elements: undisturbed
193 if !mask_bit(&mask_buf, i) {
194 continue;
195 }
196 let vs2_bit = mask_bit(&vs2_snap, i);
197 // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
198 let result = !found_first && !vs2_bit;
199 if vs2_bit {
200 found_first = true;
201 }
202 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
203 unsafe {
204 write_mask_bit(ext_state.write_vregs(), vd, i, result);
205 }
206 }
207 ext_state.mark_vs_dirty();
208 // vstart is already zero, doesn't need to be reset
209}
210
211/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
212///
213/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
214/// vs2 has a set bit. All other active destination bits are cleared.
215///
216/// # Safety
217/// Same as [`execute_vmsbf`].
218#[inline(always)]
219#[doc(hidden)]
220pub unsafe fn execute_vmsof<Reg, ExtState, CustomError>(
221 ext_state: &mut ExtState,
222 vd: VReg,
223 vs2: VReg,
224 vm: bool,
225 vl: u32,
226) where
227 Reg: Register,
228 ExtState: VectorRegistersExt<Reg, CustomError>,
229 [(); ExtState::ELEN as usize]:,
230 [(); ExtState::VLEN as usize]:,
231 [(); ExtState::VLENB as usize]:,
232 CustomError: fmt::Debug,
233{
234 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
235 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
236 let vs2_snap = *ext_state.read_vregs().get(vs2);
237 let mut found_first = false;
238 for i in 0..vl {
239 if !mask_bit(&mask_buf, i) {
240 continue;
241 }
242 let vs2_bit = mask_bit(&vs2_snap, i);
243 // vmsof: set only the first set bit position; clear all others (including after first)
244 let result = !found_first && vs2_bit;
245 if vs2_bit && !found_first {
246 found_first = true;
247 }
248 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
249 unsafe {
250 write_mask_bit(ext_state.write_vregs(), vd, i, result);
251 }
252 }
253 ext_state.mark_vs_dirty();
254 // vstart is already zero, doesn't need to be reset
255}
256
257/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
258///
259/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
260/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
261/// has been seen and the current element is past it.
262///
263/// # Safety
264/// Same as [`execute_vmsbf`].
265#[inline(always)]
266#[doc(hidden)]
267pub unsafe fn execute_vmsif<Reg, ExtState, CustomError>(
268 ext_state: &mut ExtState,
269 vd: VReg,
270 vs2: VReg,
271 vm: bool,
272 vl: u32,
273) where
274 Reg: Register,
275 ExtState: VectorRegistersExt<Reg, CustomError>,
276 [(); ExtState::ELEN as usize]:,
277 [(); ExtState::VLEN as usize]:,
278 [(); ExtState::VLENB as usize]:,
279 CustomError: fmt::Debug,
280{
281 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
282 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
283 let vs2_snap = *ext_state.read_vregs().get(vs2);
284 let mut found_first = false;
285 for i in 0..vl {
286 if !mask_bit(&mask_buf, i) {
287 continue;
288 }
289 let vs2_bit = mask_bit(&vs2_snap, i);
290 // vmsif: set bits up to *and including* the first set bit; clear elements past it
291 let result = !found_first;
292 if vs2_bit {
293 found_first = true;
294 }
295 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
296 unsafe {
297 write_mask_bit(ext_state.write_vregs(), vd, i, result);
298 }
299 }
300 ext_state.mark_vs_dirty();
301 // vstart is already zero, doesn't need to be reset
302}
303
304/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2 at
305/// positions `0..i` (strictly before `i`) as a SEW-wide integer into `vd[i]`.
306///
307/// Per spec §16.8: this instruction honors the source mask; inactive mask elements of vs2 are
308/// treated as zero for the prefix sum. Inactive destination elements follow the mask-agnostic
309/// policy (here implemented as undisturbed, which is a permitted realisation).
310///
311/// If SEW is too narrow to hold the prefix count, the value wraps (truncates to SEW) via
312/// [`write_element_u64()`]; the spec does not raise an exception for this case.
313///
314/// The caller must reject `vstart != 0` before invocation (spec §16.8 mandatory trap).
315///
316/// # Safety
317/// - `vd` does not overlap `vs2` (checked by caller)
318/// - `vm=false` implies `vd != v0` (checked by caller)
319/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
320/// - `vl <= VLMAX`; `vl <= VLEN`
321#[inline(always)]
322#[doc(hidden)]
323pub unsafe fn execute_viota<Reg, ExtState, CustomError>(
324 ext_state: &mut ExtState,
325 vd: VReg,
326 vs2: VReg,
327 vm: bool,
328 vl: u32,
329 sew: Vsew,
330) where
331 Reg: Register,
332 ExtState: VectorRegistersExt<Reg, CustomError>,
333 [(); ExtState::ELEN as usize]:,
334 [(); ExtState::VLEN as usize]:,
335 [(); ExtState::VLENB as usize]:,
336 CustomError: fmt::Debug,
337{
338 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
339 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
340 let vs2_snap = *ext_state.read_vregs().get(vs2);
341 // Per spec §16.8: inactive vs2 elements are treated as zero for the prefix sum.
342 // The prefix count advances only when the execution mask is active AND the
343 // corresponding vs2 bit is set.
344 let mut prefix_count = 0u64;
345 for i in 0..vl {
346 if !mask_bit(&mask_buf, i) {
347 continue;
348 }
349 // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
350 unsafe {
351 write_element_u64(ext_state.write_vregs(), vd, i, sew, prefix_count);
352 }
353 if mask_bit(&vs2_snap, i) {
354 prefix_count += 1;
355 }
356 }
357 ext_state.mark_vs_dirty();
358 ext_state.reset_vstart();
359}
360
361/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
362/// active element in `vstart..vl`.
363///
364/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
365///
366/// # Safety
367/// - `vm=false` implies `vd != v0` (checked by caller)
368/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
369/// - `vl <= group_regs * VLENB / sew_bytes`
370/// - `vl <= VLEN`
371#[inline(always)]
372#[doc(hidden)]
373pub unsafe fn execute_vid<Reg, ExtState, CustomError>(
374 ext_state: &mut ExtState,
375 vd: VReg,
376 vm: bool,
377 sew: Vsew,
378) where
379 Reg: Register,
380 ExtState: VectorRegistersExt<Reg, CustomError>,
381 [(); ExtState::ELEN as usize]:,
382 [(); ExtState::VLEN as usize]:,
383 [(); ExtState::VLENB as usize]:,
384 CustomError: fmt::Debug,
385{
386 let vl = ext_state.vl();
387 let vstart = ext_state.vstart();
388 // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
389 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
390 for i in u32::from(vstart)..vl {
391 if !mask_bit(&mask_buf, i) {
392 continue;
393 }
394 // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
395 unsafe {
396 write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(i));
397 }
398 }
399 ext_state.mark_vs_dirty();
400 ext_state.reset_vstart();
401}