Skip to main content

ab_riscv_interpreter/v/zve64x/mask/
zve64x_mask_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::arith::zve64x_arith_helpers::{write_element_u64, write_mask_bit};
5use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
6use crate::{InterpreterState, ProgramCounter, VirtualMemory};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::registers::general_purpose::Register;
9use ab_riscv_primitives::registers::vector::VReg;
10use core::fmt;
11
12/// Execute a mask-register logical operation on the full `VLENB`-byte mask registers.
13///
14/// Operates on the entire register width independent of `vl` or `vtype`, per spec §16.1.
15/// `op` receives `(vs2_byte: u8, vs1_byte: u8) -> u8`.
16///
17/// # Safety
18/// `vd`, `vs2`, and `vs1` are valid register indices (guaranteed by `VReg` type).
19/// The operation snaps both sources before writing, so `vd` may safely overlap either source.
20#[inline(always)]
21#[doc(hidden)]
22pub unsafe fn execute_mask_logical_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
23    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
24    vd: VReg,
25    vs2: VReg,
26    vs1: VReg,
27    op: F,
28) where
29    Reg: Register,
30    [(); Reg::N]:,
31    ExtState: VectorRegistersExt<Reg, CustomError>,
32    [(); ExtState::ELEN as usize]:,
33    [(); ExtState::VLEN as usize]:,
34    [(); ExtState::VLENB as usize]:,
35    Memory: VirtualMemory,
36    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
37    CustomError: fmt::Debug,
38    F: Fn(u8, u8) -> u8,
39{
40    // Snapshot both sources before writing to handle vd overlapping vs2 or vs1
41    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
42    let vs2_snap = *unsafe {
43        state
44            .ext_state
45            .read_vreg()
46            .get_unchecked(usize::from(vs2.bits()))
47    };
48    // SAFETY: `vs1.bits() < 32`; `VReg` values are always in `0..32`
49    let vs1_snap = *unsafe {
50        state
51            .ext_state
52            .read_vreg()
53            .get_unchecked(usize::from(vs1.bits()))
54    };
55    // SAFETY: `vd.bits() < 32`; `VReg` values are always in `0..32`
56    let vd_reg = unsafe {
57        state
58            .ext_state
59            .write_vreg()
60            .get_unchecked_mut(usize::from(vd.bits()))
61    };
62    for byte_i in 0..ExtState::VLENB as usize {
63        // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
64        // `vs2_snap` / `vs1_snap` are `[u8; VLENB]` arrays
65        let a = unsafe { *vs2_snap.get_unchecked(byte_i) };
66        // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
67        // `vs2_snap` / `vs1_snap` are `[u8; VLENB]` arrays
68        let b = unsafe { *vs1_snap.get_unchecked(byte_i) };
69        // SAFETY: `byte_i < VLENB` because the loop bound is `ExtState::VLENB`, and
70        // `vd_reg` points to a `[u8; VLENB]` register row
71        *unsafe { vd_reg.get_unchecked_mut(byte_i) } = op(a, b);
72    }
73    state.ext_state.mark_vs_dirty();
74    state.ext_state.reset_vstart();
75}
76
77/// Execute `vcpop.m`: count set bits in vs2 for active elements `0..vl`, write result to `rd`.
78///
79/// Per spec §16.2: `rd` receives the number of mask bits set in `vs2`, considering only elements
80/// `vstart..vl` that are active under the mask. For elements `< vstart`, they are not counted.
81///
82/// # Safety
83/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
84/// - `vstart <= vl`
85#[inline(always)]
86#[doc(hidden)]
87pub unsafe fn execute_vcpop<Reg, ExtState, Memory, PC, IH, CustomError>(
88    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
89    rd: Reg,
90    vs2: VReg,
91    vm: bool,
92    vl: u32,
93    vstart: u32,
94) where
95    Reg: Register,
96    [(); Reg::N]:,
97    ExtState: VectorRegistersExt<Reg, CustomError>,
98    [(); ExtState::ELEN as usize]:,
99    [(); ExtState::VLEN as usize]:,
100    [(); ExtState::VLENB as usize]:,
101    Memory: VirtualMemory,
102    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
103    CustomError: fmt::Debug,
104{
105    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
106    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
107    // SAFETY: `vs2` is a valid VReg index (< 32)
108    let vs2_reg = *unsafe {
109        state
110            .ext_state
111            .read_vreg()
112            .get_unchecked(usize::from(vs2.bits()))
113    };
114    let mut count = 0u32;
115    for i in vstart..vl {
116        if !mask_bit(&mask_buf, i) {
117            continue;
118        }
119        if mask_bit(&vs2_reg, i) {
120            count += 1;
121        }
122    }
123    state.regs.write(rd, Reg::Type::from(count));
124    state.ext_state.mark_vs_dirty();
125    state.ext_state.reset_vstart();
126}
127
128/// Execute `vfirst.m`: find the index of the first set bit in vs2 for active elements `0..vl`,
129/// write result (or -1 if none) to `rd`.
130///
131/// Per spec §16.3: `rd` receives the element index of the lowest-numbered active set bit, or
132/// `-1` (all-ones) if no active element of vs2 is set.
133///
134/// # Safety
135/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
136/// - `vstart <= vl`
137#[inline(always)]
138#[doc(hidden)]
139pub unsafe fn execute_vfirst<Reg, ExtState, Memory, PC, IH, CustomError>(
140    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
141    rd: Reg,
142    vs2: VReg,
143    vm: bool,
144    vl: u32,
145    vstart: u32,
146) where
147    Reg: Register,
148    [(); Reg::N]:,
149    ExtState: VectorRegistersExt<Reg, CustomError>,
150    [(); ExtState::ELEN as usize]:,
151    [(); ExtState::VLEN as usize]:,
152    [(); ExtState::VLENB as usize]:,
153    Memory: VirtualMemory,
154    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
155    CustomError: fmt::Debug,
156{
157    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
158    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
159    // SAFETY: `vs2` is a valid VReg index (< 32)
160    let vs2_reg = *unsafe {
161        state
162            .ext_state
163            .read_vreg()
164            .get_unchecked(usize::from(vs2.bits()))
165    };
166    // -1 encoded as all-ones for the register width; `Into<u64>` on XLEN-wide type then back
167    let not_found = u64::MAX;
168    let mut result = not_found;
169    for i in vstart..vl {
170        if !mask_bit(&mask_buf, i) {
171            continue;
172        }
173        if mask_bit(&vs2_reg, i) {
174            result = u64::from(i);
175            break;
176        }
177    }
178    // Write -1 (all-ones for XLEN bits) or the found index.
179    // The spec requires -1 as a signed XLEN-wide value, meaning all bits set.
180    // `!Reg::Type::from(0)` produces all-ones for both u32 (RV32) and u64 (RV64)
181    // without depending on `From<u64>` (which is not in the `Register` trait bounds).
182    // For the found index, element indices fit in u32 since vl <= VLEN <= 2^32.
183    let reg_value = if result == not_found {
184        !Reg::Type::from(0u8)
185    } else {
186        Reg::Type::from(result as u32)
187    };
188    state.regs.write(rd, reg_value);
189    state.ext_state.mark_vs_dirty();
190    state.ext_state.reset_vstart();
191}
192
193/// Execute `vmsbf.m`: set all mask bits before (not including) the first set bit of vs2.
194///
195/// Per spec §16.4: for each element `i` in `vstart..vl`, if no prior active set bit exists in
196/// vs2, the destination bit is set; once the first set bit in vs2 is encountered, all subsequent
197/// destination bits are cleared.
198///
199/// Inactive elements (masked off) are left undisturbed. Tail elements are undisturbed.
200///
201/// # Safety
202/// - `vd` does not overlap `vs2` (checked by caller)
203/// - `vm=false` implies `vd != v0` (checked by caller)
204/// - `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
205#[inline(always)]
206#[doc(hidden)]
207pub unsafe fn execute_vmsbf<Reg, ExtState, Memory, PC, IH, CustomError>(
208    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
209    vd: VReg,
210    vs2: VReg,
211    vm: bool,
212    vl: u32,
213    vstart: u32,
214) where
215    Reg: Register,
216    [(); Reg::N]:,
217    ExtState: VectorRegistersExt<Reg, CustomError>,
218    [(); ExtState::ELEN as usize]:,
219    [(); ExtState::VLEN as usize]:,
220    [(); ExtState::VLENB as usize]:,
221    Memory: VirtualMemory,
222    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
223    CustomError: fmt::Debug,
224{
225    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
226    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
227    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
228    let vs2_snap = *unsafe {
229        state
230            .ext_state
231            .read_vreg()
232            .get_unchecked(usize::from(vs2.bits()))
233    };
234    let mut found_first = false;
235    for i in vstart..vl {
236        // Inactive elements: undisturbed
237        if !mask_bit(&mask_buf, i) {
238            continue;
239        }
240        let vs2_bit = mask_bit(&vs2_snap, i);
241        // vmsbf: set bits strictly *before* the first set bit; clear from first set bit onward
242        let result = !found_first && !vs2_bit;
243        if vs2_bit {
244            found_first = true;
245        }
246        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
247        unsafe { write_mask_bit(state.ext_state.write_vreg(), vd, i, result) };
248    }
249    state.ext_state.mark_vs_dirty();
250    state.ext_state.reset_vstart();
251}
252
253/// Execute `vmsof.m`: set only the first set bit position of vs2, clear all others.
254///
255/// Per spec §16.5: the destination bit is set only at the lowest-numbered active element where
256/// vs2 has a set bit. All other active destination bits are cleared.
257///
258/// # Safety
259/// Same as [`execute_vmsbf`].
260#[inline(always)]
261#[doc(hidden)]
262pub unsafe fn execute_vmsof<Reg, ExtState, Memory, PC, IH, CustomError>(
263    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
264    vd: VReg,
265    vs2: VReg,
266    vm: bool,
267    vl: u32,
268    vstart: u32,
269) where
270    Reg: Register,
271    [(); Reg::N]:,
272    ExtState: VectorRegistersExt<Reg, CustomError>,
273    [(); ExtState::ELEN as usize]:,
274    [(); ExtState::VLEN as usize]:,
275    [(); ExtState::VLENB as usize]:,
276    Memory: VirtualMemory,
277    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
278    CustomError: fmt::Debug,
279{
280    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
281    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
282    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
283    let vs2_snap = *unsafe {
284        state
285            .ext_state
286            .read_vreg()
287            .get_unchecked(usize::from(vs2.bits()))
288    };
289    let mut found_first = false;
290    for i in vstart..vl {
291        if !mask_bit(&mask_buf, i) {
292            continue;
293        }
294        let vs2_bit = mask_bit(&vs2_snap, i);
295        // vmsof: set only the first set bit position; clear all others (including after first)
296        let result = !found_first && vs2_bit;
297        if vs2_bit && !found_first {
298            found_first = true;
299        }
300        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
301        unsafe { write_mask_bit(state.ext_state.write_vreg(), vd, i, result) };
302    }
303    state.ext_state.mark_vs_dirty();
304    state.ext_state.reset_vstart();
305}
306
307/// Execute `vmsif.m`: set all mask bits up to and including the first set bit of vs2.
308///
309/// Per spec §16.6: for each active element, the destination bit is set if no prior active set bit
310/// in vs2 has been seen yet *or* the current element itself is set; it is cleared once a set bit
311/// has been seen and the current element is past it.
312///
313/// # Safety
314/// Same as [`execute_vmsbf`].
315#[inline(always)]
316#[doc(hidden)]
317pub unsafe fn execute_vmsif<Reg, ExtState, Memory, PC, IH, CustomError>(
318    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
319    vd: VReg,
320    vs2: VReg,
321    vm: bool,
322    vl: u32,
323    vstart: u32,
324) where
325    Reg: Register,
326    [(); Reg::N]:,
327    ExtState: VectorRegistersExt<Reg, CustomError>,
328    [(); ExtState::ELEN as usize]:,
329    [(); ExtState::VLEN as usize]:,
330    [(); ExtState::VLENB as usize]:,
331    Memory: VirtualMemory,
332    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
333    CustomError: fmt::Debug,
334{
335    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
336    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
337    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
338    let vs2_snap = *unsafe {
339        state
340            .ext_state
341            .read_vreg()
342            .get_unchecked(usize::from(vs2.bits()))
343    };
344    let mut found_first = false;
345    for i in vstart..vl {
346        if !mask_bit(&mask_buf, i) {
347            continue;
348        }
349        let vs2_bit = mask_bit(&vs2_snap, i);
350        // vmsif: set bits up to *and including* the first set bit; clear elements past it
351        let result = !found_first;
352        if vs2_bit {
353            found_first = true;
354        }
355        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
356        unsafe { write_mask_bit(state.ext_state.write_vreg(), vd, i, result) };
357    }
358    state.ext_state.mark_vs_dirty();
359    state.ext_state.reset_vstart();
360}
361
362/// Execute `viota.m`: for each active element `i`, write the popcount of set bits in vs2
363/// at positions `0..i` (i.e. strictly before `i`) as a SEW-wide integer into `vd[i]`.
364///
365/// Per spec §16.8: inactive elements are handled according to the mask/tail agnostic policy.
366/// Here we use mask-undisturbed (inactive elements left unchanged).
367///
368/// # Safety
369/// - `vd` does not overlap `vs2` (checked by caller)
370/// - `vm=false` implies `vd != v0` (checked by caller)
371/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (checked by caller)
372/// - `vl <= group_regs * VLENB / sew_bytes`
373/// - `vl <= VLEN`
374#[inline(always)]
375#[doc(hidden)]
376pub unsafe fn execute_viota<Reg, ExtState, Memory, PC, IH, CustomError>(
377    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
378    vd: VReg,
379    vs2: VReg,
380    vm: bool,
381    vl: u32,
382    vstart: u32,
383    sew: Vsew,
384) where
385    Reg: Register,
386    [(); Reg::N]:,
387    ExtState: VectorRegistersExt<Reg, CustomError>,
388    [(); ExtState::ELEN as usize]:,
389    [(); ExtState::VLEN as usize]:,
390    [(); ExtState::VLENB as usize]:,
391    Memory: VirtualMemory,
392    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
393    CustomError: fmt::Debug,
394{
395    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
396    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
397    // SAFETY: `vs2.bits() < 32`; `VReg` values are always in `0..32`
398    let vs2_snap = *unsafe {
399        state
400            .ext_state
401            .read_vreg()
402            .get_unchecked(usize::from(vs2.bits()))
403    };
404    // Prefix popcount over the full vs2 mask, regardless of the execution mask (per spec §16.8:
405    // the prefix sum counts *all* preceding vs2 bits, not just active ones).
406    let mut prefix_count = 0;
407    // We need to compute prefix counts for *all* positions up to vl, updating as we go.
408    // For elements before vstart, we still need to advance prefix_count.
409    for i in 0..vl {
410        let is_active = mask_bit(&mask_buf, i);
411        if i >= vstart && is_active {
412            // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
413            unsafe {
414                write_element_u64(
415                    state.ext_state.write_vreg(),
416                    vd.bits(),
417                    i,
418                    sew,
419                    prefix_count,
420                );
421            }
422        }
423        // Advance prefix count unconditionally for all elements (including inactive ones):
424        // the prefix sum counts set bits in vs2 regardless of masking, per spec.
425        if mask_bit(&vs2_snap, i) {
426            prefix_count += 1;
427        }
428    }
429    state.ext_state.mark_vs_dirty();
430    state.ext_state.reset_vstart();
431}
432
433/// Execute `vid.v`: write the element index `i` as a SEW-wide integer into `vd[i]` for each
434/// active element in `vstart..vl`.
435///
436/// Per spec §16.9: inactive elements are left undisturbed (mask-undisturbed policy).
437///
438/// # Safety
439/// - `vm=false` implies `vd != v0` (checked by caller)
440/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (checked by caller)
441/// - `vl <= group_regs * VLENB / sew_bytes`
442/// - `vl <= VLEN`
443#[inline(always)]
444#[doc(hidden)]
445pub unsafe fn execute_vid<Reg, ExtState, Memory, PC, IH, CustomError>(
446    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
447    vd: VReg,
448    vm: bool,
449    vl: u32,
450    vstart: u32,
451    sew: Vsew,
452) where
453    Reg: Register,
454    [(); Reg::N]:,
455    ExtState: VectorRegistersExt<Reg, CustomError>,
456    [(); ExtState::ELEN as usize]:,
457    [(); ExtState::VLEN as usize]:,
458    [(); ExtState::VLENB as usize]:,
459    Memory: VirtualMemory,
460    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
461    CustomError: fmt::Debug,
462{
463    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
464    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
465    for i in vstart..vl {
466        if !mask_bit(&mask_buf, i) {
467            continue;
468        }
469        // SAFETY: `vd + i / elems_per_reg < 32` by caller's alignment + vl preconditions
470        unsafe {
471            write_element_u64(
472                state.ext_state.write_vreg(),
473                vd.bits(),
474                i,
475                sew,
476                u64::from(i),
477            );
478        }
479    }
480    state.ext_state.mark_vs_dirty();
481    state.ext_state.reset_vstart();
482}