Skip to main content

ab_riscv_interpreter/v/zve64x/mask/
zve64x_mask_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::RegisterFile;
4use crate::v::vector_registers::VectorRegistersExt;
5use crate::v::zve64x::arith::zve64x_arith_helpers::{write_element_u64, write_mask_bit};
6use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Execute a mask-register logical operation on the full `VLENB`-byte mask registers.
11///
12/// Operates on the entire register width independent of `vl` or `vtype`, per spec §16.1.
13/// `op` receives `(vs2_byte: u8, vs1_byte: u8) -> u8`.
14///
15/// # Safety
16/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg` type).
17/// The operation snaps both sources before writing, so `vd` may safely overlap either source.
18#[inline(always)]
19#[doc(hidden)]
20pub unsafe fn execute_mask_logical_op<Reg, ExtState, CustomError, F>(
21    ext_state: &mut ExtState,
22    vd: VReg,
23    vs2: VReg,
24    vs1: VReg,
25    op: F,
26) where
27    Reg: Register,
28    ExtState: VectorRegistersExt<Reg, CustomError>,
29    [(); ExtState::ELEN as usize]:,
30    [(); ExtState::VLEN as usize]:,
31    [(); ExtState::VLENB as usize]:,
32    CustomError: fmt::Debug,
33    F: Fn(u8, u8) -> u8,
34{
35    // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
36    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
37    let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
38    // SAFETY: `vs1.bits() < 32`; `VReg` values are always in `0..32`
39    let vs1_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs1.bits())) };
40    // SAFETY: `vd.bits() < 32`; `VReg` values are always in `0..32`
41    let vd_reg = unsafe {
42        ext_state
43            .write_vreg()
44            .get_unchecked_mut(usize::from(vd.bits()))
45    };
46    for byte_i in 0..ExtState::VLENB as usize {
47        // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
48        // `vs2_snap` / `vs1_snap` are `[u8; VLENB]` arrays
49        let a = unsafe { *vs2_snap.get_unchecked(byte_i) };
50        // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
51        // `vs2_snap` / `vs1_snap` are `[u8; VLENB]` arrays
52        let b = unsafe { *vs1_snap.get_unchecked(byte_i) };
53        // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
54        // `vd_reg` points to a `[u8; VLENB]` register row
55        *unsafe { vd_reg.get_unchecked_mut(byte_i) } = op(a, b);
56    }
57    ext_state.mark_vs_dirty();
58    ext_state.reset_vstart();
59}
60
61/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
62///
63/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
64/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
65///
66/// # Safety
67/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
68/// - `vstart <= vl`
69#[inline(always)]
70#[doc(hidden)]
71pub unsafe fn execute_vcpop<Reg, Regs, ExtState, CustomError>(
72    regs: &mut Regs,
73    ext_state: &mut ExtState,
74    rd: Reg,
75    vs2: VReg,
76    vm: bool,
77    vl: u32,
78    vstart: u32,
79) where
80    Reg: Register,
81    Regs: RegisterFile<Reg>,
82    ExtState: VectorRegistersExt<Reg, CustomError>,
83    [(); ExtState::ELEN as usize]:,
84    [(); ExtState::VLEN as usize]:,
85    [(); ExtState::VLENB as usize]:,
86    CustomError: fmt::Debug,
87{
88    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
89    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
90    // SAFETY: `vs2` is a valid VReg index (< 32)
91    let vs2_reg = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
92    let mut count = 0u32;
93    for i in vstart..vl {
94        if !mask_bit(&mask_buf, i) {
95            continue;
96        }
97        if mask_bit(&vs2_reg, i) {
98            count += 1;
99        }
100    }
101    regs.write(rd, Reg::Type::from(count));
102    ext_state.mark_vs_dirty();
103    ext_state.reset_vstart();
104}
105
106/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
107/// write result (or -1 if none) to `rd`.
108///
109/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
110/// `-1` (all-ones) if no active element of vs2 is set.
111///
112/// # Safety
113/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
114/// - `vstart <= vl`
115#[inline(always)]
116#[doc(hidden)]
117pub unsafe fn execute_vfirst<Reg, Regs, ExtState, CustomError>(
118    regs: &mut Regs,
119    ext_state: &mut ExtState,
120    rd: Reg,
121    vs2: VReg,
122    vm: bool,
123    vl: u32,
124    vstart: u32,
125) where
126    Reg: Register,
127    Regs: RegisterFile<Reg>,
128    ExtState: VectorRegistersExt<Reg, CustomError>,
129    [(); ExtState::ELEN as usize]:,
130    [(); ExtState::VLEN as usize]:,
131    [(); ExtState::VLENB as usize]:,
132    CustomError: fmt::Debug,
133{
134    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
135    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
136    // SAFETY: `vs2` is a valid VReg index (< 32)
137    let vs2_reg = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
138    // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
139    let not_found = u64::MAX;
140    let mut result = not_found;
141    for i in vstart..vl {
142        if !mask_bit(&mask_buf, i) {
143            continue;
144        }
145        if mask_bit(&vs2_reg, i) {
146            result = u64::from(i);
147            break;
148        }
149    }
150    // Write -1 (all-ones for XLEN bits) or the found index.
151    // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
152    // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
153    // without depending on `From<u64>` (which is not in the `Register` trait bounds).
154    // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
155    let reg_value = if result == not_found {
156        !Reg::Type::from(0u8)
157    } else {
158        Reg::Type::from(result as u32)
159    };
160    regs.write(rd, reg_value);
161    ext_state.mark_vs_dirty();
162    ext_state.reset_vstart();
163}
164
165/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
166///
167/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
168/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
169/// destination bits are cleared.
170///
171/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
172///
173/// # Safety
174/// - `vd` does not overlap `vs2` (checked by caller)
175/// - `vm=false` implies `vd != v0` (checked by caller)
176/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
177#[inline(always)]
178#[doc(hidden)]
179pub unsafe fn execute_vmsbf<Reg, ExtState, CustomError>(
180    ext_state: &mut ExtState,
181    vd: VReg,
182    vs2: VReg,
183    vm: bool,
184    vl: u32,
185) where
186    Reg: Register,
187    ExtState: VectorRegistersExt<Reg, CustomError>,
188    [(); ExtState::ELEN as usize]:,
189    [(); ExtState::VLEN as usize]:,
190    [(); ExtState::VLENB as usize]:,
191    CustomError: fmt::Debug,
192{
193    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
194    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
195    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
196    let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
197    let mut found_first = false;
198    for i in 0..vl {
199        // Inactive elements: undisturbed
200        if !mask_bit(&mask_buf, i) {
201            continue;
202        }
203        let vs2_bit = mask_bit(&vs2_snap, i);
204        // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
205        let result = !found_first && !vs2_bit;
206        if vs2_bit {
207            found_first = true;
208        }
209        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
210        unsafe { write_mask_bit(ext_state.write_vreg(), vd, i, result) };
211    }
212    ext_state.mark_vs_dirty();
213    // vstart is already zero, doesn't need to be reset
214}
215
216/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
217///
218/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
219/// vs2 has a set bit. All other active destination bits are cleared.
220///
221/// # Safety
222/// Same as [`execute_vmsbf`].
223#[inline(always)]
224#[doc(hidden)]
225pub unsafe fn execute_vmsof<Reg, ExtState, CustomError>(
226    ext_state: &mut ExtState,
227    vd: VReg,
228    vs2: VReg,
229    vm: bool,
230    vl: u32,
231) where
232    Reg: Register,
233    ExtState: VectorRegistersExt<Reg, CustomError>,
234    [(); ExtState::ELEN as usize]:,
235    [(); ExtState::VLEN as usize]:,
236    [(); ExtState::VLENB as usize]:,
237    CustomError: fmt::Debug,
238{
239    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
240    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
241    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
242    let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
243    let mut found_first = false;
244    for i in 0..vl {
245        if !mask_bit(&mask_buf, i) {
246            continue;
247        }
248        let vs2_bit = mask_bit(&vs2_snap, i);
249        // vmsof: set only the first set bit position; clear all others (including after first)
250        let result = !found_first && vs2_bit;
251        if vs2_bit && !found_first {
252            found_first = true;
253        }
254        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
255        unsafe { write_mask_bit(ext_state.write_vreg(), vd, i, result) };
256    }
257    ext_state.mark_vs_dirty();
258    // vstart is already zero, doesn't need to be reset
259}
260
261/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
262///
263/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
264/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
265/// has been seen and the current element is past it.
266///
267/// # Safety
268/// Same as [`execute_vmsbf`].
269#[inline(always)]
270#[doc(hidden)]
271pub unsafe fn execute_vmsif<Reg, ExtState, CustomError>(
272    ext_state: &mut ExtState,
273    vd: VReg,
274    vs2: VReg,
275    vm: bool,
276    vl: u32,
277) where
278    Reg: Register,
279    ExtState: VectorRegistersExt<Reg, CustomError>,
280    [(); ExtState::ELEN as usize]:,
281    [(); ExtState::VLEN as usize]:,
282    [(); ExtState::VLENB as usize]:,
283    CustomError: fmt::Debug,
284{
285    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
286    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
287    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
288    let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
289    let mut found_first = false;
290    for i in 0..vl {
291        if !mask_bit(&mask_buf, i) {
292            continue;
293        }
294        let vs2_bit = mask_bit(&vs2_snap, i);
295        // vmsif: set bits up to *and including* the first set bit; clear elements past it
296        let result = !found_first;
297        if vs2_bit {
298            found_first = true;
299        }
300        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
301        unsafe { write_mask_bit(ext_state.write_vreg(), vd, i, result) };
302    }
303    ext_state.mark_vs_dirty();
304    // vstart is already zero, doesn't need to be reset
305}
306
307/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2 at
308/// positions `0..i` (strictly before `i`) as a SEW-wide integer into `vd[i]`.
309///
310/// Per spec §16.8: this instruction honors the source mask; inactive mask elements of vs2 are
311/// treated as zero for the prefix sum. Inactive destination elements follow the mask-agnostic
312/// policy (here implemented as undisturbed, which is a permitted realisation).
313///
314/// The caller must reject `vstart != 0` before invocation (spec §16.8 mandatory trap).
315///
316/// # Safety
317/// - `vd` does not overlap `vs2` (checked by caller)
318/// - `vm=false` implies `vd != v0` (checked by caller)
319/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (checked by caller)
320/// - SEW wide enough to hold VLMAX-1 (checked by caller)
321/// - `vl <= VLMAX`; `vl <= VLEN`
322#[inline(always)]
323#[doc(hidden)]
324pub unsafe fn execute_viota<Reg, ExtState, CustomError>(
325    ext_state: &mut ExtState,
326    vd: VReg,
327    vs2: VReg,
328    vm: bool,
329    vl: u32,
330    sew: Vsew,
331) where
332    Reg: Register,
333    ExtState: VectorRegistersExt<Reg, CustomError>,
334    [(); ExtState::ELEN as usize]:,
335    [(); ExtState::VLEN as usize]:,
336    [(); ExtState::VLENB as usize]:,
337    CustomError: fmt::Debug,
338{
339    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
340    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
341    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
342    let vs2_snap = *unsafe { ext_state.read_vreg().get_unchecked(usize::from(vs2.bits())) };
343    // Per spec §16.8: inactive vs2 elements are treated as zero for the prefix sum.
344    // The prefix count advances only when the execution mask is active AND the
345    // corresponding vs2 bit is set.
346    let mut prefix_count = 0u64;
347    for i in 0..vl {
348        if !mask_bit(&mask_buf, i) {
349            continue;
350        }
351        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
352        unsafe {
353            write_element_u64(ext_state.write_vreg(), vd.bits(), i, sew, prefix_count);
354        }
355        if mask_bit(&vs2_snap, i) {
356            prefix_count += 1;
357        }
358    }
359    ext_state.mark_vs_dirty();
360    ext_state.reset_vstart();
361}
362
363/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
364/// active element in `vstart..vl`.
365///
366/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
367///
368/// # Safety
369/// - `vm=false` implies `vd != v0` (checked by caller)
370/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (checked by caller)
371/// - `vl <= group_regs * VLENB / sew_bytes`
372/// - `vl <= VLEN`
373#[inline(always)]
374#[doc(hidden)]
375pub unsafe fn execute_vid<Reg, ExtState, CustomError>(
376    ext_state: &mut ExtState,
377    vd: VReg,
378    vm: bool,
379    vl: u32,
380    vstart: u32,
381    sew: Vsew,
382) where
383    Reg: Register,
384    ExtState: VectorRegistersExt<Reg, CustomError>,
385    [(); ExtState::ELEN as usize]:,
386    [(); ExtState::VLEN as usize]:,
387    [(); ExtState::VLENB as usize]:,
388    CustomError: fmt::Debug,
389{
390    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
391    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
392    for i in vstart..vl {
393        if !mask_bit(&mask_buf, i) {
394            continue;
395        }
396        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
397        unsafe {
398            write_element_u64(ext_state.write_vreg(), vd.bits(), i, sew, u64::from(i));
399        }
400    }
401    ext_state.mark_vs_dirty();
402    ext_state.reset_vstart();
403}