Skip to main content

ab_riscv_interpreter/v/zve64x/muldiv/
zve64x_muldiv_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zve64x::arith::zve64x_arith_helpers::{
5    OpSrc, check_vreg_group_alignment, sew_mask, sign_extend,
6};
7use crate::v::zve64x::arith::zve64x_arith_helpers::{read_element_u64, write_element_u64};
8use crate::v::zve64x::fixed_point::zve64x_fixed_point_helpers::read_wide_element_u64;
9use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
10use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
11use crate::{ExecutionError, InterpreterState, ProgramCounter, VirtualMemory};
12use ab_riscv_primitives::instructions::v::{Vlmul, Vsew};
13use ab_riscv_primitives::registers::general_purpose::Register;
14use ab_riscv_primitives::registers::vector::VReg;
15use core::fmt;
16
17/// Compute the destination register count for a widening operation (`EMUL = 2 × LMUL`).
18///
19/// Returns `None` when the resulting EMUL falls outside the legal range `[1/8, 8]`, i.e. when
20/// `LMUL` is already `M8` (EMUL would be 16) or the caller asks for a multiplication factor that
21/// pushes the fraction past the legal lower bound.
22///
23/// The register count returned is `max(1, EMUL)`: fractional EMUL values (1/2, 1/4) still occupy
24/// exactly one physical register.
25#[inline(always)]
26#[doc(hidden)]
27pub fn widening_dest_register_count(vlmul: Vlmul) -> Option<u8> {
28    let (lmul_num, lmul_den) = vlmul.as_fraction();
29    // EMUL = 2 × LMUL = (2 * lmul_num) / lmul_den
30    let emul_num = 2u8.checked_mul(lmul_num)?;
31    let emul_den = lmul_den;
32    // Reduce the fraction by GCD (both are powers of two so min works as GCD)
33    let g = emul_num.min(emul_den);
34    let (n, d) = (emul_num / g, emul_den / g);
35    // Legal EMUL fractions: 1/8, 1/4, 1/2, 1, 2, 4, 8
36    let legal = matches!(
37        (n, d),
38        (1, 8) | (1, 4) | (1, 2) | (1, 1) | (2, 1) | (4, 1) | (8, 1)
39    );
40    if !legal {
41        return None;
42    }
43    // Register count: max(1, n/d) = n when d==1, else 1
44    Some(if d > 1 { 1 } else { n })
45}
46
47/// Check that a narrower source register group does not overlap the wider destination group.
48///
49/// For widening instructions `vd` occupies `dest_group_regs` registers (which is
50/// [`widening_dest_register_count()`] of the source LMUL); `vs` occupies `src_group_regs`.
51/// The spec prohibits any overlap between them.
52#[inline(always)]
53#[doc(hidden)]
54pub fn check_no_widening_overlap<Reg, ExtState, Memory, PC, IH, CustomError>(
55    state: &InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
56    vd: VReg,
57    vs: VReg,
58    dest_group_regs: u8,
59    src_group_regs: u8,
60) -> Result<(), ExecutionError<Reg::Type, CustomError>>
61where
62    Reg: Register,
63    [(); Reg::N]:,
64    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
65{
66    let vd_start = vd.bits();
67    let vd_end = vd_start + dest_group_regs;
68    let vs_start = vs.bits();
69    let vs_end = vs_start + src_group_regs;
70    // Overlap when the intervals intersect
71    if vs_start < vd_end && vd_start < vs_end {
72        return Err(ExecutionError::IllegalInstruction {
73            address: state.instruction_fetcher.old_pc(INSTRUCTION_SIZE),
74        });
75    }
76    Ok(())
77}
78
79/// Write a 2*SEW-wide element into the widened destination register group at element index
80/// `elem_i`.
81///
82/// # Safety
83/// `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32` must hold.
84#[inline(always)]
85unsafe fn write_wide_element_u64<const VLENB: usize>(
86    vreg: &mut [[u8; VLENB]; 32],
87    base_reg: u8,
88    elem_i: u32,
89    sew: Vsew,
90    value: u64,
91) {
92    let wide_bytes = usize::from(sew.bytes()) * 2;
93    let elems_per_reg = VLENB / wide_bytes;
94    let reg_off = elem_i as usize / elems_per_reg;
95    let byte_off = (elem_i as usize % elems_per_reg) * wide_bytes;
96    let buf = value.to_le_bytes();
97    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
98    let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
99    // SAFETY: `byte_off + wide_bytes <= VLENB`; `wide_bytes <= 8` for SEW < 64
100    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + wide_bytes) };
101    // SAFETY: `wide_bytes <= 8` because SEW < 64 is enforced before widening ops are called
102    dst.copy_from_slice(unsafe { buf.get_unchecked(..wide_bytes) });
103}
104
105/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
106///
107/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result.
108/// Only the low `sew.bytes()` of the result are written back.
109///
110/// # Safety
111/// - `vd` and source register alignment verified by caller
112/// - `vl <= group_regs * VLENB / sew_bytes`
113/// - When `vm=false`: `vd.bits() != 0`
114#[inline(always)]
115#[expect(clippy::too_many_arguments, reason = "Internal API")]
116#[doc(hidden)]
117pub unsafe fn execute_arith_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
118    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
119    vd: VReg,
120    vs2: VReg,
121    src: OpSrc,
122    vm: bool,
123    vl: u32,
124    vstart: u32,
125    sew: Vsew,
126    op: F,
127) where
128    Reg: Register,
129    [(); Reg::N]:,
130    ExtState: VectorRegistersExt<Reg, CustomError>,
131    [(); ExtState::ELEN as usize]:,
132    [(); ExtState::VLEN as usize]:,
133    [(); ExtState::VLENB as usize]:,
134    Memory: VirtualMemory,
135    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
136    CustomError: fmt::Debug,
137    F: Fn(u64, u64, Vsew) -> u64,
138{
139    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
140    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
141    let vd_base = vd.bits();
142    let vs2_base = vs2.bits();
143    for i in vstart..vl {
144        if !mask_bit(&mask_buf, i) {
145            continue;
146        }
147        // SAFETY: register bounds verified by caller
148        let a =
149            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
150        let b = match &src {
151            // SAFETY: register bounds verified by caller
152            OpSrc::Vreg(vs1_base) => unsafe {
153                read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
154            },
155            OpSrc::Scalar(val) => *val,
156        };
157        let result = op(a, b, sew);
158        // SAFETY: register bounds verified by caller
159        unsafe {
160            write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
161        }
162    }
163    state.ext_state.mark_vs_dirty();
164    state.ext_state.reset_vstart();
165}
166
167/// Execute a single-width widening operation over `vstart..vl`.
168///
169/// Reads SEW-wide elements from `vs2` and `src`, computes `op`, and writes a 2*SEW-wide result
170/// into `vd`.
171///
172/// # Safety
173/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
174///   and non-overlap verified by caller
175/// - `vl <= src_group_regs * VLENB / sew_bytes`
176/// - SEW < 64 verified by caller (so 2*SEW <= 64 and fits in u64)
177/// - When `vm=false`: `vd.bits() != 0`
178#[inline(always)]
179#[expect(clippy::too_many_arguments, reason = "Internal API")]
180#[doc(hidden)]
181pub unsafe fn execute_widening_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
182    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
183    vd: VReg,
184    vs2: VReg,
185    src: OpSrc,
186    vm: bool,
187    vl: u32,
188    vstart: u32,
189    sew: Vsew,
190    op: F,
191) where
192    Reg: Register,
193    [(); Reg::N]:,
194    ExtState: VectorRegistersExt<Reg, CustomError>,
195    [(); ExtState::ELEN as usize]:,
196    [(); ExtState::VLEN as usize]:,
197    [(); ExtState::VLENB as usize]:,
198    Memory: VirtualMemory,
199    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
200    CustomError: fmt::Debug,
201    F: Fn(u64, u64, Vsew) -> u64,
202{
203    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
204    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
205    let vd_base = vd.bits();
206    let vs2_base = vs2.bits();
207    for i in vstart..vl {
208        if !mask_bit(&mask_buf, i) {
209            continue;
210        }
211        // SAFETY: register bounds verified by caller
212        let a =
213            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
214        let b = match &src {
215            // SAFETY: register bounds verified by caller
216            OpSrc::Vreg(vs1_base) => unsafe {
217                read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
218            },
219            OpSrc::Scalar(val) => *val,
220        };
221        let result = op(a, b, sew);
222        // SAFETY: vd has dest_group_regs registers; element `i` fits within them because
223        // `vl <= src_group_regs * VLENB / sew_bytes` and dest stores at 2*SEW width so
224        // `i < dest_group_regs * VLENB / (2*sew_bytes)`
225        unsafe {
226            write_wide_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
227        }
228    }
229    state.ext_state.mark_vs_dirty();
230    state.ext_state.reset_vstart();
231}
232
233/// Execute a single-width multiply-add where the first multiplier is a vector register group.
234///
235/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)` where `acc` is the current `vd[i]`,
236/// `a` is the element from `a_reg`, and `b` is the element from `src`. Returns the new `vd[i]`.
237///
238/// # Safety
239/// - `vd`, `a_reg`, and `src` register alignment verified by caller
240/// - `vl <= group_regs * VLENB / sew_bytes`
241/// - When `vm=false`: `vd.bits() != 0`
242#[inline(always)]
243#[expect(clippy::too_many_arguments, reason = "Internal API")]
244#[doc(hidden)]
245pub unsafe fn execute_muladd_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
246    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
247    vd: VReg,
248    a_reg: u8,
249    src: OpSrc,
250    vm: bool,
251    vl: u32,
252    vstart: u32,
253    sew: Vsew,
254    op: F,
255) where
256    Reg: Register,
257    [(); Reg::N]:,
258    ExtState: VectorRegistersExt<Reg, CustomError>,
259    [(); ExtState::ELEN as usize]:,
260    [(); ExtState::VLEN as usize]:,
261    [(); ExtState::VLENB as usize]:,
262    Memory: VirtualMemory,
263    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
264    CustomError: fmt::Debug,
265    F: Fn(u64, u64, u64, Vsew) -> u64,
266{
267    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
268    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
269    let vd_base = vd.bits();
270    for i in vstart..vl {
271        if !mask_bit(&mask_buf, i) {
272            continue;
273        }
274        // SAFETY: register bounds verified by caller
275        let acc =
276            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vd_base), i, sew) };
277        // SAFETY: register bounds verified by caller
278        let a =
279            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(a_reg), i, sew) };
280        let b = match &src {
281            // SAFETY: register bounds verified by caller
282            OpSrc::Vreg(b_reg) => unsafe {
283                read_element_u64(state.ext_state.read_vreg(), usize::from(*b_reg), i, sew)
284            },
285            OpSrc::Scalar(val) => *val,
286        };
287        let result = op(acc, a, b, sew);
288        // SAFETY: register bounds verified by caller
289        unsafe {
290            write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
291        }
292    }
293    state.ext_state.mark_vs_dirty();
294    state.ext_state.reset_vstart();
295}
296
297/// Execute a single-width multiply-add where the first multiplier is a scalar.
298///
299/// Analogous to [`execute_muladd_op`] but `a` is a fixed scalar instead of a register element.
300///
301/// # Safety
302/// Same as [`execute_muladd_op`], minus constraints on `a_reg`.
303#[inline(always)]
304#[expect(clippy::too_many_arguments, reason = "Internal API")]
305#[doc(hidden)]
306pub unsafe fn execute_muladd_scalar_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
307    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
308    vd: VReg,
309    scalar: u64,
310    src: OpSrc,
311    vm: bool,
312    vl: u32,
313    vstart: u32,
314    sew: Vsew,
315    op: F,
316) where
317    Reg: Register,
318    [(); Reg::N]:,
319    ExtState: VectorRegistersExt<Reg, CustomError>,
320    [(); ExtState::ELEN as usize]:,
321    [(); ExtState::VLEN as usize]:,
322    [(); ExtState::VLENB as usize]:,
323    Memory: VirtualMemory,
324    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
325    CustomError: fmt::Debug,
326    F: Fn(u64, u64, u64, Vsew) -> u64,
327{
328    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
329    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
330    let vd_base = vd.bits();
331    for i in vstart..vl {
332        if !mask_bit(&mask_buf, i) {
333            continue;
334        }
335        // SAFETY: register bounds verified by caller
336        let acc =
337            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vd_base), i, sew) };
338        let b = match &src {
339            // SAFETY: register bounds verified by caller
340            OpSrc::Vreg(b_reg) => unsafe {
341                read_element_u64(state.ext_state.read_vreg(), usize::from(*b_reg), i, sew)
342            },
343            OpSrc::Scalar(val) => *val,
344        };
345        let result = op(acc, scalar, b, sew);
346        // SAFETY: register bounds verified by caller
347        unsafe {
348            write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
349        }
350    }
351    state.ext_state.mark_vs_dirty();
352    state.ext_state.reset_vstart();
353}
354
355/// Execute a widening multiply-add where the first multiplier is a vector register group.
356///
357/// Reads SEW-wide `acc` from the widened `vd` group, SEW-wide `a` from `a_reg`, and SEW-wide
358/// `b` from `src`. Writes a 2*SEW-wide result back into `vd`.
359///
360/// `op` receives `(acc: u64, a: u64, b: u64, sew: Vsew)`.
361///
362/// # Safety
363/// - `vd` uses `dest_group_regs` registers (result of `widening_dest_register_count()`); alignment
364///   and non-overlap verified by caller
365/// - SEW < 64 verified by caller
366/// - When `vm=false`: `vd.bits() != 0`
367#[inline(always)]
368#[expect(clippy::too_many_arguments, reason = "Internal API")]
369#[doc(hidden)]
370pub unsafe fn execute_widening_muladd_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
371    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
372    vd: VReg,
373    a_reg: u8,
374    src: OpSrc,
375    vm: bool,
376    vl: u32,
377    vstart: u32,
378    sew: Vsew,
379    op: F,
380) where
381    Reg: Register,
382    [(); Reg::N]:,
383    ExtState: VectorRegistersExt<Reg, CustomError>,
384    [(); ExtState::ELEN as usize]:,
385    [(); ExtState::VLEN as usize]:,
386    [(); ExtState::VLENB as usize]:,
387    Memory: VirtualMemory,
388    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
389    CustomError: fmt::Debug,
390    F: Fn(u64, u64, u64, Vsew) -> u64,
391{
392    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
393    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
394    let vd_base = vd.bits();
395    for i in vstart..vl {
396        if !mask_bit(&mask_buf, i) {
397            continue;
398        }
399        // Read the existing 2*SEW accumulator from vd
400        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
401        // `execute_widening_op` for the bound argument)
402        let acc = unsafe {
403            read_wide_element_u64(state.ext_state.read_vreg(), usize::from(vd_base), i, sew)
404        };
405        // SAFETY: register bounds verified by caller
406        let a =
407            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(a_reg), i, sew) };
408        let b = match &src {
409            // SAFETY: register bounds verified by caller
410            OpSrc::Vreg(b_reg) => unsafe {
411                read_element_u64(state.ext_state.read_vreg(), usize::from(*b_reg), i, sew)
412            },
413            OpSrc::Scalar(val) => *val,
414        };
415        let result = op(acc, a, b, sew);
416        // SAFETY: same as acc read above
417        unsafe {
418            write_wide_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
419        }
420    }
421    state.ext_state.mark_vs_dirty();
422    state.ext_state.reset_vstart();
423}
424
425/// Execute a widening multiply-add where the first multiplier is a scalar.
426///
427/// Analogous to [`execute_widening_muladd_op`] but `a` is a fixed scalar.
428///
429/// # Safety
430/// Same as [`execute_widening_muladd_op`], minus constraints on `a_reg`.
431#[inline(always)]
432#[expect(clippy::too_many_arguments, reason = "Internal API")]
433#[doc(hidden)]
434pub unsafe fn execute_widening_muladd_scalar_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
435    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
436    vd: VReg,
437    scalar: u64,
438    src: OpSrc,
439    vm: bool,
440    vl: u32,
441    vstart: u32,
442    sew: Vsew,
443    op: F,
444) where
445    Reg: Register,
446    [(); Reg::N]:,
447    ExtState: VectorRegistersExt<Reg, CustomError>,
448    [(); ExtState::ELEN as usize]:,
449    [(); ExtState::VLEN as usize]:,
450    [(); ExtState::VLENB as usize]:,
451    Memory: VirtualMemory,
452    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
453    CustomError: fmt::Debug,
454    F: Fn(u64, u64, u64, Vsew) -> u64,
455{
456    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
457    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
458    let vd_base = vd.bits();
459    for i in vstart..vl {
460        if !mask_bit(&mask_buf, i) {
461            continue;
462        }
463        // SAFETY: vd has dest_group_regs registers; element `i` fits within them (see
464        // `execute_widening_op` for the bound argument)
465        let acc = unsafe {
466            read_wide_element_u64(state.ext_state.read_vreg(), usize::from(vd_base), i, sew)
467        };
468        let b = match &src {
469            // SAFETY: register bounds verified by caller
470            OpSrc::Vreg(b_reg) => unsafe {
471                read_element_u64(state.ext_state.read_vreg(), usize::from(*b_reg), i, sew)
472            },
473            OpSrc::Scalar(val) => *val,
474        };
475        let result = op(acc, scalar, b, sew);
476        // SAFETY: same as acc read above
477        unsafe {
478            write_wide_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
479        }
480    }
481    state.ext_state.mark_vs_dirty();
482    state.ext_state.reset_vstart();
483}
484
485/// Signed × signed high half.
486///
487/// Both operands are sign-extended to i64, multiplied as i128, and the upper SEW bits of the
488/// 2*SEW product are returned (zero-extended to u64 for writeback into a SEW-wide element slot).
489#[inline(always)]
490#[doc(hidden)]
491pub fn mulh_ss(a: u64, b: u64, sew: Vsew) -> u64 {
492    let sa = i128::from(sign_extend(a, sew));
493    let sb = i128::from(sign_extend(b, sew));
494    let product = sa.wrapping_mul(sb);
495    // Extract bits [2*SEW-1 : SEW] of the product
496    let high = (product >> u32::from(sew.bits())).cast_unsigned() as u64;
497    high & sew_mask(sew)
498}
499
500/// Unsigned × unsigned high half
501#[inline(always)]
502#[doc(hidden)]
503pub fn mulhu_uu(a: u64, b: u64, sew: Vsew) -> u64 {
504    let ua = u128::from(a & sew_mask(sew));
505    let ub = u128::from(b & sew_mask(sew));
506    let product = ua.wrapping_mul(ub);
507    let high = (product >> u32::from(sew.bits())) as u64;
508    high & sew_mask(sew)
509}
510
511/// Signed × unsigned high half.
512///
513/// `a` (vs2) is the signed operand; `b` (vs1/rs1) is the unsigned operand.
514#[inline(always)]
515#[doc(hidden)]
516pub fn mulhsu_su(a: u64, b: u64, sew: Vsew) -> u64 {
517    let sa = i128::from(sign_extend(a, sew));
518    let ub = u128::from(b & sew_mask(sew));
519    // Compute signed × unsigned as i128 to preserve sign
520    let product = sa.wrapping_mul(ub.cast_signed());
521    let high = (product >> u32::from(sew.bits())).cast_unsigned() as u64;
522    high & sew_mask(sew)
523}
524
525/// Signed divide with division-by-zero and signed-overflow semantics from the RISC-V V spec §12.11.
526///
527/// - Division by zero: result = all-ones (i.e., −1 as signed SEW-wide integer)
528/// - Signed overflow (MIN / −1): result = MIN (i.e., `1 << (SEW-1)`)
529#[inline(always)]
530#[doc(hidden)]
531pub fn sdiv(a: u64, b: u64, sew: Vsew) -> u64 {
532    let sa = sign_extend(a, sew);
533    let sb = sign_extend(b, sew);
534    // Division by zero: return all-ones in the SEW-wide slot (= −1 signed)
535    if sb == 0 {
536        return sew_mask(sew);
537    }
538    // Signed overflow: MIN / -1 returns MIN
539    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits()));
540    if sa == sew_min && sb == -1 {
541        return sew_min.cast_unsigned() & sew_mask(sew);
542    }
543    (sa / sb).cast_unsigned() & sew_mask(sew)
544}
545
546/// Signed remainder with division-by-zero and signed-overflow semantics from the RISC-V V spec
547/// §12.11.
548///
549/// - Division by zero: remainder = dividend
550/// - Signed overflow (MIN % −1): remainder = 0
551#[inline(always)]
552#[doc(hidden)]
553pub fn srem(a: u64, b: u64, sew: Vsew) -> u64 {
554    let sa = sign_extend(a, sew);
555    let sb = sign_extend(b, sew);
556    // Division by zero: remainder = dividend
557    if sb == 0 {
558        return a & sew_mask(sew);
559    }
560    // Signed overflow: MIN % -1 = 0
561    let sew_min = i64::MIN >> (u64::BITS - u32::from(sew.bits()));
562    if sa == sew_min && sb == -1 {
563        return 0;
564    }
565    (sa % sb).cast_unsigned() & sew_mask(sew)
566}