Skip to main content

ab_riscv_interpreter/v/zvexx/mask/
zvexx_mask_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zvexx::arith::zvexx_arith_helpers::{write_element_u64, write_mask_bit};
5use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
6use ab_riscv_primitives::prelude::*;
7use core::fmt;
8
9/// Execute a mask-register logical operation (§16.1).
10///
11/// Computes the result for the body elements `[vstart, vl)` only. Prestart bits `[0, vstart)`
12/// are left undisturbed, and tail bits `[vl, VLEN)` follow the tail-agnostic policy, realised
13/// here as undisturbed (a permitted agnostic implementation and the one the reference model
14/// produces). `op` receives `(vs2_bit: bool, vs1_bit: bool) -> bool`.
15///
16/// # Safety
17/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg`).
18/// `vl <= VLEN`, so `(vl - 1) / 8 < VLENB`; `vstart <= vl` by the architectural invariant.
19/// The operation snapshots both sources before writing, so `vd` may safely overlap either source.
20#[inline(always)]
21#[doc(hidden)]
22pub unsafe fn execute_mask_logical_op<Reg, ExtState, CustomError, F>(
23    ext_state: &mut ExtState,
24    vd: VReg,
25    vs2: VReg,
26    vs1: VReg,
27    op: F,
28) where
29    Reg: Register,
30    ExtState: VectorRegistersExt<Reg, CustomError>,
31    [(); ExtState::ELEN as usize]:,
32    [(); ExtState::VLEN as usize]:,
33    [(); ExtState::VLENB as usize]:,
34    CustomError: fmt::Debug,
35    F: Fn(bool, bool) -> bool,
36{
37    let vl = ext_state.vl();
38    let vstart = ext_state.vstart();
39    // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
40    let vs2_snap = *ext_state.read_vregs().get(vs2);
41    let vs1_snap = *ext_state.read_vregs().get(vs1);
42    // Body elements [vstart, vl): compute the logical operation bit-by-bit. Prestart bits
43    // [0, vstart) and tail bits [vl, VLEN) are left undisturbed.
44    for i in u32::from(vstart)..vl {
45        let a = mask_bit(&vs2_snap, i);
46        let b = mask_bit(&vs1_snap, i);
47        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
48        unsafe {
49            write_mask_bit(ext_state.write_vregs(), vd, i, op(a, b));
50        }
51    }
52    ext_state.mark_vs_dirty();
53    ext_state.reset_vstart();
54}
55
56/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
57///
58/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
59/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
60///
61/// # Safety
62/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
63/// - `vstart <= vl`
64///
65/// Returns `rd_value`.
66#[inline(always)]
67#[doc(hidden)]
68pub unsafe fn execute_vcpop<Reg, ExtState, CustomError>(
69    ext_state: &mut ExtState,
70    vs2: VReg,
71    vm: bool,
72) -> Reg::Type
73where
74    Reg: Register,
75    ExtState: VectorRegistersExt<Reg, CustomError>,
76    [(); ExtState::ELEN as usize]:,
77    [(); ExtState::VLEN as usize]:,
78    [(); ExtState::VLENB as usize]:,
79    CustomError: fmt::Debug,
80{
81    let vl = ext_state.vl();
82    let vstart = ext_state.vstart();
83    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
84    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
85    let vs2_reg = *ext_state.read_vregs().get(vs2);
86    let mut count = 0u32;
87    for i in u32::from(vstart)..vl {
88        if !mask_bit(&mask_buf, i) {
89            continue;
90        }
91        if mask_bit(&vs2_reg, i) {
92            count += 1;
93        }
94    }
95
96    ext_state.mark_vs_dirty();
97    ext_state.reset_vstart();
98
99    Reg::Type::from(count)
100}
101
102/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
103/// write result (or -1 if none) to `rd`.
104///
105/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
106/// `-1` (all-ones) if no active element of vs2 is set.
107///
108/// # Safety
109/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
110/// - `vstart <= vl`
111///
112/// Returns `rd_value`.
113#[inline(always)]
114#[doc(hidden)]
115pub unsafe fn execute_vfirst<Reg, ExtState, CustomError>(
116    ext_state: &mut ExtState,
117    vs2: VReg,
118    vm: bool,
119) -> Reg::Type
120where
121    Reg: Register,
122    ExtState: VectorRegistersExt<Reg, CustomError>,
123    [(); ExtState::ELEN as usize]:,
124    [(); ExtState::VLEN as usize]:,
125    [(); ExtState::VLENB as usize]:,
126    CustomError: fmt::Debug,
127{
128    let vl = ext_state.vl();
129    let vstart = ext_state.vstart();
130    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
131    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
132    let vs2_reg = *ext_state.read_vregs().get(vs2);
133    // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
134    let not_found = u64::MAX;
135    let mut result = not_found;
136    for i in u32::from(vstart)..vl {
137        if !mask_bit(&mask_buf, i) {
138            continue;
139        }
140        if mask_bit(&vs2_reg, i) {
141            result = u64::from(i);
142            break;
143        }
144    }
145    // Write -1 (all-ones for XLEN bits) or the found index.
146    // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
147    // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
148    // without depending on `From<u64>` (which is not in the `Register` trait bounds).
149    // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
150    let rd_value = if result == not_found {
151        !Reg::Type::from(0u8)
152    } else {
153        Reg::Type::from(result as u32)
154    };
155    ext_state.mark_vs_dirty();
156    ext_state.reset_vstart();
157
158    rd_value
159}
160
161/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
162///
163/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
164/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
165/// destination bits are cleared.
166///
167/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
168///
169/// # Safety
170/// - `vd` does not overlap `vs2` (checked by caller)
171/// - `vm=false` implies `vd != v0` (checked by caller)
172/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
173#[inline(always)]
174#[doc(hidden)]
175pub unsafe fn execute_vmsbf<Reg, ExtState, CustomError>(
176    ext_state: &mut ExtState,
177    vd: VReg,
178    vs2: VReg,
179    vm: bool,
180    vl: u32,
181) where
182    Reg: Register,
183    ExtState: VectorRegistersExt<Reg, CustomError>,
184    [(); ExtState::ELEN as usize]:,
185    [(); ExtState::VLEN as usize]:,
186    [(); ExtState::VLENB as usize]:,
187    CustomError: fmt::Debug,
188{
189    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
190    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
191    let vs2_snap = *ext_state.read_vregs().get(vs2);
192    let mut found_first = false;
193    for i in 0..vl {
194        // Inactive elements: undisturbed
195        if !mask_bit(&mask_buf, i) {
196            continue;
197        }
198        let vs2_bit = mask_bit(&vs2_snap, i);
199        // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
200        let result = !found_first && !vs2_bit;
201        if vs2_bit {
202            found_first = true;
203        }
204        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
205        unsafe {
206            write_mask_bit(ext_state.write_vregs(), vd, i, result);
207        }
208    }
209    ext_state.mark_vs_dirty();
210    // vstart is already zero, doesn't need to be reset
211}
212
213/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
214///
215/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
216/// vs2 has a set bit. All other active destination bits are cleared.
217///
218/// # Safety
219/// Same as [`execute_vmsbf`].
220#[inline(always)]
221#[doc(hidden)]
222pub unsafe fn execute_vmsof<Reg, ExtState, CustomError>(
223    ext_state: &mut ExtState,
224    vd: VReg,
225    vs2: VReg,
226    vm: bool,
227    vl: u32,
228) where
229    Reg: Register,
230    ExtState: VectorRegistersExt<Reg, CustomError>,
231    [(); ExtState::ELEN as usize]:,
232    [(); ExtState::VLEN as usize]:,
233    [(); ExtState::VLENB as usize]:,
234    CustomError: fmt::Debug,
235{
236    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
237    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
238    let vs2_snap = *ext_state.read_vregs().get(vs2);
239    let mut found_first = false;
240    for i in 0..vl {
241        if !mask_bit(&mask_buf, i) {
242            continue;
243        }
244        let vs2_bit = mask_bit(&vs2_snap, i);
245        // vmsof: set only the first set bit position; clear all others (including after first)
246        let result = !found_first && vs2_bit;
247        if vs2_bit && !found_first {
248            found_first = true;
249        }
250        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
251        unsafe {
252            write_mask_bit(ext_state.write_vregs(), vd, i, result);
253        }
254    }
255    ext_state.mark_vs_dirty();
256    // vstart is already zero, doesn't need to be reset
257}
258
259/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
260///
261/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
262/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
263/// has been seen and the current element is past it.
264///
265/// # Safety
266/// Same as [`execute_vmsbf`].
267#[inline(always)]
268#[doc(hidden)]
269pub unsafe fn execute_vmsif<Reg, ExtState, CustomError>(
270    ext_state: &mut ExtState,
271    vd: VReg,
272    vs2: VReg,
273    vm: bool,
274    vl: u32,
275) where
276    Reg: Register,
277    ExtState: VectorRegistersExt<Reg, CustomError>,
278    [(); ExtState::ELEN as usize]:,
279    [(); ExtState::VLEN as usize]:,
280    [(); ExtState::VLENB as usize]:,
281    CustomError: fmt::Debug,
282{
283    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
284    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
285    let vs2_snap = *ext_state.read_vregs().get(vs2);
286    let mut found_first = false;
287    for i in 0..vl {
288        if !mask_bit(&mask_buf, i) {
289            continue;
290        }
291        let vs2_bit = mask_bit(&vs2_snap, i);
292        // vmsif: set bits up to *and including* the first set bit; clear elements past it
293        let result = !found_first;
294        if vs2_bit {
295            found_first = true;
296        }
297        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
298        unsafe {
299            write_mask_bit(ext_state.write_vregs(), vd, i, result);
300        }
301    }
302    ext_state.mark_vs_dirty();
303    // vstart is already zero, doesn't need to be reset
304}
305
306/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2 at
307/// positions `0..i` (strictly before `i`) as a SEW-wide integer into `vd[i]`.
308///
309/// Per spec §16.8: this instruction honors the source mask; inactive mask elements of vs2 are
310/// treated as zero for the prefix sum. Inactive destination elements follow the mask-agnostic
311/// policy (here implemented as undisturbed, which is a permitted realisation).
312///
313/// If SEW is too narrow to hold the prefix count, the value wraps (truncates to SEW) via
314/// [`write_element_u64()`]; the spec does not raise an exception for this case.
315///
316/// The caller must reject `vstart != 0` before invocation (spec §16.8 mandatory trap).
317///
318/// # Safety
319/// - `vd` does not overlap `vs2` (checked by caller)
320/// - `vm=false` implies `vd != v0` (checked by caller)
321/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
322/// - `vl <= VLMAX`; `vl <= VLEN`
323#[inline(always)]
324#[doc(hidden)]
325pub unsafe fn execute_viota<Reg, ExtState, CustomError>(
326    ext_state: &mut ExtState,
327    vd: VReg,
328    vs2: VReg,
329    vm: bool,
330    vl: u32,
331    sew: Vsew,
332) where
333    Reg: Register,
334    ExtState: VectorRegistersExt<Reg, CustomError>,
335    [(); ExtState::ELEN as usize]:,
336    [(); ExtState::VLEN as usize]:,
337    [(); ExtState::VLENB as usize]:,
338    CustomError: fmt::Debug,
339{
340    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
341    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
342    let vs2_snap = *ext_state.read_vregs().get(vs2);
343    // Per spec §16.8: inactive vs2 elements are treated as zero for the prefix sum.
344    // The prefix count advances only when the execution mask is active AND the
345    // corresponding vs2 bit is set.
346    let mut prefix_count = 0u64;
347    for i in 0..vl {
348        if !mask_bit(&mask_buf, i) {
349            continue;
350        }
351        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
352        unsafe {
353            write_element_u64(ext_state.write_vregs(), vd, i, sew, prefix_count);
354        }
355        if mask_bit(&vs2_snap, i) {
356            prefix_count += 1;
357        }
358    }
359    ext_state.mark_vs_dirty();
360    ext_state.reset_vstart();
361}
362
363/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
364/// active element in `vstart..vl`.
365///
366/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
367///
368/// # Safety
369/// - `vm=false` implies `vd != v0` (checked by caller)
370/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
371/// - `vl <= group_regs * VLENB / sew_bytes`
372/// - `vl <= VLEN`
373#[inline(always)]
374#[doc(hidden)]
375pub unsafe fn execute_vid<Reg, ExtState, CustomError>(
376    ext_state: &mut ExtState,
377    vd: VReg,
378    vm: bool,
379    sew: Vsew,
380) where
381    Reg: Register,
382    ExtState: VectorRegistersExt<Reg, CustomError>,
383    [(); ExtState::ELEN as usize]:,
384    [(); ExtState::VLEN as usize]:,
385    [(); ExtState::VLENB as usize]:,
386    CustomError: fmt::Debug,
387{
388    let vl = ext_state.vl();
389    let vstart = ext_state.vstart();
390    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
391    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
392    for i in u32::from(vstart)..vl {
393        if !mask_bit(&mask_buf, i) {
394            continue;
395        }
396        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
397        unsafe {
398            write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(i));
399        }
400    }
401    ext_state.mark_vs_dirty();
402    ext_state.reset_vstart();
403}