Skip to main content

ab_riscv_interpreter/v/zvexx/fixed_point/
zvexx_fixed_point_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5    OpSrc, check_vreg_group_alignment, sew_mask,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8    read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, ProgramCounter};
13use ab_riscv_primitives::prelude::*;
14use core::fmt;
15use core::hint::cold_path;
16
17/// Compute the rounding increment for a right shift of `val` by `shift` bits.
18///
19/// When `shift == 0` there are no fractional bits so the increment is always zero.
20/// `current_result_lsb` is the LSB of the truncated result, required for `Rne` and `Rod`.
21#[inline(always)]
22fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
23    if shift == 0 {
24        return 0;
25    }
26    // `d_minus1_bit`: the most-significant discarded bit (bit position `shift - 1`)
27    let d_minus1_bit = (val >> (shift - 1)) & 1;
28    // `sticky`: OR of all bits below position `shift - 1`
29    let sticky = if shift >= 2 {
30        // Any of bits [shift-2 : 0] set?
31        (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
32    } else {
33        false
34    };
35    match mode {
36        // Round nearest up: increment = v[d-1]
37        Vxrm::Rnu => d_minus1_bit,
38        // Round nearest even: increment = v[d-1] & (sticky | result_lsb)
39        Vxrm::Rne => d_minus1_bit & u64::from(sticky || current_result_lsb != 0),
40        // Round down / truncate: never increment
41        Vxrm::Rdn => 0,
42        // Round to odd: set result LSB if any discarded bit was non-zero
43        Vxrm::Rod => u64::from(current_result_lsb == 0 && (d_minus1_bit != 0 || sticky)),
44    }
45}
46
47/// Perform a rounded right shift of `val` by `shift` bits (logical / unsigned).
48///
49/// Returns `(val >> shift) + round_increment`.
50#[inline(always)]
51#[doc(hidden)]
52pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
53    let truncated = val >> shift;
54    let r = round_increment(val, shift, mode, truncated & 1);
55    truncated.wrapping_add(r)
56}
57
58/// Perform a rounded arithmetic right shift of `val` (sign-extended to SEW) by `shift` bits.
59///
60/// Returns the SEW-wide signed result as `u64` (sign bits above SEW are meaningful).
61#[inline(always)]
62#[doc(hidden)]
63pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
64    let signed = sign_extend(val, sew);
65    // Treat the raw bits for rounding purposes: rounding uses the unsigned representation of the
66    // SEW-wide value (only bits below `shift` matter, so masking is not needed here since the
67    // discarded bits are the same regardless of sign extension).
68    let truncated_signed = signed >> shift;
69    let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
70    truncated_signed.cast_unsigned().wrapping_add(r)
71}
72
73/// Saturating unsigned add: `vs2 + src`, clamped to `[0, 2^SEW - 1]`.
74///
75/// Sets `vxsat` to `true` on overflow.
76#[inline(always)]
77#[doc(hidden)]
78pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
79    let mask = sew_mask(sew);
80    let a_w = a & mask;
81    let b_w = b & mask;
82    let result = a_w.wrapping_add(b_w);
83    if result & mask < a_w {
84        // Overflow: wrapped around
85        *vxsat = true;
86        mask
87    } else {
88        result & mask
89    }
90}
91
92/// Saturating signed add: `vs2 + src`, clamped to `[-(2^(SEW-1)), 2^(SEW-1) - 1]`.
93///
94/// Sets `vxsat` to `true` on overflow.
95#[inline(always)]
96#[doc(hidden)]
97pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
98    let sa = i128::from(sign_extend(a, sew));
99    let sb = i128::from(sign_extend(b, sew));
100    let result = sa.wrapping_add(sb);
101    let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
102    let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
103    if result < min_val {
104        *vxsat = true;
105        (min_val as i64).cast_unsigned() & sew_mask(sew)
106    } else if result > max_val {
107        *vxsat = true;
108        (max_val as i64).cast_unsigned() & sew_mask(sew)
109    } else {
110        (result as i64).cast_unsigned() & sew_mask(sew)
111    }
112}
113
114/// Saturating unsigned subtract: `vs2 - src`, clamped to `[0, 2^SEW - 1]`.
115///
116/// Sets `vxsat` to `true` on overflow (underflow to negative).
117#[inline(always)]
118#[doc(hidden)]
119pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
120    let mask = sew_mask(sew);
121    let a_w = a & mask;
122    let b_w = b & mask;
123    if a_w < b_w {
124        *vxsat = true;
125        0
126    } else {
127        (a_w - b_w) & mask
128    }
129}
130
131/// Saturating signed subtract: `vs2 - src`, clamped to `[-(2^(SEW-1)), 2^(SEW-1) - 1]`.
132///
133/// Sets `vxsat` to `true` on overflow.
134#[inline(always)]
135#[doc(hidden)]
136pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
137    let sa = i128::from(sign_extend(a, sew));
138    let sb = i128::from(sign_extend(b, sew));
139    let result = sa.wrapping_sub(sb);
140    let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
141    let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
142    if result < min_val {
143        *vxsat = true;
144        (min_val as i64).cast_unsigned() & sew_mask(sew)
145    } else if result > max_val {
146        *vxsat = true;
147        (max_val as i64).cast_unsigned() & sew_mask(sew)
148    } else {
149        (result as i64).cast_unsigned() & sew_mask(sew)
150    }
151}
152
153/// Averaging unsigned add: `(vs2 + src) >> 1` with rounding per `vxrm`.
154///
155/// Uses a 1-bit wider intermediate to avoid overflow; no saturation, no `vxsat`.
156#[inline(always)]
157#[doc(hidden)]
158pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
159    let mask = sew_mask(sew);
160    let a_w = a & mask;
161    let b_w = b & mask;
162    // Compute full sum in one extra bit by using u128 or by widening trick.
163    // Since SEW <= 64 and both operands are SEW-bit values, the sum fits in SEW+1 bits.
164    // Use wrapping_add: the carry out of bit SEW-1 is the extra bit.
165    let sum = a_w.wrapping_add(b_w);
166    // Carry: set if unsigned sum overflowed SEW bits
167    let carry = u64::from(sum & mask < a_w);
168    // Full (SEW+1)-bit value: `carry` is at bit position SEW, `sum & mask` are low SEW bits.
169    // We need `(carry:sum) >> 1` with rounding.
170    // Bit 0 of `sum & mask` is the rounding bit for the truncated division.
171    let r = round_increment(sum & mask, 1, mode, (sum >> 1u8) & 1);
172    // Shift the (SEW+1)-bit quantity right by 1: result = (carry << (SEW-1)) | ((sum & mask) >> 1)
173    let shifted = (carry << (u32::from(sew.bits_width()) - 1)) | ((sum & mask) >> 1u8);
174    (shifted.wrapping_add(r)) & mask
175}
176
177/// Averaging signed add: `(vs2 + src) >> 1` with rounding per `vxrm`.
178///
179/// No saturation, no `vxsat`.
180#[inline(always)]
181#[doc(hidden)]
182pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
183    let sa = sign_extend(a, sew);
184    let sb = sign_extend(b, sew);
185    // Full sum as i128 to avoid overflow
186    let sum = i128::from(sa).wrapping_add(i128::from(sb));
187    // The low bit is the fractional bit for rounding
188    let r = match mode {
189        Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
190        Vxrm::Rne => {
191            // round-to-nearest-even: increment if fractional bit set AND (result LSB or sticky)
192            // For a single bit shift there are no lower sticky bits, so only check result LSB
193            let result_lsb = ((sum >> 1u8) & 1).cast_unsigned() as u64;
194            ((sum & 1).cast_unsigned() as u64) & result_lsb
195        }
196        Vxrm::Rdn => 0,
197        Vxrm::Rod => {
198            // Set result LSB if it would be 0 and the fractional bit is nonzero
199            let result_lsb = (sum >> 1u8) & 1;
200            u64::from(result_lsb == 0 && (sum & 1) != 0)
201        }
202    };
203    let result = (sum >> 1u8) + i128::from(r);
204    (result as i64).cast_unsigned() & sew_mask(sew)
205}
206
207/// Averaging unsigned subtract: `(vs2 - src) >> 1` with rounding per `vxrm`.
208///
209/// No saturation, no `vxsat`.
210#[inline(always)]
211#[doc(hidden)]
212pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
213    let mask = sew_mask(sew);
214    let a_w = a & mask;
215    let b_w = b & mask;
216    // Compute difference with borrow using wrapping sub; borrow extends to SEW+1 bit.
217    let diff = a_w.wrapping_sub(b_w);
218    // Borrow: set if a < b (unsigned)
219    let borrow = u64::from(a_w < b_w);
220    // Full (SEW+1)-bit two's-complement difference:
221    // If borrow: the SEW-bit `diff` is correct (it wrapped), and the sign extension bit is 1.
222    // Rounding: bit 0 of diff is the fractional bit.
223    let r = round_increment(diff & mask, 1, mode, (diff >> 1u8) & 1);
224    // Arithmetic right shift by 1 of the (SEW+1)-bit signed value.
225    // For unsigned averaging subtract: result = ((SEW+1)-bit diff) / 2 with rounding.
226    // The (SEW+1)-bit value is: borrow is the sign bit. If borrow set, value is negative.
227    // Result = (borrow << SEW | diff) >> 1 (arithmetic) + r
228    // Arithmetic shift: sign bit (`borrow`) propagates.
229    let sign_fill = borrow.wrapping_neg(); // all ones if borrow set, zero otherwise
230    let shifted = (sign_fill << (u32::from(sew.bits_width()) - 1)) | ((diff & mask) >> 1u8);
231    (shifted.wrapping_add(r)) & mask
232}
233
234/// Averaging signed subtract: `(vs2 - src) >> 1` with rounding per `vxrm`.
235///
236/// No saturation, no `vxsat`.
237#[inline(always)]
238#[doc(hidden)]
239pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
240    let sa = sign_extend(a, sew);
241    let sb = sign_extend(b, sew);
242    let diff = i128::from(sa).wrapping_sub(i128::from(sb));
243    let r = match mode {
244        Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
245        Vxrm::Rne => {
246            let result_lsb = ((diff >> 1u8) & 1).cast_unsigned() as u64;
247            ((diff & 1).cast_unsigned() as u64) & result_lsb
248        }
249        Vxrm::Rdn => 0,
250        Vxrm::Rod => {
251            let result_lsb = (diff >> 1u8) & 1;
252            u64::from(result_lsb == 0 && (diff & 1) != 0)
253        }
254    };
255    let result = (diff >> 1u8) + i128::from(r);
256    (result as i64).cast_unsigned() & sew_mask(sew)
257}
258
259/// Fractional multiply with rounding and saturation: `vsmul`.
260///
261/// Computes `(a * b * 2 + rounding) >> SEW`, saturating at the signed maximum when the
262/// product of two minimum signed values overflows (`INT_MIN * INT_MIN`).
263///
264/// Per spec §12.4: `vd[i] = clip(roundoff_signed(vs2[i] * vs1[i] * 2, SEW))`.
265/// Sets `vxsat` on overflow.
266#[inline(always)]
267#[doc(hidden)]
268pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
269    // SEW-wide signed min and max in i64 (valid for all SEW <= 64)
270    let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
271    let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
272    let sa = i128::from(sign_extend(a, sew));
273    let sb = i128::from(sign_extend(b, sew));
274    // The only case where `product * 2` overflows a 2*SEW signed result is INT_MIN * INT_MIN.
275    // Detect this before any multiply: for SEW=64 INT64_MIN^2 = 2^126 and <<1 would overflow i128.
276    if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
277        cold_path();
278        *vxsat = true;
279        return max_sew.cast_unsigned() & sew_mask(sew);
280    }
281    // Full 2*SEW-bit product; no overflow possible because at least one operand != INT_MIN,
282    // so |product| < INT_MIN^2 and the value fits in i128 for SEW <= 64.
283    let product = sa * sb;
284    // Left shift by 1 for the Q-format fractional interpretation; safe because
285    // |product| < INT_MIN^2, so after <<1 the result still fits in i128 for SEW <= 64.
286    let doubled = product << 1u8;
287    // Extract the low SEW bits (the discarded portion) for rounding.
288    // Cast to u128 first to avoid sign-extension contaminating the mask.
289    let shift = u32::from(sew.bits_width());
290    let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
291    // Arithmetic right shift by SEW gives the truncated signed result in SEW-wide range.
292    let truncated = doubled >> shift;
293    let r = round_increment(
294        low_bits,
295        shift.min(64),
296        mode,
297        (truncated.cast_unsigned() as u64) & 1,
298    );
299    // `truncated` fits in i64 after the SEW-bit shift (it is a SEW-wide signed value).
300    let result = (truncated as i64).wrapping_add(r.cast_signed());
301    // Clamp to SEW-wide signed range (only reachable if rounding pushed the value over)
302    if result < min_sew {
303        *vxsat = true;
304        min_sew.cast_unsigned() & sew_mask(sew)
305    } else if result > max_sew {
306        *vxsat = true;
307        max_sew.cast_unsigned() & sew_mask(sew)
308    } else {
309        result.cast_unsigned() & sew_mask(sew)
310    }
311}
312
313/// Narrowing unsigned clip: read a 2*SEW element from `vs2`, shift right by `shamt` with
314/// rounding, saturate to unsigned SEW range, set `vxsat` on clamp.
315///
316/// `vs2_elem` is the 2*SEW-bit element (zero-extended to u64 for SEW <= 32;
317/// for SEW = 64 the doubled width would be 128 bits, but Zve64x only supports SEW up to 64 and
318/// the narrowing destination is at most 64 bits wide, so 2*SEW = 128 - however the spec requires
319/// `ELEN >= 2*SEW` for narrowing instructions. Since `ELEN = 64` in Zve64x, narrowing is only
320/// valid for SEW <= 32 (`2*SEW <= 64`).  The caller must enforce this constraint by checking
321/// `vsew` before invoking narrowing operations.
322///
323/// `vs2_elem` is passed as `u64`; for SEW = 32 it holds a 64-bit (2*SEW) value.
324#[inline(always)]
325#[doc(hidden)]
326pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
327    // Shift right with rounding
328    let shifted = rounded_srl(vs2_elem, shamt, mode);
329    // Saturate to destination SEW unsigned range [0, 2^SEW - 1]
330    let max_dst = sew_mask(sew);
331    if shifted > max_dst {
332        *vxsat = true;
333        max_dst
334    } else {
335        shifted & max_dst
336    }
337}
338
339/// Narrowing signed clip: read a 2*SEW signed element from `vs2`, shift right arithmetically
340/// with rounding, saturate to signed SEW range.
341///
342/// Same SEW constraint as [`nclipu`].
343#[inline(always)]
344#[doc(hidden)]
345pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
346    // Sign-extend vs2_elem to full i64 treating it as a 2*SEW-bit signed value.
347    // For SEW=8 the source is 16-bit, for SEW=16 it is 32-bit, for SEW=32 it is 64-bit.
348    // TODO: Use `sew.double_width()`
349    let double_sew_bits = sew.bits_width() * 2;
350    let shift_amt = i64::BITS - u32::from(double_sew_bits);
351    let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
352    // Arithmetic right shift with rounding
353    // For rounding we need the raw low bits of the wide value before shifting
354    let low_bits = signed_wide.cast_unsigned()
355        & if double_sew_bits == 64 {
356            u64::MAX
357        } else {
358            (1u64 << double_sew_bits) - 1
359        };
360    let truncated = signed_wide >> shamt;
361    let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
362    let rounded = truncated.wrapping_add(r.cast_signed());
363    // Saturate to signed SEW range
364    let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
365    let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
366    if rounded < min_dst {
367        *vxsat = true;
368        min_dst.cast_unsigned() & sew_mask(sew)
369    } else if rounded > max_dst {
370        *vxsat = true;
371        max_dst.cast_unsigned() & sew_mask(sew)
372    } else {
373        rounded.cast_unsigned() & sew_mask(sew)
374    }
375}
376
377/// Read a 2*SEW-wide element as `u64` from the double-width source register group of a narrowing
378/// instruction.
379///
380/// For narrowing instructions `vs2` holds elements of width `2*SEW`. The register group size is
381/// `2 * group_regs`. Element `i` of width `2*SEW` is located in the same way as a SEW-wide
382/// element of width `2*SEW` (i.e., treating `2*SEW` as the element width). For `SEW = 32` this
383/// reads 64-bit elements; for `SEW <= 16` it reads narrower elements but zero-extends to `u64`.
384///
385/// # Safety
386/// - `2*SEW <= 64` (Zve64x constraint: only valid for SEW <= 32; caller must verify)
387/// - `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32`
388#[inline(always)]
389pub unsafe fn read_wide_element_u64<const VLENB: usize>(
390    vregs: &VectorRegisterFile<VLENB>,
391    base_reg: VReg,
392    elem_i: u32,
393    sew: Vsew,
394) -> u64 {
395    let double_sew_bytes = usize::from(sew.bytes_width()) * 2;
396    let elems_per_reg = VLENB / double_sew_bytes;
397    let reg_off = elem_i as usize / elems_per_reg;
398    let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
399    // SAFETY: caller guarantees bounds
400    let reg = unsafe {
401        vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
402    };
403    // SAFETY: `byte_off + double_sew_bytes <= VLENB`
404    let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
405    let mut buf = [0u8; 8];
406    // SAFETY: `double_sew_bytes <= 8` (SEW <= 32 for Zve64x narrowing)
407    unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
408    u64::from_le_bytes(buf)
409}
410
411/// Execute a single-width fixed-point arithmetic operation that may set `vxsat`.
412///
413/// `op` receives `(vs2_elem, src_elem, sew, vxrm)` and returns `(result, saturated)`.
414/// The helper ORs any saturation flag into `vxsat` after the loop.
415///
416/// # Safety
417/// Same preconditions as `execute_arith_op` in the arithmetic helpers.
418#[inline(always)]
419#[doc(hidden)]
420pub unsafe fn execute_fixed_point_op<Reg, ExtState, CustomError, F>(
421    ext_state: &mut ExtState,
422    vd: VReg,
423    vs2: VReg,
424    src: OpSrc,
425    vm: bool,
426    sew: Vsew,
427    op: F,
428) where
429    Reg: Register,
430    ExtState: VectorRegistersExt<Reg, CustomError>,
431    [(); ExtState::ELEN as usize]:,
432    [(); ExtState::VLEN as usize]:,
433    [(); ExtState::VLENB as usize]:,
434    CustomError: fmt::Debug,
435    // op: (vs2_elem, src_elem, sew, vxrm) -> result
436    F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
437{
438    let vl = ext_state.vl();
439    let vstart = ext_state.vstart();
440    let vxrm = ext_state.vxrm();
441    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
442    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
443    let mut any_sat = false;
444    for i in u32::from(vstart)..vl {
445        if !mask_bit(&mask_buf, i) {
446            continue;
447        }
448        // SAFETY: alignment and bounds checked by caller
449        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
450        let b = match src {
451            OpSrc::Vreg(vs1_base) => {
452                // SAFETY: same argument as vs2
453                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
454            }
455            OpSrc::Scalar(val) => val,
456        };
457        let result = op(a, b, sew, vxrm, &mut any_sat);
458        // SAFETY: alignment and bounds checked by caller
459        unsafe {
460            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
461        }
462    }
463    if any_sat {
464        // vxsat is sticky: OR in the new saturation flag
465        ext_state.set_vxsat(true);
466    }
467    ext_state.mark_vs_dirty();
468    ext_state.reset_vstart();
469}
470
471/// Execute a narrowing fixed-point clip operation.
472///
473/// `vs2` holds a double-width register group (2x `group_regs` registers). `vd` holds the
474/// single-width destination. `src` provides the shift amount (Vreg or Scalar).
475///
476/// For Zve64x narrowing instructions, `SEW` must be at most 32 because `2*SEW` must fit in 64
477/// bits. The caller must verify this constraint before invoking this function.
478///
479/// # Safety
480/// - `sew.bits_width() <= 32` (Zve64x ELEN = 64 constraint for narrowing)
481/// - `vs2.to_bits() % (2 * group_regs) == 0` and `vs2.to_bits() + 2 * group_regs <= 32`
482/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
483/// - `vl <= group_regs * VLENB / sew_bytes`
484/// - When `vm=false`: `vd.to_bits() != 0`
485#[inline(always)]
486#[doc(hidden)]
487pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, CustomError, F>(
488    ext_state: &mut ExtState,
489    vd: VReg,
490    vs2: VReg,
491    src: OpSrc,
492    vm: bool,
493    sew: Vsew,
494    op: F,
495) where
496    Reg: Register,
497    ExtState: VectorRegistersExt<Reg, CustomError>,
498    [(); ExtState::ELEN as usize]:,
499    [(); ExtState::VLEN as usize]:,
500    [(); ExtState::VLENB as usize]:,
501    CustomError: fmt::Debug,
502    // op: (vs2_wide_elem, shamt, sew, vxrm, vxsat) -> result
503    F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
504{
505    let vl = ext_state.vl();
506    let vstart = ext_state.vstart();
507    let vxrm = ext_state.vxrm();
508    // SAFETY: `vl <= VLEN`
509    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
510    let mut any_sat = false;
511    // Mask shift amount to log2(2*SEW) bits per spec §12.11
512    let shamt_mask = u64::from(sew.bits_width() * 2 - 1);
513    for i in u32::from(vstart)..vl {
514        if !mask_bit(&mask_buf, i) {
515            continue;
516        }
517        // Read 2*SEW-wide source element
518        // SAFETY: `vs2` double-width alignment checked by caller
519        let wide_a = unsafe { read_wide_element_u64(ext_state.read_vregs(), vs2, i, sew) };
520        let shamt = match src {
521            OpSrc::Vreg(vs1_base) => {
522                // SAFETY: vs1 SEW-wide alignment checked by caller
523                let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
524                (raw & shamt_mask) as u32
525            }
526            OpSrc::Scalar(val) => (val & shamt_mask) as u32,
527        };
528        let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
529        // SAFETY: `vd` alignment checked by caller
530        unsafe {
531            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
532        }
533    }
534    if any_sat {
535        ext_state.set_vxsat(true);
536    }
537    ext_state.mark_vs_dirty();
538    ext_state.reset_vstart();
539}
540
541/// Verify that the destination SEW is valid for narrowing (must be at most 32 in Zve64x).
542///
543/// Returns `Err(IllegalInstruction)` when `sew.bits_width() > 32`.
544#[inline(always)]
545#[doc(hidden)]
546pub fn check_narrowing_sew<Reg, Memory, PC, CustomError>(
547    program_counter: &PC,
548    sew: Vsew,
549) -> Result<(), ExecutionError<Reg::Type, CustomError>>
550where
551    Reg: Register,
552    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
553{
554    if sew.bits_width() > 32 {
555        cold_path();
556        return Err(ExecutionError::IllegalInstruction {
557            address: program_counter.old_pc(INSTRUCTION_SIZE),
558        });
559    }
560    Ok(())
561}
562
563/// Check that the double-width source `vs2` of a narrowing instruction is aligned to its register
564/// group and fits in `[0, 32)`.
565///
566/// The source operand has `EEW = 2*SEW`, so its `EMUL = 2*LMUL`. Per v-spec §5.2 the group must be
567/// aligned to `EMUL` registers, and `EMUL` outside the legal range `[1/8, 8]` (e.g. `LMUL=8`, which
568/// would need `EMUL=16`) is reserved. Unlike `2 * register_count()`, this correctly yields a single
569/// register with no alignment constraint for fractional `LMUL` (where `2*LMUL <= 1`).
570///
571/// `sew` is the destination (narrow) SEW; it must be at most 32 (see [`check_narrowing_sew()`]).
572#[inline(always)]
573#[doc(hidden)]
574pub fn check_vs2_narrowing_alignment<Reg, Memory, PC, CustomError>(
575    program_counter: &PC,
576    vs2: VReg,
577    vlmul: Vlmul,
578    sew: Vsew,
579) -> Result<(), ExecutionError<Reg::Type, CustomError>>
580where
581    Reg: Register,
582    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
583{
584    // Source EEW is double the destination SEW. SEW=64 is rejected earlier by
585    // `check_narrowing_sew`.
586    let wide_eew = match sew {
587        Vsew::E8 => Eew::E16,
588        Vsew::E16 => Eew::E32,
589        Vsew::E32 => Eew::E64,
590        Vsew::E64 => {
591            cold_path();
592            return Err(ExecutionError::IllegalInstruction {
593                address: program_counter.old_pc(INSTRUCTION_SIZE),
594            });
595        }
596    };
597    // `EMUL = 2*LMUL`; `None` when reserved (e.g. LMUL=8 -> EMUL=16).
598    let Some(wide_group) = vlmul.data_register_count(wide_eew, sew) else {
599        cold_path();
600        return Err(ExecutionError::IllegalInstruction {
601            address: program_counter.old_pc(INSTRUCTION_SIZE),
602        });
603    };
604    let vs2_idx = vs2.to_bits();
605    if !vs2_idx.is_multiple_of(wide_group) || vs2_idx + wide_group > 32 {
606        cold_path();
607        return Err(ExecutionError::IllegalInstruction {
608            address: program_counter.old_pc(INSTRUCTION_SIZE),
609        });
610    }
611    Ok(())
612}