Skip to main content

ab_riscv_interpreter/v/zve64x/muldiv/
zve64x_muldiv_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zve64x::arith::zve64x_arith_helpers::{
5    OpSrc, check_vreg_group_alignment, sew_mask, sign_extend,
6};
7use crate::v::zve64x::arith::zve64x_arith_helpers::{read_element_u64, write_element_u64};
8use crate::v::zve64x::fixed_point::zve64x_fixed_point_helpers::read_wide_element_u64;
9use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
10use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
11use crate::{ExecutionError, ProgramCounter};
12use ab_riscv_primitives::prelude::*;
13use core::fmt;
14
15/// Compute the destination register count for a widening operation (`EMUL = 2 × LMUL`).
16///
17/// Returns `None` when the resulting EMUL falls outside the legal range `[1/8, 8]`, i.e. when
18/// `LMUL` is already `M8` (EMUL would be 16) or the caller asks for a multiplication factor that
19/// pushes the fraction past the legal lower bound.
20///
21/// The register count returned is `max(1, EMUL)`: fractional EMUL values (1/2, 1/4) still occupy
22/// exactly one physical register.
23#[inline(always)]
24#[doc(hidden)]
25pub fn widening_dest_register_count(vlmul: Vlmul) -> Option<u8> {
26    let (lmul_num, lmul_den) = vlmul.as_fraction();
27    // EMUL = 2 × LMUL = (2 * lmul_num) / lmul_den
28    let emul_num = 2u8.checked_mul(lmul_num)?;
29    let emul_den = lmul_den;
30    // Reduce the fraction by GCD (both are powers of two so min works as GCD)
31    let g = emul_num.min(emul_den);
32    let (n, d) = (emul_num / g, emul_den / g);
33    // Legal EMUL fractions: 1/8, 1/4, 1/2, 1, 2, 4, 8
34    let legal = matches!(
35        (n, d),
36        (1, 8) | (1, 4) | (1, 2) | (1, 1) | (2, 1) | (4, 1) | (8, 1)
37    );
38    if !legal {
39        return None;
40    }
41    // Register count: max(1, n/d) = n when d==1, else 1
42    Some(if d > 1 { 1 } else { n })
43}
44
45/// Check that a narrower source register group does not overlap the wider destination group.
46///
47/// For widening instructions `vd` occupies `dest_group_regs` registers (which is
48/// [`widening_dest_register_count()`] of the source LMUL); `vs` occupies `src_group_regs`.
49/// The spec prohibits any overlap between them.
50#[inline(always)]
51#[doc(hidden)]
52pub fn check_no_widening_overlap<Reg, Memory, PC, CustomError>(
53    program_counter: &PC,
54    vd: VReg,
55    vs: VReg,
56    dest_group_regs: u8,
57    src_group_regs: u8,
58) -> Result<(), ExecutionError<Reg::Type, CustomError>>
59where
60    Reg: Register,
61    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
62{
63    let vd_start = vd.bits();
64    let vd_end = vd_start + dest_group_regs;
65    let vs_start = vs.bits();
66    let vs_end = vs_start + src_group_regs;
67    // Overlap when the intervals intersect
68    if vs_start < vd_end && vd_start < vs_end {
69        return Err(ExecutionError::IllegalInstruction {
70            address: program_counter.old_pc(INSTRUCTION_SIZE),
71        });
72    }
73    Ok(())
74}
75
76/// Write a 2*SEW-wide element into the widened destination register group at element index
77/// `elem_i`.
78///
79/// # Safety
80/// `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32` must hold.
81#[inline(always)]
82unsafe fn write_wide_element_u64<const VLENB: usize>(
83    vreg: &mut [[u8; VLENB]; 32],
84    base_reg: u8,
85    elem_i: u32,
86    sew: Vsew,
87    value: u64,
88) {
89    let wide_bytes = usize::from(sew.bytes()) * 2;
90    let elems_per_reg = VLENB / wide_bytes;
91    let reg_off = elem_i as usize / elems_per_reg;
92    let byte_off = (elem_i as usize % elems_per_reg) * wide_bytes;
93    let buf = value.to_le_bytes();
94    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
95    let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
96    // SAFETY: `byte_off + wide_bytes <= VLENB`; `wide_bytes <= 8` for SEW < 64
97    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + wide_bytes) };
98    // SAFETY: `wide_bytes <= 8` because SEW < 64 is enforced before widening ops are called
99    dst.copy_from_slice(unsafe { buf.get_unchecked(..wide_bytes) });
100}
101
102/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
103///
104/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result.
105/// Only the low `sew.bytes()` of the result are written back.
106///
107/// # Safety
108/// - `vd` and source register alignment verified by caller
109/// - `vl <= group_regs * VLENB / sew_bytes`
110/// - When `vm=false`: `vd.bits() != 0`
111#[inline(always)]
112#[expect(clippy::too_many_arguments, reason = "Internal API")]
113#[doc(hidden)]
114pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
115    ext_state: &mut ExtState,
116    vd: VReg,
117    vs2: VReg,
118    src: OpSrc,
119    vm: bool,
120    vl: u32,
121    vstart: u32,
122    sew: Vsew,
123    op: F,
124) where
125    Reg: Register,
126    ExtState: VectorRegistersExt<Reg, CustomError>,
127    [(); ExtState::ELEN as usize]:,
128    [(); ExtState::VLEN as usize]:,
129    [(); ExtState::VLENB as usize]:,
130    CustomError: fmt::Debug,
131    F: Fn(u64, u64, Vsew) -> u64,
132{
133    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
134    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
135    let vd_base = vd.bits();
136    let vs2_base = vs2.bits();
137    for i in vstart..vl {
138        if !mask_bit(&mask_buf, i) {
139            continue;
140        }
141        // SAFETY: register bounds verified by caller
142        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
143        let b = match &src {
144            // SAFETY: register bounds verified by caller
145            OpSrc::Vreg(vs1_base) => unsafe {
146                read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
147            },
148            OpSrc::Scalar(val) => *val,
149        };
150        let result = op(a, b, sew);
151        // SAFETY: register bounds verified by caller
152        unsafe {
153            write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
154        }
155    }
156    ext_state.mark_vs_dirty();
157    ext_state.reset_vstart();
158}
159
160/// Execute a single-width widening operation over `vstart..vl`.
161///
162/// Reads SEW-wide elements from `vs2` and `src`, computes `op`, and writes a 2*SEW-wide result
163/// into `vd`.
164///
165/// # Safety
166/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
167///   and non-overlap verified by caller
168/// - `vl <= src_group_regs * VLENB / sew_bytes`
169/// - SEW < 64 verified by caller (so 2*SEW <= 64 and fits in u64)
170/// - When `vm=false`: `vd.bits() != 0`
171#[inline(always)]
172#[expect(clippy::too_many_arguments, reason = "Internal API")]
173#[doc(hidden)]
174pub unsafe fn execute_widening_op<Reg, ExtState, CustomError, F>(
175    ext_state: &mut ExtState,
176    vd: VReg,
177    vs2: VReg,
178    src: OpSrc,
179    vm: bool,
180    vl: u32,
181    vstart: u32,
182    sew: Vsew,
183    op: F,
184) where
185    Reg: Register,
186    ExtState: VectorRegistersExt<Reg, CustomError>,
187    [(); ExtState::ELEN as usize]:,
188    [(); ExtState::VLEN as usize]:,
189    [(); ExtState::VLENB as usize]:,
190    CustomError: fmt::Debug,
191    F: Fn(u64, u64, Vsew) -> u64,
192{
193    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
194    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
195    let vd_base = vd.bits();
196    let vs2_base = vs2.bits();
197    for i in vstart..vl {
198        if !mask_bit(&mask_buf, i) {
199            continue;
200        }
201        // SAFETY: register bounds verified by caller
202        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
203        let b = match &src {
204            // SAFETY: register bounds verified by caller
205            OpSrc::Vreg(vs1_base) => unsafe {
206                read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
207            },
208            OpSrc::Scalar(val) => *val,
209        };
210        let result = op(a, b, sew);
211        // SAFETY: vd has dest_group_regs registers; element `i` fits within them because
212        // `vl <= src_group_regs * VLENB / sew_bytes` and dest stores at 2*SEW width so
213        // `i < dest_group_regs * VLENB / (2*sew_bytes)`
214        unsafe {
215            write_wide_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
216        }
217    }
218    ext_state.mark_vs_dirty();
219    ext_state.reset_vstart();
220}
221
222/// Execute a single-width multiply-add where the first multiplier is a vector register group.
223///
224/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)` where `acc` is the current `vd[i]`,
225/// `a` is the element from `a_reg`, and `b` is the element from `src`. Returns the new `vd[i]`.
226///
227/// # Safety
228/// - `vd`, `a_reg`, and `src` register alignment verified by caller
229/// - `vl <= group_regs * VLENB / sew_bytes`
230/// - When `vm=false`: `vd.bits() != 0`
231#[inline(always)]
232#[expect(clippy::too_many_arguments, reason = "Internal API")]
233#[doc(hidden)]
234pub unsafe fn execute_muladd_op<Reg, ExtState, CustomError, F>(
235    ext_state: &mut ExtState,
236    vd: VReg,
237    a_reg: u8,
238    src: OpSrc,
239    vm: bool,
240    vl: u32,
241    vstart: u32,
242    sew: Vsew,
243    op: F,
244) where
245    Reg: Register,
246    ExtState: VectorRegistersExt<Reg, CustomError>,
247    [(); ExtState::ELEN as usize]:,
248    [(); ExtState::VLEN as usize]:,
249    [(); ExtState::VLENB as usize]:,
250    CustomError: fmt::Debug,
251    F: Fn(u64, u64, u64, Vsew) -> u64,
252{
253    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
254    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
255    let vd_base = vd.bits();
256    for i in vstart..vl {
257        if !mask_bit(&mask_buf, i) {
258            continue;
259        }
260        // SAFETY: register bounds verified by caller
261        let acc = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vd_base), i, sew) };
262        // SAFETY: register bounds verified by caller
263        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(a_reg), i, sew) };
264        let b = match &src {
265            // SAFETY: register bounds verified by caller
266            OpSrc::Vreg(b_reg) => unsafe {
267                read_element_u64(ext_state.read_vreg(), usize::from(*b_reg), i, sew)
268            },
269            OpSrc::Scalar(val) => *val,
270        };
271        let result = op(acc, a, b, sew);
272        // SAFETY: register bounds verified by caller
273        unsafe {
274            write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
275        }
276    }
277    ext_state.mark_vs_dirty();
278    ext_state.reset_vstart();
279}
280
281/// Execute a single-width multiply-add where the first multiplier is a scalar.
282///
283/// Analogous to [`execute_muladd_op`] but `a` is a fixed scalar instead of a register element.
284///
285/// # Safety
286/// Same as [`execute_muladd_op`], minus constraints on `a_reg`.
287#[inline(always)]
288#[expect(clippy::too_many_arguments, reason = "Internal API")]
289#[doc(hidden)]
290pub unsafe fn execute_muladd_scalar_op<Reg, ExtState, CustomError, F>(
291    ext_state: &mut ExtState,
292    vd: VReg,
293    scalar: u64,
294    src: OpSrc,
295    vm: bool,
296    vl: u32,
297    vstart: u32,
298    sew: Vsew,
299    op: F,
300) where
301    Reg: Register,
302    ExtState: VectorRegistersExt<Reg, CustomError>,
303    [(); ExtState::ELEN as usize]:,
304    [(); ExtState::VLEN as usize]:,
305    [(); ExtState::VLENB as usize]:,
306    CustomError: fmt::Debug,
307    F: Fn(u64, u64, u64, Vsew) -> u64,
308{
309    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
310    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
311    let vd_base = vd.bits();
312    for i in vstart..vl {
313        if !mask_bit(&mask_buf, i) {
314            continue;
315        }
316        // SAFETY: register bounds verified by caller
317        let acc = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vd_base), i, sew) };
318        let b = match &src {
319            // SAFETY: register bounds verified by caller
320            OpSrc::Vreg(b_reg) => unsafe {
321                read_element_u64(ext_state.read_vreg(), usize::from(*b_reg), i, sew)
322            },
323            OpSrc::Scalar(val) => *val,
324        };
325        let result = op(acc, scalar, b, sew);
326        // SAFETY: register bounds verified by caller
327        unsafe {
328            write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
329        }
330    }
331    ext_state.mark_vs_dirty();
332    ext_state.reset_vstart();
333}
334
335/// Execute a widening multiply-add where the first multiplier is a vector register group.
336///
337/// Reads SEW-wide `acc` from the widened `vd` group, SEW-wide `a` from `a_reg`, and SEW-wide
338/// `b` from `src`. Writes a 2*SEW-wide result back into `vd`.
339///
340/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)`.
341///
342/// # Safety
343/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
344///   and non-overlap verified by caller
345/// - SEW < 64 verified by caller
346/// - When `vm=false`: `vd.bits() != 0`
347#[inline(always)]
348#[expect(clippy::too_many_arguments, reason = "Internal API")]
349#[doc(hidden)]
350pub unsafe fn execute_widening_muladd_op<Reg, ExtState, CustomError, F>(
351    ext_state: &mut ExtState,
352    vd: VReg,
353    a_reg: u8,
354    src: OpSrc,
355    vm: bool,
356    vl: u32,
357    vstart: u32,
358    sew: Vsew,
359    op: F,
360) where
361    Reg: Register,
362    ExtState: VectorRegistersExt<Reg, CustomError>,
363    [(); ExtState::ELEN as usize]:,
364    [(); ExtState::VLEN as usize]:,
365    [(); ExtState::VLENB as usize]:,
366    CustomError: fmt::Debug,
367    F: Fn(u64, u64, u64, Vsew) -> u64,
368{
369    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
370    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
371    let vd_base = vd.bits();
372    for i in vstart..vl {
373        if !mask_bit(&mask_buf, i) {
374            continue;
375        }
376        // Read the existing 2*SEW accumulator from vd
377        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
378        // `execute_widening_op` for the bound argument)
379        let acc =
380            unsafe { read_wide_element_u64(ext_state.read_vreg(), usize::from(vd_base), i, sew) };
381        // SAFETY: register bounds verified by caller
382        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(a_reg), i, sew) };
383        let b = match &src {
384            // SAFETY: register bounds verified by caller
385            OpSrc::Vreg(b_reg) => unsafe {
386                read_element_u64(ext_state.read_vreg(), usize::from(*b_reg), i, sew)
387            },
388            OpSrc::Scalar(val) => *val,
389        };
390        let result = op(acc, a, b, sew);
391        // SAFETY: same as acc read above
392        unsafe {
393            write_wide_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
394        }
395    }
396    ext_state.mark_vs_dirty();
397    ext_state.reset_vstart();
398}
399
400/// Execute a widening multiply-add where the first multiplier is a scalar.
401///
402/// Analogous to [`execute_widening_muladd_op`] but `a` is a fixed scalar.
403///
404/// # Safety
405/// Same as [`execute_widening_muladd_op`], minus constraints on `a_reg`.
406#[inline(always)]
407#[expect(clippy::too_many_arguments, reason = "Internal API")]
408#[doc(hidden)]
409pub unsafe fn execute_widening_muladd_scalar_op<Reg, ExtState, CustomError, F>(
410    ext_state: &mut ExtState,
411    vd: VReg,
412    scalar: u64,
413    src: OpSrc,
414    vm: bool,
415    vl: u32,
416    vstart: u32,
417    sew: Vsew,
418    op: F,
419) where
420    Reg: Register,
421    ExtState: VectorRegistersExt<Reg, CustomError>,
422    [(); ExtState::ELEN as usize]:,
423    [(); ExtState::VLEN as usize]:,
424    [(); ExtState::VLENB as usize]:,
425    CustomError: fmt::Debug,
426    F: Fn(u64, u64, u64, Vsew) -> u64,
427{
428    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
429    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
430    let vd_base = vd.bits();
431    for i in vstart..vl {
432        if !mask_bit(&mask_buf, i) {
433            continue;
434        }
435        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
436        // `execute_widening_op` for the bound argument)
437        let acc =
438            unsafe { read_wide_element_u64(ext_state.read_vreg(), usize::from(vd_base), i, sew) };
439        let b = match &src {
440            // SAFETY: register bounds verified by caller
441            OpSrc::Vreg(b_reg) => unsafe {
442                read_element_u64(ext_state.read_vreg(), usize::from(*b_reg), i, sew)
443            },
444            OpSrc::Scalar(val) => *val,
445        };
446        let result = op(acc, scalar, b, sew);
447        // SAFETY: same as acc read above
448        unsafe {
449            write_wide_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
450        }
451    }
452    ext_state.mark_vs_dirty();
453    ext_state.reset_vstart();
454}
455
456/// Signed × signed high half.
457///
458/// Both operands are sign-extended to i64, multiplied as i128, and the upper SEW bits of the
459/// 2*SEW product are returned (zero-extended to u64 for writeback into a SEW-wide element slot).
460#[inline(always)]
461#[doc(hidden)]
462pub fn mulh_ss(a: u64, b: u64, sew: Vsew) -> u64 {
463    let sa = i128::from(sign_extend(a, sew));
464    let sb = i128::from(sign_extend(b, sew));
465    let product = sa.wrapping_mul(sb);
466    // Extract bits [2*SEW-1 : SEW] of the product
467    let high = (product >> u32::from(sew.bits())).cast_unsigned() as u64;
468    high & sew_mask(sew)
469}
470
471/// Unsigned × unsigned high half
472#[inline(always)]
473#[doc(hidden)]
474pub fn mulhu_uu(a: u64, b: u64, sew: Vsew) -> u64 {
475    let ua = u128::from(a & sew_mask(sew));
476    let ub = u128::from(b & sew_mask(sew));
477    let product = ua.wrapping_mul(ub);
478    let high = (product >> u32::from(sew.bits())) as u64;
479    high & sew_mask(sew)
480}
481
482/// Signed × unsigned high half.
483///
484/// `a` (vs2) is the signed operand; `b` (vs1/rs1) is the unsigned operand.
485#[inline(always)]
486#[doc(hidden)]
487pub fn mulhsu_su(a: u64, b: u64, sew: Vsew) -> u64 {
488    let sa = i128::from(sign_extend(a, sew));
489    let ub = u128::from(b & sew_mask(sew));
490    // Compute signed × unsigned as i128 to preserve sign
491    let product = sa.wrapping_mul(ub.cast_signed());
492    let high = (product >> u32::from(sew.bits())).cast_unsigned() as u64;
493    high & sew_mask(sew)
494}
495
496/// Signed divide with division-by-zero and signed-overflow semantics from the RISC-V V spec §12.11.
497///
498/// - Division by zero: result = all-ones (i.e., −1 as signed SEW-wide integer)
499/// - Signed overflow (MIN / −1): result = MIN (i.e., `1 << (SEW-1)`)
500#[inline(always)]
501#[doc(hidden)]
502pub fn sdiv(a: u64, b: u64, sew: Vsew) -> u64 {
503    let sa = sign_extend(a, sew);
504    let sb = sign_extend(b, sew);
505    // Division by zero: return all-ones in the SEW-wide slot (= −1 signed)
506    if sb == 0 {
507        return sew_mask(sew);
508    }
509    // Signed overflow: MIN / -1 returns MIN
510    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits()));
511    if sa == sew_min && sb == -1 {
512        return sew_min.cast_unsigned() & sew_mask(sew);
513    }
514    (sa / sb).cast_unsigned() & sew_mask(sew)
515}
516
517/// Signed remainder with division-by-zero and signed-overflow semantics from the RISC-V V spec
518/// §12.11.
519///
520/// - Division by zero: remainder = dividend
521/// - Signed overflow (MIN % −1): remainder = 0
522#[inline(always)]
523#[doc(hidden)]
524pub fn srem(a: u64, b: u64, sew: Vsew) -> u64 {
525    let sa = sign_extend(a, sew);
526    let sb = sign_extend(b, sew);
527    // Division by zero: remainder = dividend
528    if sb == 0 {
529        return a & sew_mask(sew);
530    }
531    // Signed overflow: MIN % -1 = 0
532    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits()));
533    if sa == sew_min && sb == -1 {
534        return 0;
535    }
536    (sa % sb).cast_unsigned() & sew_mask(sew)
537}