Skip to main content

ab_riscv_interpreter/v/zvexx/muldiv/
zvexx_muldiv_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5    OpSrc, check_vreg_group_alignment, sew_mask, sign_extend,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{read_element_u64, write_element_u64};
8use crate::v::zvexx::fixed_point::zvexx_fixed_point_helpers::read_wide_element_u64;
9use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
10use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
11use crate::{ExecutionError, ProgramCounter};
12use ab_riscv_primitives::prelude::*;
13use core::fmt;
14
15/// Compute the destination register count for a widening operation (`EMUL = 2 × LMUL`).
16///
17/// Returns `None` when the resulting EMUL falls outside the legal range `[1/8, 8]`, i.e. when
18/// `LMUL` is already `M8` (EMUL would be 16) or the caller asks for a multiplication factor that
19/// pushes the fraction past the legal lower bound.
20///
21/// The register count returned is `max(1, EMUL)`: fractional EMUL values (1/2, 1/4) still occupy
22/// exactly one physical register.
23#[inline(always)]
24#[doc(hidden)]
25pub fn widening_dest_register_count(vlmul: Vlmul) -> Option<u8> {
26    let (lmul_num, lmul_den) = vlmul.as_fraction();
27    // EMUL = 2 × LMUL = (2 * lmul_num) / lmul_den
28    let emul_num = 2u8.checked_mul(lmul_num)?;
29    let emul_den = lmul_den;
30    // Reduce the fraction by GCD (both are powers of two so min works as GCD)
31    let g = emul_num.min(emul_den);
32    let (n, d) = (emul_num / g, emul_den / g);
33    // Legal EMUL fractions: 1/8, 1/4, 1/2, 1, 2, 4, 8
34    let legal = matches!(
35        (n, d),
36        (1, 8) | (1, 4) | (1, 2) | (1, 1) | (2, 1) | (4, 1) | (8, 1)
37    );
38    if !legal {
39        return None;
40    }
41    // Register count: max(1, n/d) = n when d==1, else 1
42    Some(if d > 1 { 1 } else { n })
43}
44
45/// Check that a narrower source register group does not *illegally* overlap the wider destination
46/// group of a widening instruction.
47///
48/// For widening instructions `vd` occupies `dest_group_regs` registers (which is
49/// [`widening_dest_register_count()`] of the source LMUL); `vs` occupies `src_group_regs`.
50///
51/// Per the RISC-V "V" spec §5.2, because the destination EEW (`2*SEW`) is greater than the source
52/// EEW (`SEW`), the source group *may* overlap the destination group, but only when both of the
53/// following hold:
54/// - the source EMUL is at least 1, and
55/// - the overlap is in the highest-numbered part of the destination register group, i.e. the source
56///   occupies exactly the top `src_group_regs` registers of the destination group.
57///
58/// When the source EMUL is at least 1, `dest_group_regs == 2 * src_group_regs`, so
59/// `dest_group_regs > src_group_regs` is an equivalent test for "source EMUL >= 1": a fractional
60/// source EMUL (`< 1`) yields `dest_group_regs == src_group_regs == 1`, in which case no overlap is
61/// ever legal. Any overlap that is not the legal "source in the highest-numbered part" form is
62/// rejected as an illegal instruction.
63#[inline(always)]
64#[doc(hidden)]
65pub fn check_no_widening_overlap<Reg, Memory, PC, CustomError>(
66    program_counter: &PC,
67    vd: VReg,
68    vs: VReg,
69    dest_group_regs: u8,
70    src_group_regs: u8,
71) -> Result<(), ExecutionError<Reg::Type, CustomError>>
72where
73    Reg: Register,
74    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
75{
76    let vd_start = vd.to_bits();
77    let vd_end = vd_start + dest_group_regs;
78    let vs_start = vs.to_bits();
79    let vs_end = vs_start + src_group_regs;
80    // Disjoint register groups are always fine
81    if vs_start >= vd_end || vd_start >= vs_end {
82        return Ok(());
83    }
84    // The groups overlap. This is legal only when the source EMUL is at least 1
85    // (`dest_group_regs > src_group_regs`) and the source occupies exactly the highest-numbered
86    // part of the destination group (`vs_start == vd_end - src_group_regs`).
87    if dest_group_regs > src_group_regs && vs_start == vd_end - src_group_regs {
88        return Ok(());
89    }
90    Err(ExecutionError::IllegalInstruction {
91        address: program_counter.old_pc(INSTRUCTION_SIZE),
92    })
93}
94
95/// Write a 2*SEW-wide element into the widened destination register group at element index
96/// `elem_i`.
97///
98/// # Safety
99/// `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32` must hold.
100#[inline(always)]
101unsafe fn write_wide_element_u64<const VLENB: usize>(
102    vregs: &mut VectorRegisterFile<VLENB>,
103    base_reg: VReg,
104    elem_i: u32,
105    sew: Vsew,
106    value: u64,
107) {
108    let wide_bytes = usize::from(sew.bytes_width()) * 2;
109    let elems_per_reg = VLENB / wide_bytes;
110    let reg_off = elem_i as usize / elems_per_reg;
111    let byte_off = (elem_i as usize % elems_per_reg) * wide_bytes;
112    let buf = value.to_le_bytes();
113    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
114    let reg = unsafe {
115        vregs.get_mut(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
116    };
117    // SAFETY: `byte_off + wide_bytes <= VLENB`; `wide_bytes <= 8` for SEW < 64
118    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + wide_bytes) };
119    // SAFETY: `wide_bytes <= 8` because SEW < 64 is enforced before widening ops are called
120    dst.copy_from_slice(unsafe { buf.get_unchecked(..wide_bytes) });
121}
122
123/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
124///
125/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result.
126/// Only the low `sew.bytes()` of the result are written back.
127///
128/// # Safety
129/// - `vd` and source register alignment verified by caller
130/// - `vl <= group_regs * VLENB / sew_bytes`
131/// - When `vm=false`: `vd.to_bits() != 0`
132#[inline(always)]
133#[doc(hidden)]
134pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
135    ext_state: &mut ExtState,
136    vd: VReg,
137    vs2: VReg,
138    src: OpSrc,
139    vm: bool,
140    sew: Vsew,
141    op: F,
142) where
143    Reg: Register,
144    ExtState: VectorRegistersExt<Reg, CustomError>,
145    [(); ExtState::ELEN as usize]:,
146    [(); ExtState::VLEN as usize]:,
147    [(); ExtState::VLENB as usize]:,
148    CustomError: fmt::Debug,
149    F: Fn(u64, u64, Vsew) -> u64,
150{
151    let vl = ext_state.vl();
152    let vstart = ext_state.vstart();
153    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
154    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
155    for i in u32::from(vstart)..vl {
156        if !mask_bit(&mask_buf, i) {
157            continue;
158        }
159        // SAFETY: register bounds verified by caller
160        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
161        let b = match src {
162            // SAFETY: register bounds verified by caller
163            OpSrc::Vreg(vs1_base) => unsafe {
164                read_element_u64(ext_state.read_vregs(), vs1_base, i, sew)
165            },
166            OpSrc::Scalar(val) => val,
167        };
168        let result = op(a, b, sew);
169        // SAFETY: register bounds verified by caller
170        unsafe {
171            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
172        }
173    }
174    ext_state.mark_vs_dirty();
175    ext_state.reset_vstart();
176}
177
178/// Execute a single-width widening operation over `vstart..vl`.
179///
180/// Reads SEW-wide elements from `vs2` and `src`, computes `op`, and writes a 2*SEW-wide result
181/// into `vd`.
182///
183/// # Safety
184/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
185///   and non-overlap verified by caller
186/// - `vl <= src_group_regs * VLENB / sew_bytes`
187/// - SEW < 64 verified by caller (so 2*SEW <= 64 and fits in u64)
188/// - When `vm=false`: `vd.to_bits() != 0`
189#[inline(always)]
190#[doc(hidden)]
191pub unsafe fn execute_widening_op<Reg, ExtState, CustomError, F>(
192    ext_state: &mut ExtState,
193    vd: VReg,
194    vs2: VReg,
195    src: OpSrc,
196    vm: bool,
197    sew: Vsew,
198    op: F,
199) where
200    Reg: Register,
201    ExtState: VectorRegistersExt<Reg, CustomError>,
202    [(); ExtState::ELEN as usize]:,
203    [(); ExtState::VLEN as usize]:,
204    [(); ExtState::VLENB as usize]:,
205    CustomError: fmt::Debug,
206    F: Fn(u64, u64, Vsew) -> u64,
207{
208    let vl = ext_state.vl();
209    let vstart = ext_state.vstart();
210    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
211    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
212    for i in u32::from(vstart)..vl {
213        if !mask_bit(&mask_buf, i) {
214            continue;
215        }
216        // SAFETY: register bounds verified by caller
217        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
218        let b = match src {
219            // SAFETY: register bounds verified by caller
220            OpSrc::Vreg(vs1_base) => unsafe {
221                read_element_u64(ext_state.read_vregs(), vs1_base, i, sew)
222            },
223            OpSrc::Scalar(val) => val,
224        };
225        let result = op(a, b, sew);
226        // SAFETY: vd has dest_group_regs registers; element `i` fits within them because
227        // `vl <= src_group_regs * VLENB / sew_bytes` and dest stores at 2*SEW width so
228        // `i < dest_group_regs * VLENB / (2*sew_bytes)`
229        unsafe {
230            write_wide_element_u64(ext_state.write_vregs(), vd, i, sew, result);
231        }
232    }
233    ext_state.mark_vs_dirty();
234    ext_state.reset_vstart();
235}
236
237/// Execute a single-width multiply-add where the first multiplier is a vector register group.
238///
239/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)` where `acc` is the current `vd[i]`,
240/// `a` is the element from `a_reg`, and `b` is the element from `src`. Returns the new `vd[i]`.
241///
242/// # Safety
243/// - `vd`, `a_reg`, and `src` register alignment verified by caller
244/// - `vl <= group_regs * VLENB / sew_bytes`
245/// - When `vm=false`: `vd.to_bits() != 0`
246#[inline(always)]
247#[doc(hidden)]
248pub unsafe fn execute_muladd_op<Reg, ExtState, CustomError, F>(
249    ext_state: &mut ExtState,
250    vd: VReg,
251    a_reg: VReg,
252    src: OpSrc,
253    vm: bool,
254    sew: Vsew,
255    op: F,
256) where
257    Reg: Register,
258    ExtState: VectorRegistersExt<Reg, CustomError>,
259    [(); ExtState::ELEN as usize]:,
260    [(); ExtState::VLEN as usize]:,
261    [(); ExtState::VLENB as usize]:,
262    CustomError: fmt::Debug,
263    F: Fn(u64, u64, u64, Vsew) -> u64,
264{
265    let vl = ext_state.vl();
266    let vstart = ext_state.vstart();
267    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
268    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
269    for i in u32::from(vstart)..vl {
270        if !mask_bit(&mask_buf, i) {
271            continue;
272        }
273        // SAFETY: register bounds verified by caller
274        let acc = unsafe { read_element_u64(ext_state.read_vregs(), vd, i, sew) };
275        // SAFETY: register bounds verified by caller
276        let a = unsafe { read_element_u64(ext_state.read_vregs(), a_reg, i, sew) };
277        let b = match src {
278            // SAFETY: register bounds verified by caller
279            OpSrc::Vreg(b_reg) => unsafe {
280                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
281            },
282            OpSrc::Scalar(val) => val,
283        };
284        let result = op(acc, a, b, sew);
285        // SAFETY: register bounds verified by caller
286        unsafe {
287            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
288        }
289    }
290    ext_state.mark_vs_dirty();
291    ext_state.reset_vstart();
292}
293
294/// Execute a single-width multiply-add where the first multiplier is a scalar.
295///
296/// Analogous to [`execute_muladd_op`] but `a` is a fixed scalar instead of a register element.
297///
298/// # Safety
299/// Same as [`execute_muladd_op`], minus constraints on `a_reg`.
300#[inline(always)]
301#[doc(hidden)]
302pub unsafe fn execute_muladd_scalar_op<Reg, ExtState, CustomError, F>(
303    ext_state: &mut ExtState,
304    vd: VReg,
305    scalar: u64,
306    src: OpSrc,
307    vm: bool,
308    sew: Vsew,
309    op: F,
310) where
311    Reg: Register,
312    ExtState: VectorRegistersExt<Reg, CustomError>,
313    [(); ExtState::ELEN as usize]:,
314    [(); ExtState::VLEN as usize]:,
315    [(); ExtState::VLENB as usize]:,
316    CustomError: fmt::Debug,
317    F: Fn(u64, u64, u64, Vsew) -> u64,
318{
319    let vl = ext_state.vl();
320    let vstart = ext_state.vstart();
321    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
322    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
323    for i in u32::from(vstart)..vl {
324        if !mask_bit(&mask_buf, i) {
325            continue;
326        }
327        // SAFETY: register bounds verified by caller
328        let acc = unsafe { read_element_u64(ext_state.read_vregs(), vd, i, sew) };
329        let b = match src {
330            // SAFETY: register bounds verified by caller
331            OpSrc::Vreg(b_reg) => unsafe {
332                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
333            },
334            OpSrc::Scalar(val) => val,
335        };
336        let result = op(acc, scalar, b, sew);
337        // SAFETY: register bounds verified by caller
338        unsafe {
339            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
340        }
341    }
342    ext_state.mark_vs_dirty();
343    ext_state.reset_vstart();
344}
345
346/// Execute a widening multiply-add where the first multiplier is a vector register group.
347///
348/// Reads SEW-wide `acc` from the widened `vd` group, SEW-wide `a` from `a_reg`, and SEW-wide
349/// `b` from `src`. Writes a 2*SEW-wide result back into `vd`.
350///
351/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)`.
352///
353/// # Safety
354/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
355///   and non-overlap verified by caller
356/// - SEW < 64 verified by caller
357/// - When `vm=false`: `vd.to_bits() != 0`
358#[inline(always)]
359#[doc(hidden)]
360pub unsafe fn execute_widening_muladd_op<Reg, ExtState, CustomError, F>(
361    ext_state: &mut ExtState,
362    vd: VReg,
363    a_reg: VReg,
364    src: OpSrc,
365    vm: bool,
366    sew: Vsew,
367    op: F,
368) where
369    Reg: Register,
370    ExtState: VectorRegistersExt<Reg, CustomError>,
371    [(); ExtState::ELEN as usize]:,
372    [(); ExtState::VLEN as usize]:,
373    [(); ExtState::VLENB as usize]:,
374    CustomError: fmt::Debug,
375    F: Fn(u64, u64, u64, Vsew) -> u64,
376{
377    let vl = ext_state.vl();
378    let vstart = ext_state.vstart();
379    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
380    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
381    for i in u32::from(vstart)..vl {
382        if !mask_bit(&mask_buf, i) {
383            continue;
384        }
385        // Read the existing 2*SEW accumulator from vd
386        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
387        // `execute_widening_op` for the bound argument)
388        let acc = unsafe { read_wide_element_u64(ext_state.read_vregs(), vd, i, sew) };
389        // SAFETY: register bounds verified by caller
390        let a = unsafe { read_element_u64(ext_state.read_vregs(), a_reg, i, sew) };
391        let b = match src {
392            // SAFETY: register bounds verified by caller
393            OpSrc::Vreg(b_reg) => unsafe {
394                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
395            },
396            OpSrc::Scalar(val) => val,
397        };
398        let result = op(acc, a, b, sew);
399        // SAFETY: same as acc read above
400        unsafe {
401            write_wide_element_u64(ext_state.write_vregs(), vd, i, sew, result);
402        }
403    }
404    ext_state.mark_vs_dirty();
405    ext_state.reset_vstart();
406}
407
408/// Execute a widening multiply-add where the first multiplier is a scalar.
409///
410/// Analogous to [`execute_widening_muladd_op`] but `a` is a fixed scalar.
411///
412/// # Safety
413/// Same as [`execute_widening_muladd_op`], minus constraints on `a_reg`.
414#[inline(always)]
415#[doc(hidden)]
416pub unsafe fn execute_widening_muladd_scalar_op<Reg, ExtState, CustomError, F>(
417    ext_state: &mut ExtState,
418    vd: VReg,
419    scalar: u64,
420    src: OpSrc,
421    vm: bool,
422    sew: Vsew,
423    op: F,
424) where
425    Reg: Register,
426    ExtState: VectorRegistersExt<Reg, CustomError>,
427    [(); ExtState::ELEN as usize]:,
428    [(); ExtState::VLEN as usize]:,
429    [(); ExtState::VLENB as usize]:,
430    CustomError: fmt::Debug,
431    F: Fn(u64, u64, u64, Vsew) -> u64,
432{
433    let vl = ext_state.vl();
434    let vstart = ext_state.vstart();
435    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
436    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
437    for i in u32::from(vstart)..vl {
438        if !mask_bit(&mask_buf, i) {
439            continue;
440        }
441        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
442        // `execute_widening_op` for the bound argument)
443        let acc = unsafe { read_wide_element_u64(ext_state.read_vregs(), vd, i, sew) };
444        let b = match src {
445            // SAFETY: register bounds verified by caller
446            OpSrc::Vreg(b_reg) => unsafe {
447                read_element_u64(ext_state.read_vregs(), b_reg, i, sew)
448            },
449            OpSrc::Scalar(val) => val,
450        };
451        let result = op(acc, scalar, b, sew);
452        // SAFETY: same as acc read above
453        unsafe {
454            write_wide_element_u64(ext_state.write_vregs(), vd, i, sew, result);
455        }
456    }
457    ext_state.mark_vs_dirty();
458    ext_state.reset_vstart();
459}
460
461/// Signed × signed high half.
462///
463/// Both operands are sign-extended to i64, multiplied as i128, and the upper SEW bits of the
464/// 2*SEW product are returned (zero-extended to u64 for writeback into a SEW-wide element slot).
465#[inline(always)]
466#[doc(hidden)]
467pub fn mulh_ss(a: u64, b: u64, sew: Vsew) -> u64 {
468    let sa = i128::from(sign_extend(a, sew));
469    let sb = i128::from(sign_extend(b, sew));
470    let product = sa.wrapping_mul(sb);
471    // Extract bits [2*SEW-1 : SEW] of the product
472    let high = (product >> u32::from(sew.bits_width())).cast_unsigned() as u64;
473    high & sew_mask(sew)
474}
475
476/// Unsigned × unsigned high half
477#[inline(always)]
478#[doc(hidden)]
479pub fn mulhu_uu(a: u64, b: u64, sew: Vsew) -> u64 {
480    let ua = u128::from(a & sew_mask(sew));
481    let ub = u128::from(b & sew_mask(sew));
482    let product = ua.wrapping_mul(ub);
483    let high = (product >> u32::from(sew.bits_width())) as u64;
484    high & sew_mask(sew)
485}
486
487/// Signed × unsigned high half.
488///
489/// `a` (vs2) is the signed operand; `b` (vs1/rs1) is the unsigned operand.
490#[inline(always)]
491#[doc(hidden)]
492pub fn mulhsu_su(a: u64, b: u64, sew: Vsew) -> u64 {
493    let sa = i128::from(sign_extend(a, sew));
494    let ub = u128::from(b & sew_mask(sew));
495    // Compute signed × unsigned as i128 to preserve sign
496    let product = sa.wrapping_mul(ub.cast_signed());
497    let high = (product >> u32::from(sew.bits_width())).cast_unsigned() as u64;
498    high & sew_mask(sew)
499}
500
501/// Signed divide with division-by-zero and signed-overflow semantics from the RISC-V V spec §12.11.
502///
503/// - Division by zero: result = all-ones (i.e., −1 as signed SEW-wide integer)
504/// - Signed overflow (MIN / −1): result = MIN (i.e., `1 << (SEW-1)`)
505#[inline(always)]
506#[doc(hidden)]
507pub fn sdiv(a: u64, b: u64, sew: Vsew) -> u64 {
508    let sa = sign_extend(a, sew);
509    let sb = sign_extend(b, sew);
510    // Division by zero: return all-ones in the SEW-wide slot (= −1 signed)
511    if sb == 0 {
512        return sew_mask(sew);
513    }
514    // Signed overflow: MIN / -1 returns MIN
515    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits_width()));
516    if sa == sew_min && sb == -1 {
517        return sew_min.cast_unsigned() & sew_mask(sew);
518    }
519    (sa / sb).cast_unsigned() & sew_mask(sew)
520}
521
522/// Signed remainder with division-by-zero and signed-overflow semantics from the RISC-V V spec
523/// §12.11.
524///
525/// - Division by zero: remainder = dividend
526/// - Signed overflow (MIN % −1): remainder = 0
527#[inline(always)]
528#[doc(hidden)]
529pub fn srem(a: u64, b: u64, sew: Vsew) -> u64 {
530    let sa = sign_extend(a, sew);
531    let sb = sign_extend(b, sew);
532    // Division by zero: remainder = dividend
533    if sb == 0 {
534        return a & sew_mask(sew);
535    }
536    // Signed overflow: MIN % -1 = 0
537    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits_width()));
538    if sa == sew_min && sb == -1 {
539        return 0;
540    }
541    (sa % sb).cast_unsigned() & sew_mask(sew)
542}