Skip to main content

ab_riscv_interpreter/v/zve64x/fixed_point/
zve64x_fixed_point_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zve64x::arith::zve64x_arith_helpers::{
5    OpSrc, check_vreg_group_alignment as check_vd, check_vreg_group_alignment as check_vs, sew_mask,
6};
7use crate::v::zve64x::arith::zve64x_arith_helpers::{
8    read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, ProgramCounter};
13use ab_riscv_primitives::prelude::*;
14use core::fmt;
15
16/// Compute the rounding increment for a right shift of `val` by `shift` bits.
17///
18/// When `shift == 0` there are no fractional bits so the increment is always zero.
19/// `current_result_lsb` is the LSB of the truncated result, required for `Rne` and `Rod`.
20#[inline(always)]
21fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
22    if shift == 0 {
23        return 0;
24    }
25    // `d_minus1_bit`: the most-significant discarded bit (bit position `shift - 1`)
26    let d_minus1_bit = (val >> (shift - 1)) & 1;
27    // `sticky`: OR of all bits below position `shift - 1`
28    let sticky = if shift >= 2 {
29        // Any of bits [shift-2 : 0] set?
30        (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
31    } else {
32        false
33    };
34    match mode {
35        // Round nearest up: increment = v[d-1]
36        Vxrm::Rnu => d_minus1_bit,
37        // Round nearest even: increment = v[d-1] & (sticky | result_lsb)
38        Vxrm::Rne => {
39            d_minus1_bit
40                & (if sticky || current_result_lsb != 0 {
41                    1
42                } else {
43                    0
44                })
45        }
46        // Round down / truncate: never increment
47        Vxrm::Rdn => 0,
48        // Round to odd: set result LSB if any discarded bit was non-zero
49        Vxrm::Rod => {
50            if current_result_lsb == 0 && (d_minus1_bit != 0 || sticky) {
51                1
52            } else {
53                0
54            }
55        }
56    }
57}
58
59/// Perform a rounded right shift of `val` by `shift` bits (logical / unsigned).
60///
61/// Returns `(val >> shift) + round_increment`.
62#[inline(always)]
63#[doc(hidden)]
64pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
65    let truncated = val >> shift;
66    let r = round_increment(val, shift, mode, truncated & 1);
67    truncated.wrapping_add(r)
68}
69
70/// Perform a rounded arithmetic right shift of `val` (sign-extended to SEW) by `shift` bits.
71///
72/// Returns the SEW-wide signed result as `u64` (sign bits above SEW are meaningful).
73#[inline(always)]
74#[doc(hidden)]
75pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
76    let signed = sign_extend(val, sew);
77    // Treat the raw bits for rounding purposes: rounding uses the unsigned representation of the
78    // SEW-wide value (only bits below `shift` matter, so masking is not needed here since the
79    // discarded bits are the same regardless of sign extension).
80    let truncated_signed = signed >> shift;
81    let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
82    truncated_signed.cast_unsigned().wrapping_add(r)
83}
84
85/// Saturating unsigned add: `vs2 + src`, clamped to `[0, 2^SEW - 1]`.
86///
87/// Sets `vxsat` to `true` on overflow.
88#[inline(always)]
89#[doc(hidden)]
90pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
91    let mask = sew_mask(sew);
92    let a_w = a & mask;
93    let b_w = b & mask;
94    let result = a_w.wrapping_add(b_w);
95    if result & mask < a_w {
96        // Overflow: wrapped around
97        *vxsat = true;
98        mask
99    } else {
100        result & mask
101    }
102}
103
104/// Saturating signed add: `vs2 + src`, clamped to `[-(2^(SEW-1)), 2^(SEW-1) - 1]`.
105///
106/// Sets `vxsat` to `true` on overflow.
107#[inline(always)]
108#[doc(hidden)]
109pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
110    let sa = sign_extend(a, sew) as i128;
111    let sb = sign_extend(b, sew) as i128;
112    let result = sa.wrapping_add(sb);
113    let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits()));
114    let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits()));
115    if result < min_val {
116        *vxsat = true;
117        (min_val as i64).cast_unsigned() & sew_mask(sew)
118    } else if result > max_val {
119        *vxsat = true;
120        (max_val as i64).cast_unsigned() & sew_mask(sew)
121    } else {
122        (result as i64).cast_unsigned() & sew_mask(sew)
123    }
124}
125
126/// Saturating unsigned subtract: `vs2 - src`, clamped to `[0, 2^SEW - 1]`.
127///
128/// Sets `vxsat` to `true` on overflow (underflow to negative).
129#[inline(always)]
130#[doc(hidden)]
131pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
132    let mask = sew_mask(sew);
133    let a_w = a & mask;
134    let b_w = b & mask;
135    if a_w < b_w {
136        *vxsat = true;
137        0
138    } else {
139        (a_w - b_w) & mask
140    }
141}
142
143/// Saturating signed subtract: `vs2 - src`, clamped to `[-(2^(SEW-1)), 2^(SEW-1) - 1]`.
144///
145/// Sets `vxsat` to `true` on overflow.
146#[inline(always)]
147#[doc(hidden)]
148pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
149    let sa = sign_extend(a, sew) as i128;
150    let sb = sign_extend(b, sew) as i128;
151    let result = sa.wrapping_sub(sb);
152    let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits()));
153    let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits()));
154    if result < min_val {
155        *vxsat = true;
156        (min_val as i64).cast_unsigned() & sew_mask(sew)
157    } else if result > max_val {
158        *vxsat = true;
159        (max_val as i64).cast_unsigned() & sew_mask(sew)
160    } else {
161        (result as i64).cast_unsigned() & sew_mask(sew)
162    }
163}
164
165/// Averaging unsigned add: `(vs2 + src) >> 1` with rounding per `vxrm`.
166///
167/// Uses a 1-bit wider intermediate to avoid overflow; no saturation, no `vxsat`.
168#[inline(always)]
169#[doc(hidden)]
170pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
171    let mask = sew_mask(sew);
172    let a_w = a & mask;
173    let b_w = b & mask;
174    // Compute full sum in one extra bit by using u128 or by widening trick.
175    // Since SEW <= 64 and both operands are SEW-bit values, the sum fits in SEW+1 bits.
176    // Use wrapping_add: the carry out of bit SEW-1 is the extra bit.
177    let sum = a_w.wrapping_add(b_w);
178    // Carry: set if unsigned sum overflowed SEW bits
179    let carry = if sum & mask < a_w { 1u64 } else { 0u64 };
180    // Full (SEW+1)-bit value: `carry` is at bit position SEW, `sum & mask` are low SEW bits.
181    // We need `(carry:sum) >> 1` with rounding.
182    // Bit 0 of `sum & mask` is the rounding bit for the truncated division.
183    let r = round_increment(sum & mask, 1, mode, (sum >> 1) & 1);
184    // Shift the (SEW+1)-bit quantity right by 1: result = (carry << (SEW-1)) | ((sum & mask) >> 1)
185    let shifted = (carry << (sew.bits() as u32 - 1)) | ((sum & mask) >> 1);
186    (shifted.wrapping_add(r)) & mask
187}
188
189/// Averaging signed add: `(vs2 + src) >> 1` with rounding per `vxrm`.
190///
191/// No saturation, no `vxsat`.
192#[inline(always)]
193#[doc(hidden)]
194pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
195    let sa = sign_extend(a, sew);
196    let sb = sign_extend(b, sew);
197    // Full sum as i128 to avoid overflow
198    let sum = (sa as i128).wrapping_add(sb as i128);
199    // The low bit is the fractional bit for rounding
200    let r = match mode {
201        Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
202        Vxrm::Rne => {
203            // round-to-nearest-even: increment if fractional bit set AND (result LSB or sticky)
204            // For a single bit shift there are no lower sticky bits, so only check result LSB
205            let result_lsb = ((sum >> 1) & 1).cast_unsigned() as u64;
206            ((sum & 1).cast_unsigned() as u64) & result_lsb
207        }
208        Vxrm::Rdn => 0,
209        Vxrm::Rod => {
210            // Set result LSB if it would be 0 and the fractional bit is nonzero
211            let result_lsb = (sum >> 1) & 1;
212            if result_lsb == 0 && (sum & 1) != 0 {
213                1
214            } else {
215                0
216            }
217        }
218    };
219    let result = (sum >> 1) + r as i128;
220    (result as i64).cast_unsigned() & sew_mask(sew)
221}
222
223/// Averaging unsigned subtract: `(vs2 - src) >> 1` with rounding per `vxrm`.
224///
225/// No saturation, no `vxsat`.
226#[inline(always)]
227#[doc(hidden)]
228pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
229    let mask = sew_mask(sew);
230    let a_w = a & mask;
231    let b_w = b & mask;
232    // Compute difference with borrow using wrapping sub; borrow extends to SEW+1 bit.
233    let diff = a_w.wrapping_sub(b_w);
234    // Borrow: set if a < b (unsigned)
235    let borrow = if a_w < b_w { 1u64 } else { 0u64 };
236    // Full (SEW+1)-bit two's-complement difference:
237    // If borrow: the SEW-bit `diff` is correct (it wrapped), and the sign extension bit is 1.
238    // Rounding: bit 0 of diff is the fractional bit.
239    let r = round_increment(diff & mask, 1, mode, (diff >> 1) & 1);
240    // Arithmetic right shift by 1 of the (SEW+1)-bit signed value.
241    // For unsigned averaging subtract: result = ((SEW+1)-bit diff) / 2 with rounding.
242    // The (SEW+1)-bit value is: borrow is the sign bit. If borrow set, value is negative.
243    // Result = (borrow << SEW | diff) >> 1 (arithmetic) + r
244    // Arithmetic shift: sign bit (`borrow`) propagates.
245    let sign_fill = borrow.wrapping_neg(); // all ones if borrow set, zero otherwise
246    let shifted = (sign_fill << (sew.bits() as u32 - 1)) | ((diff & mask) >> 1);
247    (shifted.wrapping_add(r)) & mask
248}
249
250/// Averaging signed subtract: `(vs2 - src) >> 1` with rounding per `vxrm`.
251///
252/// No saturation, no `vxsat`.
253#[inline(always)]
254#[doc(hidden)]
255pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
256    let sa = sign_extend(a, sew);
257    let sb = sign_extend(b, sew);
258    let diff = (sa as i128).wrapping_sub(sb as i128);
259    let r = match mode {
260        Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
261        Vxrm::Rne => {
262            let result_lsb = ((diff >> 1) & 1).cast_unsigned() as u64;
263            ((diff & 1).cast_unsigned() as u64) & result_lsb
264        }
265        Vxrm::Rdn => 0,
266        Vxrm::Rod => {
267            let result_lsb = (diff >> 1) & 1;
268            if result_lsb == 0 && (diff & 1) != 0 {
269                1
270            } else {
271                0
272            }
273        }
274    };
275    let result = (diff >> 1) + r as i128;
276    (result as i64).cast_unsigned() & sew_mask(sew)
277}
278
279/// Fractional multiply with rounding and saturation: `vsmul`.
280///
281/// Computes `(a * b * 2 + rounding) >> SEW`, saturating at the signed maximum when the
282/// product of two minimum signed values overflows (`INT_MIN * INT_MIN`).
283///
284/// Per spec §12.4: `vd[i] = clip(roundoff_signed(vs2[i] * vs1[i] * 2, SEW))`.
285/// Sets `vxsat` on overflow.
286#[inline(always)]
287#[doc(hidden)]
288pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
289    // SEW-wide signed min and max in i64 (valid for all SEW <= 64)
290    let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits()));
291    let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits()));
292    let sa = i128::from(sign_extend(a, sew));
293    let sb = i128::from(sign_extend(b, sew));
294    // The only case where `product * 2` overflows a 2*SEW signed result is INT_MIN * INT_MIN.
295    // Detect this before any multiply: for SEW=64 INT64_MIN^2 = 2^126 and <<1 would overflow i128.
296    if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
297        *vxsat = true;
298        return max_sew.cast_unsigned() & sew_mask(sew);
299    }
300    // Full 2*SEW-bit product; no overflow possible because at least one operand != INT_MIN,
301    // so |product| < INT_MIN^2 and the value fits in i128 for SEW <= 64.
302    let product = sa * sb;
303    // Left shift by 1 for the Q-format fractional interpretation; safe because
304    // |product| < INT_MIN^2, so after <<1 the result still fits in i128 for SEW <= 64.
305    let doubled = product << 1;
306    // Extract the low SEW bits (the discarded portion) for rounding.
307    // Cast to u128 first to avoid sign-extension contaminating the mask.
308    let shift = u32::from(sew.bits());
309    let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
310    // Arithmetic right shift by SEW gives the truncated signed result in SEW-wide range.
311    let truncated = doubled >> shift;
312    let r = round_increment(
313        low_bits,
314        shift.min(64),
315        mode,
316        (truncated.cast_unsigned() as u64) & 1,
317    );
318    // `truncated` fits in i64 after the SEW-bit shift (it is a SEW-wide signed value).
319    let result = (truncated as i64).wrapping_add(r.cast_signed());
320    // Clamp to SEW-wide signed range (only reachable if rounding pushed the value over)
321    if result < min_sew {
322        *vxsat = true;
323        min_sew.cast_unsigned() & sew_mask(sew)
324    } else if result > max_sew {
325        *vxsat = true;
326        max_sew.cast_unsigned() & sew_mask(sew)
327    } else {
328        result.cast_unsigned() & sew_mask(sew)
329    }
330}
331
332/// Narrowing unsigned clip: read a 2*SEW element from `vs2`, shift right by `shamt` with
333/// rounding, saturate to unsigned SEW range, set `vxsat` on clamp.
334///
335/// `vs2_elem` is the 2*SEW-bit element (zero-extended to u64 for SEW <= 32;
336/// for SEW = 64 the doubled width would be 128 bits, but Zve64x only supports SEW up to 64 and
337/// the narrowing destination is at most 64 bits wide, so 2*SEW = 128 - however the spec requires
338/// `ELEN >= 2*SEW` for narrowing instructions. Since `ELEN = 64` in Zve64x, narrowing is only
339/// valid for SEW <= 32 (`2*SEW <= 64`).  The caller must enforce this constraint by checking
340/// `vsew` before invoking narrowing operations.
341///
342/// `vs2_elem` is passed as `u64`; for SEW = 32 it holds a 64-bit (2*SEW) value.
343#[inline(always)]
344#[doc(hidden)]
345pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
346    // Shift right with rounding
347    let shifted = rounded_srl(vs2_elem, shamt, mode);
348    // Saturate to destination SEW unsigned range [0, 2^SEW - 1]
349    let max_dst = sew_mask(sew);
350    if shifted > max_dst {
351        *vxsat = true;
352        max_dst
353    } else {
354        shifted & max_dst
355    }
356}
357
358/// Narrowing signed clip: read a 2*SEW signed element from `vs2`, shift right arithmetically
359/// with rounding, saturate to signed SEW range.
360///
361/// Same SEW constraint as [`nclipu`].
362#[inline(always)]
363#[doc(hidden)]
364pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
365    // Sign-extend vs2_elem to full i64 treating it as a 2*SEW-bit signed value.
366    // For SEW=8 the source is 16-bit, for SEW=16 it is 32-bit, for SEW=32 it is 64-bit.
367    let double_sew_bits = sew.bits() * 2;
368    let shift_amt = i64::BITS - u32::from(double_sew_bits);
369    let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
370    // Arithmetic right shift with rounding
371    // For rounding we need the raw low bits of the wide value before shifting
372    let low_bits = signed_wide.cast_unsigned()
373        & if double_sew_bits == 64 {
374            u64::MAX
375        } else {
376            (1u64 << double_sew_bits) - 1
377        };
378    let truncated = signed_wide >> shamt;
379    let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
380    let rounded = truncated.wrapping_add(r.cast_signed());
381    // Saturate to signed SEW range
382    let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits()));
383    let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits()));
384    if rounded < min_dst {
385        *vxsat = true;
386        min_dst.cast_unsigned() & sew_mask(sew)
387    } else if rounded > max_dst {
388        *vxsat = true;
389        max_dst.cast_unsigned() & sew_mask(sew)
390    } else {
391        rounded.cast_unsigned() & sew_mask(sew)
392    }
393}
394
395/// Read a 2*SEW-wide element as `u64` from the double-width source register group of a narrowing
396/// instruction.
397///
398/// For narrowing instructions `vs2` holds elements of width `2*SEW`. The register group size is
399/// `2 * group_regs`. Element `i` of width `2*SEW` is located in the same way as a SEW-wide
400/// element of width `2*SEW` (i.e., treating `2*SEW` as the element width). For `SEW = 32` this
401/// reads 64-bit elements; for `SEW <= 16` it reads narrower elements but zero-extends to `u64`.
402///
403/// # Safety
404/// - `2*SEW <= 64` (Zve64x constraint: only valid for SEW <= 32; caller must verify)
405/// - `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32`
406#[inline(always)]
407pub unsafe fn read_wide_element_u64<const VLENB: usize>(
408    vreg: &[[u8; VLENB]; 32],
409    base_reg: usize,
410    elem_i: u32,
411    sew: Vsew,
412) -> u64 {
413    let double_sew_bytes = usize::from(sew.bytes()) * 2;
414    let elems_per_reg = VLENB / double_sew_bytes;
415    let reg_off = elem_i as usize / elems_per_reg;
416    let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
417    // SAFETY: caller guarantees bounds
418    let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
419    // SAFETY: `byte_off + double_sew_bytes <= VLENB`
420    let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
421    let mut buf = [0u8; 8];
422    // SAFETY: `double_sew_bytes <= 8` (SEW <= 32 for Zve64x narrowing)
423    unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
424    u64::from_le_bytes(buf)
425}
426
427/// Execute a single-width fixed-point arithmetic operation that may set `vxsat`.
428///
429/// `op` receives `(vs2_elem, src_elem, sew, vxrm)` and returns `(result, saturated)`.
430/// The helper ORs any saturation flag into `vxsat` after the loop.
431///
432/// # Safety
433/// Same preconditions as `execute_arith_op` in the arithmetic helpers.
434#[inline(always)]
435#[expect(clippy::too_many_arguments, reason = "Internal API")]
436#[doc(hidden)]
437pub unsafe fn execute_fixed_point_op<Reg, ExtState, CustomError, F>(
438    ext_state: &mut ExtState,
439    vd: VReg,
440    vs2: VReg,
441    src: OpSrc,
442    vm: bool,
443    vl: u32,
444    vstart: u32,
445    sew: Vsew,
446    op: F,
447) where
448    Reg: Register,
449    ExtState: VectorRegistersExt<Reg, CustomError>,
450    [(); ExtState::ELEN as usize]:,
451    [(); ExtState::VLEN as usize]:,
452    [(); ExtState::VLENB as usize]:,
453    CustomError: fmt::Debug,
454    // op: (vs2_elem, src_elem, sew, vxrm) -> result
455    F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
456{
457    let vxrm = ext_state.vxrm();
458    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
459    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
460    let vd_base = vd.bits();
461    let vs2_base = vs2.bits();
462    let mut any_sat = false;
463    for i in vstart..vl {
464        if !mask_bit(&mask_buf, i) {
465            continue;
466        }
467        // SAFETY: alignment and bounds checked by caller
468        let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
469        let b = match &src {
470            OpSrc::Vreg(vs1_base) => {
471                // SAFETY: same argument as vs2
472                unsafe { read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew) }
473            }
474            OpSrc::Scalar(val) => *val,
475        };
476        let result = op(a, b, sew, vxrm, &mut any_sat);
477        // SAFETY: alignment and bounds checked by caller
478        unsafe {
479            write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
480        }
481    }
482    if any_sat {
483        // vxsat is sticky: OR in the new saturation flag
484        ext_state.set_vxsat(true);
485    }
486    ext_state.mark_vs_dirty();
487    ext_state.reset_vstart();
488}
489
490/// Execute a narrowing fixed-point clip operation.
491///
492/// `vs2` holds a double-width register group (2x `group_regs` registers). `vd` holds the
493/// single-width destination. `src` provides the shift amount (Vreg or Scalar).
494///
495/// For Zve64x narrowing instructions, `SEW` must be at most 32 because `2*SEW` must fit in 64
496/// bits. The caller must verify this constraint before invoking this function.
497///
498/// # Safety
499/// - `sew.bits() <= 32` (Zve64x ELEN = 64 constraint for narrowing)
500/// - `vs2.bits() % (2 * group_regs) == 0` and `vs2.bits() + 2 * group_regs <= 32`
501/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32`
502/// - `vl <= group_regs * VLENB / sew_bytes`
503/// - When `vm=false`: `vd.bits() != 0`
504#[inline(always)]
505#[expect(clippy::too_many_arguments, reason = "Internal API")]
506#[doc(hidden)]
507pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, CustomError, F>(
508    ext_state: &mut ExtState,
509    vd: VReg,
510    vs2: VReg,
511    src: OpSrc,
512    vm: bool,
513    vl: u32,
514    vstart: u32,
515    sew: Vsew,
516    op: F,
517) where
518    Reg: Register,
519    ExtState: VectorRegistersExt<Reg, CustomError>,
520    [(); ExtState::ELEN as usize]:,
521    [(); ExtState::VLEN as usize]:,
522    [(); ExtState::VLENB as usize]:,
523    CustomError: fmt::Debug,
524    // op: (vs2_wide_elem, shamt, sew, vxrm, vxsat) -> result
525    F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
526{
527    let vxrm = ext_state.vxrm();
528    // SAFETY: `vl <= VLEN`
529    let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
530    let vd_base = vd.bits();
531    let vs2_base = vs2.bits();
532    let mut any_sat = false;
533    // Mask shift amount to log2(2*SEW) bits per spec §12.11
534    let shamt_mask = u64::from(sew.bits() * 2 - 1);
535    for i in vstart..vl {
536        if !mask_bit(&mask_buf, i) {
537            continue;
538        }
539        // Read 2*SEW-wide source element
540        // SAFETY: `vs2` double-width alignment checked by caller
541        let wide_a =
542            unsafe { read_wide_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
543        let shamt = match &src {
544            OpSrc::Vreg(vs1_base) => {
545                // SAFETY: vs1 SEW-wide alignment checked by caller
546                let raw = unsafe {
547                    read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
548                };
549                (raw & shamt_mask) as u32
550            }
551            OpSrc::Scalar(val) => (*val & shamt_mask) as u32,
552        };
553        let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
554        // SAFETY: `vd` alignment checked by caller
555        unsafe {
556            write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
557        }
558    }
559    if any_sat {
560        ext_state.set_vxsat(true);
561    }
562    ext_state.mark_vs_dirty();
563    ext_state.reset_vstart();
564}
565
566/// Verify that the destination SEW is valid for narrowing (must be at most 32 in Zve64x).
567///
568/// Returns `Err(IllegalInstruction)` when `sew.bits() > 32`.
569#[inline(always)]
570#[doc(hidden)]
571pub fn check_narrowing_sew<Reg, Memory, PC, CustomError>(
572    program_counter: &PC,
573    sew: Vsew,
574) -> Result<(), ExecutionError<Reg::Type, CustomError>>
575where
576    Reg: Register,
577    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
578{
579    if sew.bits() > 32 {
580        return Err(ExecutionError::IllegalInstruction {
581            address: program_counter.old_pc(INSTRUCTION_SIZE),
582        });
583    }
584    Ok(())
585}
586
587/// Check that `vs2` for a narrowing instruction is aligned to `2 * group_regs` and fits in [0,32).
588#[inline(always)]
589#[doc(hidden)]
590pub fn check_vs2_narrowing_alignment<Reg, Memory, PC, CustomError>(
591    program_counter: &PC,
592    vs2: VReg,
593    group_regs: u8,
594) -> Result<(), ExecutionError<Reg::Type, CustomError>>
595where
596    Reg: Register,
597    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
598{
599    // Per v-spec §5.2: narrowing requires EMUL_src = 2*LMUL <= 8.
600    // LMUL=8 with a narrowing instruction is reserved.
601    if group_regs > 4 {
602        return Err(ExecutionError::IllegalInstruction {
603            address: program_counter.old_pc(INSTRUCTION_SIZE),
604        });
605    }
606    let double_group = group_regs * 2;
607    let vs2_idx = vs2.bits();
608    if !vs2_idx.is_multiple_of(double_group) || vs2_idx + double_group > 32 {
609        return Err(ExecutionError::IllegalInstruction {
610            address: program_counter.old_pc(INSTRUCTION_SIZE),
611        });
612    }
613    Ok(())
614}