Skip to main content

ab_riscv_interpreter/v/zvexx/fixed_point/
zvexx_fixed_point_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5    OpSrc, check_vreg_group_alignment, sew_mask,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8    read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, ProgramCounter};
13use ab_riscv_primitives::prelude::*;
14use core::fmt;
15
16/// Compute the rounding increment for a right shift of `val` by `shift` bits.
17///
18/// When `shift == 0` there are no fractional bits so the increment is always zero.
19/// `current_result_lsb` is the LSB of the truncated result, required for `Rne` and `Rod`.
20#[inline(always)]
21fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
22    if shift == 0 {
23        return 0;
24    }
25    // `d_minus1_bit`: the most-significant discarded bit (bit position `shift - 1`)
26    let d_minus1_bit = (val >> (shift - 1)) & 1;
27    // `sticky`: OR of all bits below position `shift - 1`
28    let sticky = if shift >= 2 {
29        // Any of bits [shift-2 : 0] set?
30        (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
31    } else {
32        false
33    };
34    match mode {
35        // Round nearest up: increment = v[d-1]
36        Vxrm::Rnu => d_minus1_bit,
37        // Round nearest even: increment = v[d-1] & (sticky | result_lsb)
38        Vxrm::Rne => d_minus1_bit & u64::from(sticky || current_result_lsb != 0),
39        // Round down / truncate: never increment
40        Vxrm::Rdn => 0,
41        // Round to odd: set result LSB if any discarded bit was non-zero
42        Vxrm::Rod => u64::from(current_result_lsb == 0 && (d_minus1_bit != 0 || sticky)),
43    }
44}
45
46/// Perform a rounded right shift of `val` by `shift` bits (logical / unsigned).
47///
48/// Returns `(val >> shift) + round_increment`.
49#[inline(always)]
50#[doc(hidden)]
51pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
52    let truncated = val >> shift;
53    let r = round_increment(val, shift, mode, truncated & 1);
54    truncated.wrapping_add(r)
55}
56
57/// Perform a rounded arithmetic right shift of `val` (sign-extended to SEW) by `shift` bits.
58///
59/// Returns the SEW-wide signed result as `u64` (sign bits above SEW are meaningful).
60#[inline(always)]
61#[doc(hidden)]
62pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
63    let signed = sign_extend(val, sew);
64    // Treat the raw bits for rounding purposes: rounding uses the unsigned representation of the
65    // SEW-wide value (only bits below `shift` matter, so masking is not needed here since the
66    // discarded bits are the same regardless of sign extension).
67    let truncated_signed = signed >> shift;
68    let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
69    truncated_signed.cast_unsigned().wrapping_add(r)
70}
71
72/// Saturating unsigned add: `vs2 + src`, clamped to `[0, 2^SEW - 1]`.
73///
74/// Sets `vxsat` to `true` on overflow.
75#[inline(always)]
76#[doc(hidden)]
77pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
78    let mask = sew_mask(sew);
79    let a_w = a & mask;
80    let b_w = b & mask;
81    let result = a_w.wrapping_add(b_w);
82    if result & mask < a_w {
83        // Overflow: wrapped around
84        *vxsat = true;
85        mask
86    } else {
87        result & mask
88    }
89}
90
91/// Saturating signed add: `vs2 + src`, clamped to `[-(2^(SEW-1)), 2^(SEW-1) - 1]`.
92///
93/// Sets `vxsat` to `true` on overflow.
94#[inline(always)]
95#[doc(hidden)]
96pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
97    let sa = i128::from(sign_extend(a, sew));
98    let sb = i128::from(sign_extend(b, sew));
99    let result = sa.wrapping_add(sb);
100    let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
101    let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
102    if result < min_val {
103        *vxsat = true;
104        (min_val as i64).cast_unsigned() & sew_mask(sew)
105    } else if result > max_val {
106        *vxsat = true;
107        (max_val as i64).cast_unsigned() & sew_mask(sew)
108    } else {
109        (result as i64).cast_unsigned() & sew_mask(sew)
110    }
111}
112
113/// Saturating unsigned subtract: `vs2 - src`, clamped to `[0, 2^SEW - 1]`.
114///
115/// Sets `vxsat` to `true` on overflow (underflow to negative).
116#[inline(always)]
117#[doc(hidden)]
118pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
119    let mask = sew_mask(sew);
120    let a_w = a & mask;
121    let b_w = b & mask;
122    if a_w < b_w {
123        *vxsat = true;
124        0
125    } else {
126        (a_w - b_w) & mask
127    }
128}
129
130/// Saturating signed subtract: `vs2 - src`, clamped to `[-(2^(SEW-1)), 2^(SEW-1) - 1]`.
131///
132/// Sets `vxsat` to `true` on overflow.
133#[inline(always)]
134#[doc(hidden)]
135pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
136    let sa = i128::from(sign_extend(a, sew));
137    let sb = i128::from(sign_extend(b, sew));
138    let result = sa.wrapping_sub(sb);
139    let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
140    let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
141    if result < min_val {
142        *vxsat = true;
143        (min_val as i64).cast_unsigned() & sew_mask(sew)
144    } else if result > max_val {
145        *vxsat = true;
146        (max_val as i64).cast_unsigned() & sew_mask(sew)
147    } else {
148        (result as i64).cast_unsigned() & sew_mask(sew)
149    }
150}
151
152/// Averaging unsigned add: `(vs2 + src) >> 1` with rounding per `vxrm`.
153///
154/// Uses a 1-bit wider intermediate to avoid overflow; no saturation, no `vxsat`.
155#[inline(always)]
156#[doc(hidden)]
157pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
158    let mask = sew_mask(sew);
159    let a_w = a & mask;
160    let b_w = b & mask;
161    // Compute full sum in one extra bit by using u128 or by widening trick.
162    // Since SEW <= 64 and both operands are SEW-bit values, the sum fits in SEW+1 bits.
163    // Use wrapping_add: the carry out of bit SEW-1 is the extra bit.
164    let sum = a_w.wrapping_add(b_w);
165    // Carry: set if unsigned sum overflowed SEW bits
166    let carry = u64::from(sum & mask < a_w);
167    // Full (SEW+1)-bit value: `carry` is at bit position SEW, `sum & mask` are low SEW bits.
168    // We need `(carry:sum) >> 1` with rounding.
169    // Bit 0 of `sum & mask` is the rounding bit for the truncated division.
170    let r = round_increment(sum & mask, 1, mode, (sum >> 1u8) & 1);
171    // Shift the (SEW+1)-bit quantity right by 1: result = (carry << (SEW-1)) | ((sum & mask) >> 1)
172    let shifted = (carry << (u32::from(sew.bits_width()) - 1)) | ((sum & mask) >> 1u8);
173    (shifted.wrapping_add(r)) & mask
174}
175
176/// Averaging signed add: `(vs2 + src) >> 1` with rounding per `vxrm`.
177///
178/// No saturation, no `vxsat`.
179#[inline(always)]
180#[doc(hidden)]
181pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
182    let sa = sign_extend(a, sew);
183    let sb = sign_extend(b, sew);
184    // Full sum as i128 to avoid overflow
185    let sum = i128::from(sa).wrapping_add(i128::from(sb));
186    // The low bit is the fractional bit for rounding
187    let r = match mode {
188        Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
189        Vxrm::Rne => {
190            // round-to-nearest-even: increment if fractional bit set AND (result LSB or sticky)
191            // For a single bit shift there are no lower sticky bits, so only check result LSB
192            let result_lsb = ((sum >> 1u8) & 1).cast_unsigned() as u64;
193            ((sum & 1).cast_unsigned() as u64) & result_lsb
194        }
195        Vxrm::Rdn => 0,
196        Vxrm::Rod => {
197            // Set result LSB if it would be 0 and the fractional bit is nonzero
198            let result_lsb = (sum >> 1u8) & 1;
199            u64::from(result_lsb == 0 && (sum & 1) != 0)
200        }
201    };
202    let result = (sum >> 1u8) + i128::from(r);
203    (result as i64).cast_unsigned() & sew_mask(sew)
204}
205
206/// Averaging unsigned subtract: `(vs2 - src) >> 1` with rounding per `vxrm`.
207///
208/// No saturation, no `vxsat`.
209#[inline(always)]
210#[doc(hidden)]
211pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
212    let mask = sew_mask(sew);
213    let a_w = a & mask;
214    let b_w = b & mask;
215    // Compute difference with borrow using wrapping sub; borrow extends to SEW+1 bit.
216    let diff = a_w.wrapping_sub(b_w);
217    // Borrow: set if a < b (unsigned)
218    let borrow = u64::from(a_w < b_w);
219    // Full (SEW+1)-bit two's-complement difference:
220    // If borrow: the SEW-bit `diff` is correct (it wrapped), and the sign extension bit is 1.
221    // Rounding: bit 0 of diff is the fractional bit.
222    let r = round_increment(diff & mask, 1, mode, (diff >> 1u8) & 1);
223    // Arithmetic right shift by 1 of the (SEW+1)-bit signed value.
224    // For unsigned averaging subtract: result = ((SEW+1)-bit diff) / 2 with rounding.
225    // The (SEW+1)-bit value is: borrow is the sign bit. If borrow set, value is negative.
226    // Result = (borrow << SEW | diff) >> 1 (arithmetic) + r
227    // Arithmetic shift: sign bit (`borrow`) propagates.
228    let sign_fill = borrow.wrapping_neg(); // all ones if borrow set, zero otherwise
229    let shifted = (sign_fill << (u32::from(sew.bits_width()) - 1)) | ((diff & mask) >> 1u8);
230    (shifted.wrapping_add(r)) & mask
231}
232
233/// Averaging signed subtract: `(vs2 - src) >> 1` with rounding per `vxrm`.
234///
235/// No saturation, no `vxsat`.
236#[inline(always)]
237#[doc(hidden)]
238pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
239    let sa = sign_extend(a, sew);
240    let sb = sign_extend(b, sew);
241    let diff = i128::from(sa).wrapping_sub(i128::from(sb));
242    let r = match mode {
243        Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
244        Vxrm::Rne => {
245            let result_lsb = ((diff >> 1u8) & 1).cast_unsigned() as u64;
246            ((diff & 1).cast_unsigned() as u64) & result_lsb
247        }
248        Vxrm::Rdn => 0,
249        Vxrm::Rod => {
250            let result_lsb = (diff >> 1u8) & 1;
251            u64::from(result_lsb == 0 && (diff & 1) != 0)
252        }
253    };
254    let result = (diff >> 1u8) + i128::from(r);
255    (result as i64).cast_unsigned() & sew_mask(sew)
256}
257
258/// Fractional multiply with rounding and saturation: `vsmul`.
259///
260/// Computes `(a * b * 2 + rounding) >> SEW`, saturating at the signed maximum when the
261/// product of two minimum signed values overflows (`INT_MIN * INT_MIN`).
262///
263/// Per spec §12.4: `vd[i] = clip(roundoff_signed(vs2[i] * vs1[i] * 2, SEW))`.
264/// Sets `vxsat` on overflow.
265#[inline(always)]
266#[doc(hidden)]
267pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
268    // SEW-wide signed min and max in i64 (valid for all SEW <= 64)
269    let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
270    let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
271    let sa = i128::from(sign_extend(a, sew));
272    let sb = i128::from(sign_extend(b, sew));
273    // The only case where `product * 2` overflows a 2*SEW signed result is INT_MIN * INT_MIN.
274    // Detect this before any multiply: for SEW=64 INT64_MIN^2 = 2^126 and <<1 would overflow i128.
275    if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
276        *vxsat = true;
277        return max_sew.cast_unsigned() & sew_mask(sew);
278    }
279    // Full 2*SEW-bit product; no overflow possible because at least one operand != INT_MIN,
280    // so |product| < INT_MIN^2 and the value fits in i128 for SEW <= 64.
281    let product = sa * sb;
282    // Left shift by 1 for the Q-format fractional interpretation; safe because
283    // |product| < INT_MIN^2, so after <<1 the result still fits in i128 for SEW <= 64.
284    let doubled = product << 1u8;
285    // Extract the low SEW bits (the discarded portion) for rounding.
286    // Cast to u128 first to avoid sign-extension contaminating the mask.
287    let shift = u32::from(sew.bits_width());
288    let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
289    // Arithmetic right shift by SEW gives the truncated signed result in SEW-wide range.
290    let truncated = doubled >> shift;
291    let r = round_increment(
292        low_bits,
293        shift.min(64),
294        mode,
295        (truncated.cast_unsigned() as u64) & 1,
296    );
297    // `truncated` fits in i64 after the SEW-bit shift (it is a SEW-wide signed value).
298    let result = (truncated as i64).wrapping_add(r.cast_signed());
299    // Clamp to SEW-wide signed range (only reachable if rounding pushed the value over)
300    if result < min_sew {
301        *vxsat = true;
302        min_sew.cast_unsigned() & sew_mask(sew)
303    } else if result > max_sew {
304        *vxsat = true;
305        max_sew.cast_unsigned() & sew_mask(sew)
306    } else {
307        result.cast_unsigned() & sew_mask(sew)
308    }
309}
310
311/// Narrowing unsigned clip: read a 2*SEW element from `vs2`, shift right by `shamt` with
312/// rounding, saturate to unsigned SEW range, set `vxsat` on clamp.
313///
314/// `vs2_elem` is the 2*SEW-bit element (zero-extended to u64 for SEW <= 32;
315/// for SEW = 64 the doubled width would be 128 bits, but Zve64x only supports SEW up to 64 and
316/// the narrowing destination is at most 64 bits wide, so 2*SEW = 128 - however the spec requires
317/// `ELEN >= 2*SEW` for narrowing instructions. Since `ELEN = 64` in Zve64x, narrowing is only
318/// valid for SEW <= 32 (`2*SEW <= 64`).  The caller must enforce this constraint by checking
319/// `vsew` before invoking narrowing operations.
320///
321/// `vs2_elem` is passed as `u64`; for SEW = 32 it holds a 64-bit (2*SEW) value.
322#[inline(always)]
323#[doc(hidden)]
324pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
325    // Shift right with rounding
326    let shifted = rounded_srl(vs2_elem, shamt, mode);
327    // Saturate to destination SEW unsigned range [0, 2^SEW - 1]
328    let max_dst = sew_mask(sew);
329    if shifted > max_dst {
330        *vxsat = true;
331        max_dst
332    } else {
333        shifted & max_dst
334    }
335}
336
337/// Narrowing signed clip: read a 2*SEW signed element from `vs2`, shift right arithmetically
338/// with rounding, saturate to signed SEW range.
339///
340/// Same SEW constraint as [`nclipu`].
341#[inline(always)]
342#[doc(hidden)]
343pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
344    // Sign-extend vs2_elem to full i64 treating it as a 2*SEW-bit signed value.
345    // For SEW=8 the source is 16-bit, for SEW=16 it is 32-bit, for SEW=32 it is 64-bit.
346    let double_sew_bits = sew.bits_width() * 2;
347    let shift_amt = i64::BITS - u32::from(double_sew_bits);
348    let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
349    // Arithmetic right shift with rounding
350    // For rounding we need the raw low bits of the wide value before shifting
351    let low_bits = signed_wide.cast_unsigned()
352        & if double_sew_bits == 64 {
353            u64::MAX
354        } else {
355            (1u64 << double_sew_bits) - 1
356        };
357    let truncated = signed_wide >> shamt;
358    let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
359    let rounded = truncated.wrapping_add(r.cast_signed());
360    // Saturate to signed SEW range
361    let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
362    let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
363    if rounded < min_dst {
364        *vxsat = true;
365        min_dst.cast_unsigned() & sew_mask(sew)
366    } else if rounded > max_dst {
367        *vxsat = true;
368        max_dst.cast_unsigned() & sew_mask(sew)
369    } else {
370        rounded.cast_unsigned() & sew_mask(sew)
371    }
372}
373
374/// Read a 2*SEW-wide element as `u64` from the double-width source register group of a narrowing
375/// instruction.
376///
377/// For narrowing instructions `vs2` holds elements of width `2*SEW`. The register group size is
378/// `2 * group_regs`. Element `i` of width `2*SEW` is located in the same way as a SEW-wide
379/// element of width `2*SEW` (i.e., treating `2*SEW` as the element width). For `SEW = 32` this
380/// reads 64-bit elements; for `SEW <= 16` it reads narrower elements but zero-extends to `u64`.
381///
382/// # Safety
383/// - `2*SEW <= 64` (Zve64x constraint: only valid for SEW <= 32; caller must verify)
384/// - `base_reg + elem_i / (VLENB / (2*sew_bytes)) < 32`
385#[inline(always)]
386pub unsafe fn read_wide_element_u64<const VLENB: usize>(
387    vregs: &VectorRegisterFile<VLENB>,
388    base_reg: VReg,
389    elem_i: u32,
390    sew: Vsew,
391) -> u64 {
392    let double_sew_bytes = usize::from(sew.bytes_width()) * 2;
393    let elems_per_reg = VLENB / double_sew_bytes;
394    let reg_off = elem_i as usize / elems_per_reg;
395    let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
396    // SAFETY: caller guarantees bounds
397    let reg = unsafe {
398        vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
399    };
400    // SAFETY: `byte_off + double_sew_bytes <= VLENB`
401    let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
402    let mut buf = [0u8; 8];
403    // SAFETY: `double_sew_bytes <= 8` (SEW <= 32 for Zve64x narrowing)
404    unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
405    u64::from_le_bytes(buf)
406}
407
408/// Execute a single-width fixed-point arithmetic operation that may set `vxsat`.
409///
410/// `op` receives `(vs2_elem, src_elem, sew, vxrm)` and returns `(result, saturated)`.
411/// The helper ORs any saturation flag into `vxsat` after the loop.
412///
413/// # Safety
414/// Same preconditions as `execute_arith_op` in the arithmetic helpers.
415#[inline(always)]
416#[doc(hidden)]
417pub unsafe fn execute_fixed_point_op<Reg, ExtState, CustomError, F>(
418    ext_state: &mut ExtState,
419    vd: VReg,
420    vs2: VReg,
421    src: OpSrc,
422    vm: bool,
423    sew: Vsew,
424    op: F,
425) where
426    Reg: Register,
427    ExtState: VectorRegistersExt<Reg, CustomError>,
428    [(); ExtState::ELEN as usize]:,
429    [(); ExtState::VLEN as usize]:,
430    [(); ExtState::VLENB as usize]:,
431    CustomError: fmt::Debug,
432    // op: (vs2_elem, src_elem, sew, vxrm) -> result
433    F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
434{
435    let vl = ext_state.vl();
436    let vstart = ext_state.vstart();
437    let vxrm = ext_state.vxrm();
438    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`
439    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
440    let mut any_sat = false;
441    for i in u32::from(vstart)..vl {
442        if !mask_bit(&mask_buf, i) {
443            continue;
444        }
445        // SAFETY: alignment and bounds checked by caller
446        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
447        let b = match src {
448            OpSrc::Vreg(vs1_base) => {
449                // SAFETY: same argument as vs2
450                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
451            }
452            OpSrc::Scalar(val) => val,
453        };
454        let result = op(a, b, sew, vxrm, &mut any_sat);
455        // SAFETY: alignment and bounds checked by caller
456        unsafe {
457            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
458        }
459    }
460    if any_sat {
461        // vxsat is sticky: OR in the new saturation flag
462        ext_state.set_vxsat(true);
463    }
464    ext_state.mark_vs_dirty();
465    ext_state.reset_vstart();
466}
467
468/// Execute a narrowing fixed-point clip operation.
469///
470/// `vs2` holds a double-width register group (2x `group_regs` registers). `vd` holds the
471/// single-width destination. `src` provides the shift amount (Vreg or Scalar).
472///
473/// For Zve64x narrowing instructions, `SEW` must be at most 32 because `2*SEW` must fit in 64
474/// bits. The caller must verify this constraint before invoking this function.
475///
476/// # Safety
477/// - `sew.bits_width() <= 32` (Zve64x ELEN = 64 constraint for narrowing)
478/// - `vs2.to_bits() % (2 * group_regs) == 0` and `vs2.to_bits() + 2 * group_regs <= 32`
479/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
480/// - `vl <= group_regs * VLENB / sew_bytes`
481/// - When `vm=false`: `vd.to_bits() != 0`
482#[inline(always)]
483#[doc(hidden)]
484pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, CustomError, F>(
485    ext_state: &mut ExtState,
486    vd: VReg,
487    vs2: VReg,
488    src: OpSrc,
489    vm: bool,
490    sew: Vsew,
491    op: F,
492) where
493    Reg: Register,
494    ExtState: VectorRegistersExt<Reg, CustomError>,
495    [(); ExtState::ELEN as usize]:,
496    [(); ExtState::VLEN as usize]:,
497    [(); ExtState::VLENB as usize]:,
498    CustomError: fmt::Debug,
499    // op: (vs2_wide_elem, shamt, sew, vxrm, vxsat) -> result
500    F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
501{
502    let vl = ext_state.vl();
503    let vstart = ext_state.vstart();
504    let vxrm = ext_state.vxrm();
505    // SAFETY: `vl <= VLEN`
506    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
507    let mut any_sat = false;
508    // Mask shift amount to log2(2*SEW) bits per spec §12.11
509    let shamt_mask = u64::from(sew.bits_width() * 2 - 1);
510    for i in u32::from(vstart)..vl {
511        if !mask_bit(&mask_buf, i) {
512            continue;
513        }
514        // Read 2*SEW-wide source element
515        // SAFETY: `vs2` double-width alignment checked by caller
516        let wide_a = unsafe { read_wide_element_u64(ext_state.read_vregs(), vs2, i, sew) };
517        let shamt = match src {
518            OpSrc::Vreg(vs1_base) => {
519                // SAFETY: vs1 SEW-wide alignment checked by caller
520                let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
521                (raw & shamt_mask) as u32
522            }
523            OpSrc::Scalar(val) => (val & shamt_mask) as u32,
524        };
525        let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
526        // SAFETY: `vd` alignment checked by caller
527        unsafe {
528            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
529        }
530    }
531    if any_sat {
532        ext_state.set_vxsat(true);
533    }
534    ext_state.mark_vs_dirty();
535    ext_state.reset_vstart();
536}
537
538/// Verify that the destination SEW is valid for narrowing (must be at most 32 in Zve64x).
539///
540/// Returns `Err(IllegalInstruction)` when `sew.bits_width() > 32`.
541#[inline(always)]
542#[doc(hidden)]
543pub fn check_narrowing_sew<Reg, Memory, PC, CustomError>(
544    program_counter: &PC,
545    sew: Vsew,
546) -> Result<(), ExecutionError<Reg::Type, CustomError>>
547where
548    Reg: Register,
549    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
550{
551    if sew.bits_width() > 32 {
552        return Err(ExecutionError::IllegalInstruction {
553            address: program_counter.old_pc(INSTRUCTION_SIZE),
554        });
555    }
556    Ok(())
557}
558
559/// Check that the double-width source `vs2` of a narrowing instruction is aligned to its register
560/// group and fits in `[0, 32)`.
561///
562/// The source operand has `EEW = 2*SEW`, so its `EMUL = 2*LMUL`. Per v-spec §5.2 the group must be
563/// aligned to `EMUL` registers, and `EMUL` outside the legal range `[1/8, 8]` (e.g. `LMUL=8`, which
564/// would need `EMUL=16`) is reserved. Unlike `2 * register_count()`, this correctly yields a single
565/// register with no alignment constraint for fractional `LMUL` (where `2*LMUL <= 1`).
566///
567/// `sew` is the destination (narrow) SEW; it must be at most 32 (see [`check_narrowing_sew()`]).
568#[inline(always)]
569#[doc(hidden)]
570pub fn check_vs2_narrowing_alignment<Reg, Memory, PC, CustomError>(
571    program_counter: &PC,
572    vs2: VReg,
573    vlmul: Vlmul,
574    sew: Vsew,
575) -> Result<(), ExecutionError<Reg::Type, CustomError>>
576where
577    Reg: Register,
578    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
579{
580    // Source EEW is double the destination SEW. SEW=64 is rejected earlier by
581    // `check_narrowing_sew`.
582    let wide_eew = match sew {
583        Vsew::E8 => Eew::E16,
584        Vsew::E16 => Eew::E32,
585        Vsew::E32 => Eew::E64,
586        Vsew::E64 => {
587            return Err(ExecutionError::IllegalInstruction {
588                address: program_counter.old_pc(INSTRUCTION_SIZE),
589            });
590        }
591    };
592    // `EMUL = 2*LMUL`; `None` when reserved (e.g. LMUL=8 -> EMUL=16).
593    let wide_group =
594        vlmul
595            .data_register_count(wide_eew, sew)
596            .ok_or(ExecutionError::IllegalInstruction {
597                address: program_counter.old_pc(INSTRUCTION_SIZE),
598            })?;
599    let vs2_idx = vs2.to_bits();
600    if !vs2_idx.is_multiple_of(wide_group) || vs2_idx + wide_group > 32 {
601        return Err(ExecutionError::IllegalInstruction {
602            address: program_counter.old_pc(INSTRUCTION_SIZE),
603        });
604    }
605    Ok(())
606}