Skip to main content

ab_riscv_interpreter/v/zvexx/muldiv/
zvexx_muldiv_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5    OpSrc, check_vreg_group_alignment, sew_mask, sign_extend,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{read_element_u64, write_element_u64};
8use crate::v::zvexx::fixed_point::zvexx_fixed_point_helpers::read_wide_element_u64;
9use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
10use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
11use crate::{ExecutionError, ProgramCounter};
12use ab_riscv_primitives::prelude::*;
13use core::fmt;
14use core::hint::cold_path;
15
16/// Compute the destination register count for a widening operation (`EMUL = 2 × LMUL`).
17///
18/// Returns `None` when the resulting EMUL falls outside the legal range `[1/8, 8]`, i.e. when
19/// `LMUL` is already `M8` (EMUL would be 16) or the caller asks for a multiplication factor that
20/// pushes the fraction past the legal lower bound.
21///
22/// The register count returned is `max(1, EMUL)`: fractional EMUL values (1/2, 1/4) still occupy
23/// exactly one physical register.
24#[inline(always)]
25#[doc(hidden)]
26pub fn widening_dest_register_count(vlmul: Vlmul) -> Option<u8> {
27    let (lmul_num, lmul_den) = vlmul.as_fraction();
28    // EMUL = 2 × LMUL = (2 * lmul_num) / lmul_den
29    let Some(emul_num) = 2u8.checked_mul(lmul_num) else {
30        cold_path();
31        return None;
32    };
33    let emul_den = lmul_den;
34    // Reduce the fraction by GCD (both are powers of two so min works as GCD)
35    let g = emul_num.min(emul_den);
36    let (n, d) = (emul_num / g, emul_den / g);
37    // Legal EMUL fractions: 1/8, 1/4, 1/2, 1, 2, 4, 8
38    let legal = matches!(
39        (n, d),
40        (1, 8) | (1, 4) | (1, 2) | (1, 1) | (2, 1) | (4, 1) | (8, 1)
41    );
42    if !legal {
43        cold_path();
44        return None;
45    }
46    // Register count: max(1, n/d) = n when d==1, else 1
47    Some(if d > 1 { 1 } else { n })
48}
49
50/// Check that a narrower source register group does not *illegally* overlap the wider destination
51/// group of a widening instruction.
52///
53/// For widening instructions `vd` occupies `dest_group_regs` registers (which is
54/// [`widening_dest_register_count()`] of the source LMUL); `vs` occupies `src_group_regs`.
55///
56/// Per the RISC-V "V" spec §5.2, because the destination EEW (`2*SEW`) is greater than the source
57/// EEW (`SEW`), the source group *may* overlap the destination group, but only when both of the
58/// following hold:
59/// - the source EMUL is at least 1, and
60/// - the overlap is in the highest-numbered part of the destination register group, i.e. the source
61///   occupies exactly the top `src_group_regs` registers of the destination group.
62///
63/// When the source EMUL is at least 1, `dest_group_regs == 2 * src_group_regs`, so
64/// `dest_group_regs > src_group_regs` is an equivalent test for "source EMUL >= 1": a fractional
65/// source EMUL (`< 1`) yields `dest_group_regs == src_group_regs == 1`, in which case no overlap is
66/// ever legal. Any overlap that is not the legal "source in the highest-numbered part" form is
67/// rejected as an illegal instruction.
68#[inline(always)]
69#[doc(hidden)]
70pub fn check_no_widening_overlap<Reg, Memory, PC, CustomError>(
71    program_counter: &PC,
72    vd: VReg,
73    vs: VReg,
74    dest_group_regs: u8,
75    src_group_regs: u8,
76) -> Result<(), ExecutionError<Reg::Type, CustomError>>
77where
78    Reg: Register,
79    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
80{
81    let vd_start = vd.to_bits();
82    let vd_end = vd_start + dest_group_regs;
83    let vs_start = vs.to_bits();
84    let vs_end = vs_start + src_group_regs;
85    // Disjoint register groups are always fine
86    if vs_start >= vd_end || vd_start >= vs_end {
87        return Ok(());
88    }
89    // The groups overlap. This is legal only when the source EMUL is at least 1
90    // (`dest_group_regs > src_group_regs`) and the source occupies exactly the highest-numbered
91    // part of the destination group (`vs_start == vd_end - src_group_regs`).
92    if dest_group_regs > src_group_regs && vs_start == vd_end - src_group_regs {
93        return Ok(());
94    }
95
96    cold_path();
97    Err(ExecutionError::IllegalInstruction {
98        address: program_counter.old_pc(INSTRUCTION_SIZE),
99    })
100}
101
102/// Write a 2*SEW-wide element into the widened destination register group at element index
103/// `elem_i`.
104///
105/// # Safety
106/// `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32` must hold.
107#[inline(always)]
108unsafe fn write_wide_element_u64<const VLENB: usize>(
109    vregs: &mut VectorRegisterFile<VLENB>,
110    base_reg: VReg,
111    elem_i: u32,
112    sew: Vsew,
113    value: u64,
114) {
115    let wide_bytes = usize::from(sew.bytes_width()) * 2;
116    let elems_per_reg = VLENB / wide_bytes;
117    let reg_off = elem_i as usize / elems_per_reg;
118    let byte_off = (elem_i as usize % elems_per_reg) * wide_bytes;
119    let buf = value.to_le_bytes();
120    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
121    let reg = unsafe {
122        vregs.get_mut(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
123    };
124    // SAFETY: `byte_off + wide_bytes <= VLENB`; `wide_bytes <= 8` for SEW < 64
125    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + wide_bytes) };
126    // SAFETY: `wide_bytes <= 8` because SEW < 64 is enforced before widening ops are called
127    dst.copy_from_slice(unsafe { buf.get_unchecked(..wide_bytes) });
128}
129
130/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
131///
132/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result.
133/// Only the low `sew.bytes()` of the result are written back.
134///
135/// # Safety
136/// - `vd` and source register alignment verified by caller
137/// - `vl <= group_regs * VLENB / sew_bytes`
138/// - When `vm=false`: `vd.to_bits() != 0`
139#[inline(always)]
140#[doc(hidden)]
141pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
142    ext_state: &mut ExtState,
143    vd: VReg,
144    vs2: VReg,
145    src: OpSrc,
146    vm: bool,
147    sew: Vsew,
148    op: F,
149) where
150    Reg: Register,
151    ExtState: VectorRegistersExt<Reg, CustomError>,
152    [(); ExtState::ELEN as usize]:,
153    [(); ExtState::VLEN as usize]:,
154    [(); ExtState::VLENB as usize]:,
155    CustomError: fmt::Debug,
156    F: Fn(u64, u64, Vsew) -> u64,
157{
158    let vl = ext_state.vl();
159    let vstart = ext_state.vstart();
160    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
161    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
162    for i in u32::from(vstart)..vl {
163        if !mask_bit(&mask_buf, i) {
164            continue;
165        }
166        // SAFETY: register bounds verified by caller
167        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
168        let b = match src {
169            // SAFETY: register bounds verified by caller
170            OpSrc::Vreg(vs1_base) => unsafe {
171                read_element_u64(ext_state.read_vregs(), vs1_base, i, sew)
172            },
173            OpSrc::Scalar(val) => val,
174        };
175        let result = op(a, b, sew);
176        // SAFETY: register bounds verified by caller
177        unsafe {
178            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
179        }
180    }
181    ext_state.mark_vs_dirty();
182    ext_state.reset_vstart();
183}
184
185/// Execute a single-width widening operation over `vstart..vl`.
186///
187/// Reads SEW-wide elements from `vs2` and `src`, computes `op`, and writes a 2*SEW-wide result
188/// into `vd`.
189///
190/// # Safety
191/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
192///   and non-overlap verified by caller
193/// - `vl <= src_group_regs * VLENB / sew_bytes`
194/// - SEW < 64 verified by caller (so 2*SEW <= 64 and fits in u64)
195/// - When `vm=false`: `vd.to_bits() != 0`
196#[inline(always)]
197#[doc(hidden)]
198pub unsafe fn execute_widening_op<Reg, ExtState, CustomError, F>(
199    ext_state: &mut ExtState,
200    vd: VReg,
201    vs2: VReg,
202    src: OpSrc,
203    vm: bool,
204    sew: Vsew,
205    op: F,
206) where
207    Reg: Register,
208    ExtState: VectorRegistersExt<Reg, CustomError>,
209    [(); ExtState::ELEN as usize]:,
210    [(); ExtState::VLEN as usize]:,
211    [(); ExtState::VLENB as usize]:,
212    CustomError: fmt::Debug,
213    F: Fn(u64, u64, Vsew) -> u64,
214{
215    let vl = ext_state.vl();
216    let vstart = ext_state.vstart();
217    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
218    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
219    for i in u32::from(vstart)..vl {
220        if !mask_bit(&mask_buf, i) {
221            continue;
222        }
223        // SAFETY: register bounds verified by caller
224        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
225        let b = match src {
226            // SAFETY: register bounds verified by caller
227            OpSrc::Vreg(vs1_base) => unsafe {
228                read_element_u64(ext_state.read_vregs(), vs1_base, i, sew)
229            },
230            OpSrc::Scalar(val) => val,
231        };
232        let result = op(a, b, sew);
233        // SAFETY: vd has dest_group_regs registers; element `i` fits within them because
234        // `vl <= src_group_regs * VLENB / sew_bytes` and dest stores at 2*SEW width so
235        // `i < dest_group_regs * VLENB / (2*sew_bytes)`
236        unsafe {
237            write_wide_element_u64(ext_state.write_vregs(), vd, i, sew, result);
238        }
239    }
240    ext_state.mark_vs_dirty();
241    ext_state.reset_vstart();
242}
243
244/// Execute a single-width multiply-add where the first multiplier is a vector register group.
245///
246/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)` where `acc` is the current `vd[i]`,
247/// `a` is the element from `a_reg`, and `b` is the element from `src`. Returns the new `vd[i]`.
248///
249/// # Safety
250/// - `vd`, `a_reg`, and `src` register alignment verified by caller
251/// - `vl <= group_regs * VLENB / sew_bytes`
252/// - When `vm=false`: `vd.to_bits() != 0`
253#[inline(always)]
254#[doc(hidden)]
255pub unsafe fn execute_muladd_op<Reg, ExtState, CustomError, F>(
256    ext_state: &mut ExtState,
257    vd: VReg,
258    a_reg: VReg,
259    src: OpSrc,
260    vm: bool,
261    sew: Vsew,
262    op: F,
263) where
264    Reg: Register,
265    ExtState: VectorRegistersExt<Reg, CustomError>,
266    [(); ExtState::ELEN as usize]:,
267    [(); ExtState::VLEN as usize]:,
268    [(); ExtState::VLENB as usize]:,
269    CustomError: fmt::Debug,
270    F: Fn(u64, u64, u64, Vsew) -> u64,
271{
272    let vl = ext_state.vl();
273    let vstart = ext_state.vstart();
274    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
275    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
276    for i in u32::from(vstart)..vl {
277        if !mask_bit(&mask_buf, i) {
278            continue;
279        }
280        // SAFETY: register bounds verified by caller
281        let acc = unsafe { read_element_u64(ext_state.read_vregs(), vd, i, sew) };
282        // SAFETY: register bounds verified by caller
283        let a = unsafe { read_element_u64(ext_state.read_vregs(), a_reg, i, sew) };
284        let b = match src {
285            // SAFETY: register bounds verified by caller
286            OpSrc::Vreg(b_reg) => unsafe {
287                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
288            },
289            OpSrc::Scalar(val) => val,
290        };
291        let result = op(acc, a, b, sew);
292        // SAFETY: register bounds verified by caller
293        unsafe {
294            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
295        }
296    }
297    ext_state.mark_vs_dirty();
298    ext_state.reset_vstart();
299}
300
301/// Execute a single-width multiply-add where the first multiplier is a scalar.
302///
303/// Analogous to [`execute_muladd_op`] but `a` is a fixed scalar instead of a register element.
304///
305/// # Safety
306/// Same as [`execute_muladd_op`], minus constraints on `a_reg`.
307#[inline(always)]
308#[doc(hidden)]
309pub unsafe fn execute_muladd_scalar_op<Reg, ExtState, CustomError, F>(
310    ext_state: &mut ExtState,
311    vd: VReg,
312    scalar: u64,
313    src: OpSrc,
314    vm: bool,
315    sew: Vsew,
316    op: F,
317) where
318    Reg: Register,
319    ExtState: VectorRegistersExt<Reg, CustomError>,
320    [(); ExtState::ELEN as usize]:,
321    [(); ExtState::VLEN as usize]:,
322    [(); ExtState::VLENB as usize]:,
323    CustomError: fmt::Debug,
324    F: Fn(u64, u64, u64, Vsew) -> u64,
325{
326    let vl = ext_state.vl();
327    let vstart = ext_state.vstart();
328    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
329    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
330    for i in u32::from(vstart)..vl {
331        if !mask_bit(&mask_buf, i) {
332            continue;
333        }
334        // SAFETY: register bounds verified by caller
335        let acc = unsafe { read_element_u64(ext_state.read_vregs(), vd, i, sew) };
336        let b = match src {
337            // SAFETY: register bounds verified by caller
338            OpSrc::Vreg(b_reg) => unsafe {
339                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
340            },
341            OpSrc::Scalar(val) => val,
342        };
343        let result = op(acc, scalar, b, sew);
344        // SAFETY: register bounds verified by caller
345        unsafe {
346            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
347        }
348    }
349    ext_state.mark_vs_dirty();
350    ext_state.reset_vstart();
351}
352
353/// Execute a widening multiply-add where the first multiplier is a vector register group.
354///
355/// Reads SEW-wide `acc` from the widened `vd` group, SEW-wide `a` from `a_reg`, and SEW-wide
356/// `b` from `src`. Writes a 2*SEW-wide result back into `vd`.
357///
358/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)`.
359///
360/// # Safety
361/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
362///   and non-overlap verified by caller
363/// - SEW < 64 verified by caller
364/// - When `vm=false`: `vd.to_bits() != 0`
365#[inline(always)]
366#[doc(hidden)]
367pub unsafe fn execute_widening_muladd_op<Reg, ExtState, CustomError, F>(
368    ext_state: &mut ExtState,
369    vd: VReg,
370    a_reg: VReg,
371    src: OpSrc,
372    vm: bool,
373    sew: Vsew,
374    op: F,
375) where
376    Reg: Register,
377    ExtState: VectorRegistersExt<Reg, CustomError>,
378    [(); ExtState::ELEN as usize]:,
379    [(); ExtState::VLEN as usize]:,
380    [(); ExtState::VLENB as usize]:,
381    CustomError: fmt::Debug,
382    F: Fn(u64, u64, u64, Vsew) -> u64,
383{
384    let vl = ext_state.vl();
385    let vstart = ext_state.vstart();
386    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
387    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
388    for i in u32::from(vstart)..vl {
389        if !mask_bit(&mask_buf, i) {
390            continue;
391        }
392        // Read the existing 2*SEW accumulator from vd
393        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
394        // `execute_widening_op` for the bound argument)
395        let acc = unsafe { read_wide_element_u64(ext_state.read_vregs(), vd, i, sew) };
396        // SAFETY: register bounds verified by caller
397        let a = unsafe { read_element_u64(ext_state.read_vregs(), a_reg, i, sew) };
398        let b = match src {
399            // SAFETY: register bounds verified by caller
400            OpSrc::Vreg(b_reg) => unsafe {
401                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
402            },
403            OpSrc::Scalar(val) => val,
404        };
405        let result = op(acc, a, b, sew);
406        // SAFETY: same as acc read above
407        unsafe {
408            write_wide_element_u64(ext_state.write_vregs(), vd, i, sew, result);
409        }
410    }
411    ext_state.mark_vs_dirty();
412    ext_state.reset_vstart();
413}
414
415/// Execute a widening multiply-add where the first multiplier is a scalar.
416///
417/// Analogous to [`execute_widening_muladd_op`] but `a` is a fixed scalar.
418///
419/// # Safety
420/// Same as [`execute_widening_muladd_op`], minus constraints on `a_reg`.
421#[inline(always)]
422#[doc(hidden)]
423pub unsafe fn execute_widening_muladd_scalar_op<Reg, ExtState, CustomError, F>(
424    ext_state: &mut ExtState,
425    vd: VReg,
426    scalar: u64,
427    src: OpSrc,
428    vm: bool,
429    sew: Vsew,
430    op: F,
431) where
432    Reg: Register,
433    ExtState: VectorRegistersExt<Reg, CustomError>,
434    [(); ExtState::ELEN as usize]:,
435    [(); ExtState::VLEN as usize]:,
436    [(); ExtState::VLENB as usize]:,
437    CustomError: fmt::Debug,
438    F: Fn(u64, u64, u64, Vsew) -> u64,
439{
440    let vl = ext_state.vl();
441    let vstart = ext_state.vstart();
442    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
443    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
444    for i in u32::from(vstart)..vl {
445        if !mask_bit(&mask_buf, i) {
446            continue;
447        }
448        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
449        // `execute_widening_op` for the bound argument)
450        let acc = unsafe { read_wide_element_u64(ext_state.read_vregs(), vd, i, sew) };
451        let b = match src {
452            // SAFETY: register bounds verified by caller
453            OpSrc::Vreg(b_reg) => unsafe {
454                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
455            },
456            OpSrc::Scalar(val) => val,
457        };
458        let result = op(acc, scalar, b, sew);
459        // SAFETY: same as acc read above
460        unsafe {
461            write_wide_element_u64(ext_state.write_vregs(), vd, i, sew, result);
462        }
463    }
464    ext_state.mark_vs_dirty();
465    ext_state.reset_vstart();
466}
467
468/// Signed × signed high half.
469///
470/// Both operands are sign-extended to i64, multiplied as i128, and the upper SEW bits of the
471/// 2*SEW product are returned (zero-extended to u64 for writeback into a SEW-wide element slot).
472#[inline(always)]
473#[doc(hidden)]
474pub fn mulh_ss(a: u64, b: u64, sew: Vsew) -> u64 {
475    let sa = i128::from(sign_extend(a, sew));
476    let sb = i128::from(sign_extend(b, sew));
477    let product = sa.wrapping_mul(sb);
478    // Extract bits [2*SEW-1 : SEW] of the product
479    let high = (product >> u32::from(sew.bits_width())).cast_unsigned() as u64;
480    high & sew_mask(sew)
481}
482
483/// Unsigned × unsigned high half
484#[inline(always)]
485#[doc(hidden)]
486pub fn mulhu_uu(a: u64, b: u64, sew: Vsew) -> u64 {
487    let ua = u128::from(a & sew_mask(sew));
488    let ub = u128::from(b & sew_mask(sew));
489    let product = ua.wrapping_mul(ub);
490    let high = (product >> u32::from(sew.bits_width())) as u64;
491    high & sew_mask(sew)
492}
493
494/// Signed × unsigned high half.
495///
496/// `a` (vs2) is the signed operand; `b` (vs1/rs1) is the unsigned operand.
497#[inline(always)]
498#[doc(hidden)]
499pub fn mulhsu_su(a: u64, b: u64, sew: Vsew) -> u64 {
500    let sa = i128::from(sign_extend(a, sew));
501    let ub = u128::from(b & sew_mask(sew));
502    // Compute signed × unsigned as i128 to preserve sign
503    let product = sa.wrapping_mul(ub.cast_signed());
504    let high = (product >> u32::from(sew.bits_width())).cast_unsigned() as u64;
505    high & sew_mask(sew)
506}
507
508/// Signed divide with division-by-zero and signed-overflow semantics from the RISC-V V spec §12.11.
509///
510/// - Division by zero: result = all-ones (i.e., −1 as signed SEW-wide integer)
511/// - Signed overflow (MIN / −1): result = MIN (i.e., `1 << (SEW-1)`)
512#[inline(always)]
513#[doc(hidden)]
514pub fn sdiv(a: u64, b: u64, sew: Vsew) -> u64 {
515    let sa = sign_extend(a, sew);
516    let sb = sign_extend(b, sew);
517    // Division by zero: return all-ones in the SEW-wide slot (= −1 signed)
518    if sb == 0 {
519        return sew_mask(sew);
520    }
521    // Signed overflow: MIN / -1 returns MIN
522    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits_width()));
523    if sa == sew_min && sb == -1 {
524        return sew_min.cast_unsigned() & sew_mask(sew);
525    }
526    (sa / sb).cast_unsigned() & sew_mask(sew)
527}
528
529/// Signed remainder with division-by-zero and signed-overflow semantics from the RISC-V V spec
530/// §12.11.
531///
532/// - Division by zero: remainder = dividend
533/// - Signed overflow (MIN % −1): remainder = 0
534#[inline(always)]
535#[doc(hidden)]
536#[expect(
537    clippy::modulo_arithmetic,
538    reason = "This is what the code is supposed to do"
539)]
540pub fn srem(a: u64, b: u64, sew: Vsew) -> u64 {
541    let sa = sign_extend(a, sew);
542    let sb = sign_extend(b, sew);
543    // Division by zero: remainder = dividend
544    if sb == 0 {
545        return a & sew_mask(sew);
546    }
547    // Signed overflow: MIN % -1 = 0
548    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits_width()));
549    if sa == sew_min && sb == -1 {
550        return 0;
551    }
552    (sa % sb).cast_unsigned() & sew_mask(sew)
553}