Skip to main content

ab_riscv_interpreter/v/zvexx/mask/
zvexx_mask_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::RegisterFile;
4use crate::v::vector_registers::VectorRegistersExt;
5use crate::v::zvexx::arith::zvexx_arith_helpers::{write_element_u64, write_mask_bit};
6use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute a mask-register logical operation (§16.1).
11///
12/// Computes the result for the body elements `[vstart, vl)` only. Prestart bits `[0, vstart)`
13/// are left undisturbed, and tail bits `[vl, VLEN)` follow the tail-agnostic policy, realised
14/// here as undisturbed (a permitted agnostic implementation and the one the reference model
15/// produces). `op` receives `(vs2_bit: bool, vs1_bit: bool) -> bool`.
16///
17/// # Safety
18/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg`).
19/// `vl <= VLEN`, so `(vl - 1) / 8 < VLENB`; `vstart <= vl` by the architectural invariant.
20/// The operation snapshots both sources before writing, so `vd` may safely overlap either source.
21#[inline(always)]
22#[doc(hidden)]
23pub unsafe fn execute_mask_logical_op<Reg, ExtState, CustomError, F>(
24    ext_state: &mut ExtState,
25    vd: VReg,
26    vs2: VReg,
27    vs1: VReg,
28    op: F,
29) where
30    Reg: Register,
31    ExtState: VectorRegistersExt<Reg, CustomError>,
32    [(); ExtState::ELEN as usize]:,
33    [(); ExtState::VLEN as usize]:,
34    [(); ExtState::VLENB as usize]:,
35    CustomError: fmt::Debug,
36    F: Fn(bool, bool) -> bool,
37{
38    let vl = ext_state.vl();
39    let vstart = ext_state.vstart();
40    // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
41    let vs2_snap = *ext_state.read_vregs().get(vs2);
42    let vs1_snap = *ext_state.read_vregs().get(vs1);
43    // Body elements [vstart, vl): compute the logical operation bit-by-bit. Prestart bits
44    // [0, vstart) and tail bits [vl, VLEN) are left undisturbed.
45    for i in u32::from(vstart)..vl {
46        let a = mask_bit(&vs2_snap, i);
47        let b = mask_bit(&vs1_snap, i);
48        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
49        unsafe {
50            write_mask_bit(ext_state.write_vregs(), vd, i, op(a, b));
51        }
52    }
53    ext_state.mark_vs_dirty();
54    ext_state.reset_vstart();
55}
56
57/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
58///
59/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
60/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
61///
62/// # Safety
63/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
64/// - `vstart <= vl`
65#[inline(always)]
66#[doc(hidden)]
67pub unsafe fn execute_vcpop<Reg, Regs, ExtState, CustomError>(
68    regs: &mut Regs,
69    ext_state: &mut ExtState,
70    rd: Reg,
71    vs2: VReg,
72    vm: bool,
73) where
74    Reg: Register,
75    Regs: RegisterFile<Reg>,
76    ExtState: VectorRegistersExt<Reg, CustomError>,
77    [(); ExtState::ELEN as usize]:,
78    [(); ExtState::VLEN as usize]:,
79    [(); ExtState::VLENB as usize]:,
80    CustomError: fmt::Debug,
81{
82    let vl = ext_state.vl();
83    let vstart = ext_state.vstart();
84    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
85    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
86    let vs2_reg = *ext_state.read_vregs().get(vs2);
87    let mut count = 0u32;
88    for i in u32::from(vstart)..vl {
89        if !mask_bit(&mask_buf, i) {
90            continue;
91        }
92        if mask_bit(&vs2_reg, i) {
93            count += 1;
94        }
95    }
96    regs.write(rd, Reg::Type::from(count));
97    ext_state.mark_vs_dirty();
98    ext_state.reset_vstart();
99}
100
101/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
102/// write result (or -1 if none) to `rd`.
103///
104/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
105/// `-1` (all-ones) if no active element of vs2 is set.
106///
107/// # Safety
108/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
109/// - `vstart <= vl`
110#[inline(always)]
111#[doc(hidden)]
112pub unsafe fn execute_vfirst<Reg, Regs, ExtState, CustomError>(
113    regs: &mut Regs,
114    ext_state: &mut ExtState,
115    rd: Reg,
116    vs2: VReg,
117    vm: bool,
118) where
119    Reg: Register,
120    Regs: RegisterFile<Reg>,
121    ExtState: VectorRegistersExt<Reg, CustomError>,
122    [(); ExtState::ELEN as usize]:,
123    [(); ExtState::VLEN as usize]:,
124    [(); ExtState::VLENB as usize]:,
125    CustomError: fmt::Debug,
126{
127    let vl = ext_state.vl();
128    let vstart = ext_state.vstart();
129    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
130    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
131    let vs2_reg = *ext_state.read_vregs().get(vs2);
132    // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
133    let not_found = u64::MAX;
134    let mut result = not_found;
135    for i in u32::from(vstart)..vl {
136        if !mask_bit(&mask_buf, i) {
137            continue;
138        }
139        if mask_bit(&vs2_reg, i) {
140            result = u64::from(i);
141            break;
142        }
143    }
144    // Write -1 (all-ones for XLEN bits) or the found index.
145    // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
146    // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
147    // without depending on `From<u64>` (which is not in the `Register` trait bounds).
148    // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
149    let reg_value = if result == not_found {
150        !Reg::Type::from(0u8)
151    } else {
152        Reg::Type::from(result as u32)
153    };
154    regs.write(rd, reg_value);
155    ext_state.mark_vs_dirty();
156    ext_state.reset_vstart();
157}
158
159/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
160///
161/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
162/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
163/// destination bits are cleared.
164///
165/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
166///
167/// # Safety
168/// - `vd` does not overlap `vs2` (checked by caller)
169/// - `vm=false` implies `vd != v0` (checked by caller)
170/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
171#[inline(always)]
172#[doc(hidden)]
173pub unsafe fn execute_vmsbf<Reg, ExtState, CustomError>(
174    ext_state: &mut ExtState,
175    vd: VReg,
176    vs2: VReg,
177    vm: bool,
178    vl: u32,
179) where
180    Reg: Register,
181    ExtState: VectorRegistersExt<Reg, CustomError>,
182    [(); ExtState::ELEN as usize]:,
183    [(); ExtState::VLEN as usize]:,
184    [(); ExtState::VLENB as usize]:,
185    CustomError: fmt::Debug,
186{
187    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
188    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
189    let vs2_snap = *ext_state.read_vregs().get(vs2);
190    let mut found_first = false;
191    for i in 0..vl {
192        // Inactive elements: undisturbed
193        if !mask_bit(&mask_buf, i) {
194            continue;
195        }
196        let vs2_bit = mask_bit(&vs2_snap, i);
197        // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
198        let result = !found_first && !vs2_bit;
199        if vs2_bit {
200            found_first = true;
201        }
202        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
203        unsafe {
204            write_mask_bit(ext_state.write_vregs(), vd, i, result);
205        }
206    }
207    ext_state.mark_vs_dirty();
208    // vstart is already zero, doesn't need to be reset
209}
210
211/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
212///
213/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
214/// vs2 has a set bit. All other active destination bits are cleared.
215///
216/// # Safety
217/// Same as [`execute_vmsbf`].
218#[inline(always)]
219#[doc(hidden)]
220pub unsafe fn execute_vmsof<Reg, ExtState, CustomError>(
221    ext_state: &mut ExtState,
222    vd: VReg,
223    vs2: VReg,
224    vm: bool,
225    vl: u32,
226) where
227    Reg: Register,
228    ExtState: VectorRegistersExt<Reg, CustomError>,
229    [(); ExtState::ELEN as usize]:,
230    [(); ExtState::VLEN as usize]:,
231    [(); ExtState::VLENB as usize]:,
232    CustomError: fmt::Debug,
233{
234    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
235    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
236    let vs2_snap = *ext_state.read_vregs().get(vs2);
237    let mut found_first = false;
238    for i in 0..vl {
239        if !mask_bit(&mask_buf, i) {
240            continue;
241        }
242        let vs2_bit = mask_bit(&vs2_snap, i);
243        // vmsof: set only the first set bit position; clear all others (including after first)
244        let result = !found_first && vs2_bit;
245        if vs2_bit && !found_first {
246            found_first = true;
247        }
248        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
249        unsafe {
250            write_mask_bit(ext_state.write_vregs(), vd, i, result);
251        }
252    }
253    ext_state.mark_vs_dirty();
254    // vstart is already zero, doesn't need to be reset
255}
256
257/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
258///
259/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
260/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
261/// has been seen and the current element is past it.
262///
263/// # Safety
264/// Same as [`execute_vmsbf`].
265#[inline(always)]
266#[doc(hidden)]
267pub unsafe fn execute_vmsif<Reg, ExtState, CustomError>(
268    ext_state: &mut ExtState,
269    vd: VReg,
270    vs2: VReg,
271    vm: bool,
272    vl: u32,
273) where
274    Reg: Register,
275    ExtState: VectorRegistersExt<Reg, CustomError>,
276    [(); ExtState::ELEN as usize]:,
277    [(); ExtState::VLEN as usize]:,
278    [(); ExtState::VLENB as usize]:,
279    CustomError: fmt::Debug,
280{
281    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
282    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
283    let vs2_snap = *ext_state.read_vregs().get(vs2);
284    let mut found_first = false;
285    for i in 0..vl {
286        if !mask_bit(&mask_buf, i) {
287            continue;
288        }
289        let vs2_bit = mask_bit(&vs2_snap, i);
290        // vmsif: set bits up to *and including* the first set bit; clear elements past it
291        let result = !found_first;
292        if vs2_bit {
293            found_first = true;
294        }
295        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
296        unsafe {
297            write_mask_bit(ext_state.write_vregs(), vd, i, result);
298        }
299    }
300    ext_state.mark_vs_dirty();
301    // vstart is already zero, doesn't need to be reset
302}
303
304/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2 at
305/// positions `0..i` (strictly before `i`) as a SEW-wide integer into `vd[i]`.
306///
307/// Per spec §16.8: this instruction honors the source mask; inactive mask elements of vs2 are
308/// treated as zero for the prefix sum. Inactive destination elements follow the mask-agnostic
309/// policy (here implemented as undisturbed, which is a permitted realisation).
310///
311/// If SEW is too narrow to hold the prefix count, the value wraps (truncates to SEW) via
312/// [`write_element_u64()`]; the spec does not raise an exception for this case.
313///
314/// The caller must reject `vstart != 0` before invocation (spec §16.8 mandatory trap).
315///
316/// # Safety
317/// - `vd` does not overlap `vs2` (checked by caller)
318/// - `vm=false` implies `vd != v0` (checked by caller)
319/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
320/// - `vl <= VLMAX`; `vl <= VLEN`
321#[inline(always)]
322#[doc(hidden)]
323pub unsafe fn execute_viota<Reg, ExtState, CustomError>(
324    ext_state: &mut ExtState,
325    vd: VReg,
326    vs2: VReg,
327    vm: bool,
328    vl: u32,
329    sew: Vsew,
330) where
331    Reg: Register,
332    ExtState: VectorRegistersExt<Reg, CustomError>,
333    [(); ExtState::ELEN as usize]:,
334    [(); ExtState::VLEN as usize]:,
335    [(); ExtState::VLENB as usize]:,
336    CustomError: fmt::Debug,
337{
338    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
339    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
340    let vs2_snap = *ext_state.read_vregs().get(vs2);
341    // Per spec §16.8: inactive vs2 elements are treated as zero for the prefix sum.
342    // The prefix count advances only when the execution mask is active AND the
343    // corresponding vs2 bit is set.
344    let mut prefix_count = 0u64;
345    for i in 0..vl {
346        if !mask_bit(&mask_buf, i) {
347            continue;
348        }
349        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
350        unsafe {
351            write_element_u64(ext_state.write_vregs(), vd, i, sew, prefix_count);
352        }
353        if mask_bit(&vs2_snap, i) {
354            prefix_count += 1;
355        }
356    }
357    ext_state.mark_vs_dirty();
358    ext_state.reset_vstart();
359}
360
361/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
362/// active element in `vstart..vl`.
363///
364/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
365///
366/// # Safety
367/// - `vm=false` implies `vd != v0` (checked by caller)
368/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (checked by caller)
369/// - `vl <= group_regs * VLENB / sew_bytes`
370/// - `vl <= VLEN`
371#[inline(always)]
372#[doc(hidden)]
373pub unsafe fn execute_vid<Reg, ExtState, CustomError>(
374    ext_state: &mut ExtState,
375    vd: VReg,
376    vm: bool,
377    sew: Vsew,
378) where
379    Reg: Register,
380    ExtState: VectorRegistersExt<Reg, CustomError>,
381    [(); ExtState::ELEN as usize]:,
382    [(); ExtState::VLEN as usize]:,
383    [(); ExtState::VLENB as usize]:,
384    CustomError: fmt::Debug,
385{
386    let vl = ext_state.vl();
387    let vstart = ext_state.vstart();
388    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
389    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
390    for i in u32::from(vstart)..vl {
391        if !mask_bit(&mask_buf, i) {
392            continue;
393        }
394        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
395        unsafe {
396            write_element_u64(ext_state.write_vregs(), vd, i, sew, u64::from(i));
397        }
398    }
399    ext_state.mark_vs_dirty();
400    ext_state.reset_vstart();
401}