Skip to main content

ab_riscv_interpreter/v/zvexx/widen_narrow/
zvexx_widen_narrow_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10
11/// Check that a widening destination `vd` is aligned to `wide_group_regs` and fits within
12/// `[0,32)`, without any source overlap check
13#[inline(always)]
14#[doc(hidden)]
15pub fn check_vd_widen_no_src_check<Reg, Memory, PC, CustomError>(
16    program_counter: &PC,
17    vd: VReg,
18    wide_group_regs: u8,
19) -> Result<(), ExecutionError<Reg::Type, CustomError>>
20where
21    Reg: Register,
22    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
23{
24    let vd_idx = vd.to_bits();
25    if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
26        return Err(ExecutionError::IllegalInstruction {
27            address: program_counter.old_pc(INSTRUCTION_SIZE),
28        });
29    }
30    Ok(())
31}
32
33/// Check that an extension source `vs2` is aligned to `src_group_regs`, fits in `[0,32)`, and only
34/// overlaps `vd` (which occupies `group_regs` registers) in a manner permitted by the spec.
35///
36/// Per the vector spec §5.2, the destination EEW (SEW) of an extension is greater than the source
37/// EEW (SEW/factor), so the destination may overlap the source only when the source EMUL is at
38/// least 1 and the overlap is in the highest-numbered part of the destination register group (e.g.
39/// `vzext.vf4 v0, v6` with LMUL=8, where the narrow source `{v6,v7}` aliases the high registers of
40/// the wide `{v0..v7}` destination). Any other overlap is illegal.
41#[inline(always)]
42#[doc(hidden)]
43pub fn check_vs_ext_alignment<Reg, Memory, PC, CustomError>(
44    program_counter: &PC,
45    vs2: VReg,
46    src_group_regs: u8,
47    vd: VReg,
48    group_regs: u8,
49) -> Result<(), ExecutionError<Reg::Type, CustomError>>
50where
51    Reg: Register,
52    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
53{
54    let vs2_idx = vs2.to_bits();
55    if !vs2_idx.is_multiple_of(src_group_regs) || vs2_idx + src_group_regs > 32 {
56        return Err(ExecutionError::IllegalInstruction {
57            address: program_counter.old_pc(INSTRUCTION_SIZE),
58        });
59    }
60    // The wide destination (group_regs) may overlap the narrow source (src_group_regs) only in the
61    // highest-numbered part of the destination group, and only when the source EMUL >= 1.
62    if widen_src_overlap_illegal(vd.to_bits(), group_regs, vs2_idx, src_group_regs) {
63        return Err(ExecutionError::IllegalInstruction {
64            address: program_counter.old_pc(INSTRUCTION_SIZE),
65        });
66    }
67    Ok(())
68}
69
70/// Check that a widening destination `vd` is aligned to `wide_group_regs`, fits within `[0, 32)`,
71/// and only overlaps the `group_regs`-register narrow source(s) starting at `vs_a`/`vs_b` in a
72/// manner permitted by the spec.
73///
74/// `wide_group_regs` is the pre-computed register count for the wide EMUL (2*LMUL), obtained via
75/// `Vlmul::index_register_count(wide_eew, sew)`. `group_regs` is the narrow LMUL register count.
76///
77/// Per the vector spec §5.2, a destination whose EEW (2*SEW) is greater than a source's EEW (SEW)
78/// may overlap that source only when the source EMUL is at least 1 and the overlap is in the
79/// highest-numbered part of the destination register group (e.g. `vwsubu.wv v2, v14, v3` with
80/// LMUL=1, where the narrow `v3` aliases the high register of the wide `{v2, v3}` destination).
81/// Any other overlap is illegal.
82#[inline(always)]
83#[doc(hidden)]
84pub fn check_vd_widen_alignment<Reg, Memory, PC, CustomError>(
85    program_counter: &PC,
86    vd: VReg,
87    vs_a: VReg,
88    vs_b_opt: Option<VReg>,
89    group_regs: u8,
90    wide_group_regs: u8,
91) -> Result<(), ExecutionError<Reg::Type, CustomError>>
92where
93    Reg: Register,
94    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
95{
96    let vd_idx = vd.to_bits();
97    if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
98        return Err(ExecutionError::IllegalInstruction {
99            address: program_counter.old_pc(INSTRUCTION_SIZE),
100        });
101    }
102    if widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_a.to_bits(), group_regs) {
103        return Err(ExecutionError::IllegalInstruction {
104            address: program_counter.old_pc(INSTRUCTION_SIZE),
105        });
106    }
107    if let Some(vs_b) = vs_b_opt
108        && widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_b.to_bits(), group_regs)
109    {
110        return Err(ExecutionError::IllegalInstruction {
111            address: program_counter.old_pc(INSTRUCTION_SIZE),
112        });
113    }
114    Ok(())
115}
116
117/// Returns `true` when a narrow source group of `group_regs` registers starting at `vs_idx`
118/// overlaps the wide destination group (`wide_group_regs` registers starting at `vd_idx`) in a way
119/// that is *not* permitted by the spec.
120///
121/// Overlap is only legal when the source EMUL is at least 1 - which, on widening, is exactly when
122/// the destination register count strictly exceeds the narrow source count (for fractional LMUL
123/// both counts collapse to 1) - and the source occupies the highest-numbered registers of the
124/// destination group.
125#[inline(always)]
126fn widen_src_overlap_illegal(vd_idx: u8, wide_group_regs: u8, vs_idx: u8, group_regs: u8) -> bool {
127    if !ranges_overlap(vd_idx, wide_group_regs, vs_idx, group_regs) {
128        return false;
129    }
130    let high_part_overlap =
131        wide_group_regs > group_regs && vs_idx == vd_idx + wide_group_regs - group_regs;
132    !high_part_overlap
133}
134
135/// Check that a widening source `vs2` that is already 2×SEW wide is aligned to `wide_group_regs`
136/// and fits within `[0, 32)`.
137#[inline(always)]
138#[doc(hidden)]
139pub fn check_vs_wide_alignment<Reg, Memory, PC, CustomError>(
140    program_counter: &PC,
141    vs: VReg,
142    wide_group_regs: u8,
143) -> Result<(), ExecutionError<Reg::Type, CustomError>>
144where
145    Reg: Register,
146    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
147{
148    let vs_idx = vs.to_bits();
149    if !vs_idx.is_multiple_of(wide_group_regs) || vs_idx + wide_group_regs > 32 {
150        return Err(ExecutionError::IllegalInstruction {
151            address: program_counter.old_pc(INSTRUCTION_SIZE),
152        });
153    }
154    Ok(())
155}
156
157/// Check that a narrowing destination `vd` is aligned to `group_regs` and fits
158/// within `[0, 32)`.
159///
160/// No overlap check against `vs2` is performed here because narrowing instructions
161/// permit `vd` to alias the low half of the wide `vs2` register group per spec §11.7.
162#[inline(always)]
163#[doc(hidden)]
164pub fn check_vd_narrow_alignment<Reg, Memory, PC, CustomError>(
165    program_counter: &PC,
166    vd: VReg,
167    group_regs: u8,
168) -> Result<(), ExecutionError<Reg::Type, CustomError>>
169where
170    Reg: Register,
171    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
172{
173    let vd_idx = vd.to_bits();
174    if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
175        return Err(ExecutionError::IllegalInstruction {
176            address: program_counter.old_pc(INSTRUCTION_SIZE),
177        });
178    }
179    Ok(())
180}
181
182/// Returns `true` when `[a_start, a_start+a_len)` overlaps `[b_start, b_start+b_len)`.
183#[inline(always)]
184fn ranges_overlap(a_start: u8, a_len: u8, b_start: u8, b_len: u8) -> bool {
185    a_start < b_start + b_len && b_start < a_start + a_len
186}
187
188/// Return whether mask bit `i` is set in the mask byte slice (LSB-first within each byte).
189#[inline(always)]
190fn mask_bit(mask: &[u8], i: u32) -> bool {
191    mask.get((i / u8::BITS) as usize)
192        .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
193}
194
195/// Snapshot the mask register into a stack buffer.
196///
197/// When `vm=true` (unmasked), all bytes are `0xff`.
198///
199/// # Safety
200/// `vl.div_ceil(8) <= VLENB` must hold. This is guaranteed when `vl <= VLEN`.
201#[inline(always)]
202unsafe fn snapshot_mask<const VLENB: usize>(
203    vregs: &VectorRegisterFile<VLENB>,
204    vm: bool,
205    vl: u32,
206) -> [u8; VLENB] {
207    let mut buf = [0u8; VLENB];
208    if vm {
209        buf = [0xffu8; VLENB];
210    } else {
211        let mask_bytes = vl.div_ceil(u8::BITS) as usize;
212        // SAFETY: `mask_bytes <= VLENB` by precondition
213        unsafe {
214            buf.get_unchecked_mut(..mask_bytes)
215                .copy_from_slice(vregs.get(VReg::V0).get_unchecked(..mask_bytes));
216        }
217    }
218    buf
219}
220
221/// Read the low `sew.bytes_width()` of the element `elem_i` from the register group `base_reg`,
222/// zero-extended to `u64`.
223///
224/// # Safety
225/// `base_reg + elem_i / (VLENB / sew.bytes_width()) < 32`
226#[inline(always)]
227unsafe fn read_element_u64<const VLENB: usize>(
228    vregs: &VectorRegisterFile<VLENB>,
229    base_reg: VReg,
230    elem_i: u32,
231    sew: Vsew,
232) -> u64 {
233    let sew_bytes = usize::from(sew.bytes_width());
234    let elems_per_reg = VLENB / sew_bytes;
235    let reg_off = elem_i as usize / elems_per_reg;
236    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
237    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
238    let reg = unsafe {
239        vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
240    };
241    // SAFETY: `byte_off + sew_bytes <= VLENB`
242    let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
243    let mut buf = [0u8; 8];
244    // SAFETY: `sew_bytes <= 8`
245    unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
246    u64::from_le_bytes(buf)
247}
248
249/// Write the low `sew.bytes_width()` of `value` into element `elem_i` in register group `base_reg`.
250///
251/// # Safety
252/// `base_reg + elem_i / (VLENB / sew.bytes_width()) < 32`
253#[inline(always)]
254unsafe fn write_element_u64<const VLENB: usize>(
255    vregs: &mut VectorRegisterFile<VLENB>,
256    base_reg: VReg,
257    elem_i: u32,
258    sew: Vsew,
259    value: u64,
260) {
261    let sew_bytes = usize::from(sew.bytes_width());
262    let elems_per_reg = VLENB / sew_bytes;
263    let reg_off = elem_i as usize / elems_per_reg;
264    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
265    let buf = value.to_le_bytes();
266    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
267    let reg = unsafe {
268        vregs.get_mut(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
269    };
270    // SAFETY: `byte_off + sew_bytes <= VLENB`
271    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
272    // SAFETY: `sew_bytes <= 8`
273    dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
274}
275
276/// Sign-extend the low `bits` of `val` to `i64`.
277#[inline(always)]
278#[doc(hidden)]
279pub fn sign_extend_bits(val: u64, bits: u8) -> i64 {
280    let shift = u64::BITS - u32::from(bits);
281    (val.cast_signed() << shift) >> shift
282}
283
284/// Interpret a scalar operand as an unsigned SEW-wide value.
285///
286/// RVV widening scalar instructions (.vx/.wx) conceptually use a scalar
287/// operand whose width matches the current SEW, not the full XLEN width.
288///
289/// For example on RV64:
290///
291/// SEW=8:
292///     val = 0x0000_0000_0000_01ff
293///     result = 0x0000_0000_0000_00ff
294///
295/// SEW=16:
296///     val = 0x0000_0000_0000_01ff
297///     result = 0x0000_0000_0000_01ff
298///
299/// SEW=32:
300///     val = 0xffff_ffff_1234_5678
301///     result = 0x0000_0000_1234_5678
302///
303/// This helper performs that SEW-width truncation without sign extension.
304#[inline(always)]
305fn scalar_unsigned_for_sew(val: u64, sew_bits: u8) -> u64 {
306    val & (u64::MAX >> (u64::BITS - u32::from(sew_bits)))
307}
308
309/// Interpret a scalar operand as a signed SEW-wide value.
310///
311/// The scalar is first truncated to SEW bits, then sign-extended back to
312/// 64 bits.
313///
314/// For example on RV64:
315///
316/// SEW=8:
317///     val = 0x0000_0000_0000_00ff
318///     result = 0xffff_ffff_ffff_ffff (-1)
319///
320/// SEW=8:
321///     val = 0x0000_0000_0000_007f
322///     result = 0x0000_0000_0000_007f (+127)
323///
324/// SEW=16:
325///     val = 0x0000_0000_0000_ffff
326///     result = 0xffff_ffff_ffff_ffff (-1)
327///
328/// This matches the signed widening behavior required by instructions such
329/// as vwadd.vx and vwsub.vx.
330#[inline(always)]
331fn scalar_signed_for_sew(val: u64, sew_bits: u8) -> u64 {
332    sign_extend_bits(val, sew_bits).cast_unsigned()
333}
334
335/// Execute a widening integer add/subtract.
336///
337/// Each source element is SEW-wide; the destination element is 2×SEW-wide.
338/// `zero_extend_a` and `zero_extend_b` select unsigned vs signed widening for each source
339/// (unsigned = zero-extend, signed = sign-extend).
340///
341/// `op` receives `(wide_a: u64, wide_b: u64) -> u64`.
342///
343/// # Safety
344/// - `vd` aligned to `2*group_regs`, fits in `[0,32)`, does not overlap `vs2` or `src` (verified by
345///   caller)
346/// - `vs2` aligned to `group_regs`, fits in `[0,32)` (verified by caller)
347/// - `src` register (when `WidenSrc::Vreg`) aligned to `group_regs`, fits in `[0,32)` (verified by
348///   caller)
349/// - `vl <= group_regs * VLENB / sew.bytes_width()` (all elements fit)
350/// - SEW < 64
351/// - When `vm=false`: `vd.to_bits() != 0`
352#[inline(always)]
353#[expect(clippy::too_many_arguments, reason = "Internal API")]
354#[doc(hidden)]
355pub unsafe fn execute_widen_op<Reg, ExtState, CustomError, F>(
356    ext_state: &mut ExtState,
357    vd: VReg,
358    vs2: VReg,
359    src: OpSrc,
360    vm: bool,
361    sew: Vsew,
362    zero_extend_a: bool,
363    zero_extend_b: bool,
364    op: F,
365) where
366    Reg: Register,
367    ExtState: VectorRegistersExt<Reg, CustomError>,
368    [(); ExtState::ELEN as usize]:,
369    [(); ExtState::VLEN as usize]:,
370    [(); ExtState::VLENB as usize]:,
371    CustomError: fmt::Debug,
372    F: Fn(u64, u64) -> u64,
373{
374    let vl = ext_state.vl();
375    let vstart = ext_state.vstart();
376    let wide_sew = sew
377        .double_width()
378        .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
379
380    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
381    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
382
383    for i in u32::from(vstart)..vl {
384        if !mask_bit(&mask_buf, i) {
385            continue;
386        }
387        // SAFETY: `vs2` aligned to `group_regs`;
388        // `i < vl <= group_regs * (VLENB / sew.bytes_width())`
389        let raw_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
390        let wide_a = if zero_extend_a {
391            raw_a
392        } else {
393            sign_extend_bits(raw_a, sew.bits_width()).cast_unsigned()
394        };
395        let wide_b = match src {
396            OpSrc::Vreg(vs1_base) => {
397                // SAFETY: same argument as vs2
398                let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
399                if zero_extend_b {
400                    raw_b
401                } else {
402                    sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
403                }
404            }
405            OpSrc::Scalar(val) => {
406                if zero_extend_b {
407                    scalar_unsigned_for_sew(val, sew.bits_width())
408                } else {
409                    scalar_signed_for_sew(val, sew.bits_width())
410                }
411            }
412        };
413        let result = op(wide_a, wide_b);
414        // SAFETY: `vd` aligned to `2*group_regs`;
415        // `i < vl <= group_regs * (VLENB / sew.bytes_width())` so
416        // `i < 2*group_regs * (VLENB / wide_sew.bytes_width())` - element fits in the wide group
417        unsafe {
418            write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
419        }
420    }
421    ext_state.mark_vs_dirty();
422    ext_state.reset_vstart();
423}
424
425/// Execute a widening add/subtract where `vs2` is already 2×SEW wide.
426///
427/// `vs2` is read at `wide_sew.bytes_width()`; `src` (narrow) is read at `sew.bytes_width()` and
428/// widened. `zero_extend_b` selects unsigned vs signed widening for the narrow source operand.
429///
430/// # Safety
431/// - `vd` aligned to `2*group_regs`, fits in `[0,32)`, does not overlap `vs2` or `src`
432/// - `vs2` aligned to `2*group_regs`, fits in `[0,32)` (wide source)
433/// - `src` register (when `WidenSrc::Vreg`) aligned to `group_regs`, fits in `[0,32)`
434/// - `vl <= group_regs * VLENB / sew.bytes_width()`
435/// - SEW < 64
436/// - When `vm=false`: `vd.to_bits() != 0`
437#[inline(always)]
438#[expect(clippy::too_many_arguments, reason = "Internal API")]
439#[doc(hidden)]
440pub unsafe fn execute_widen_w_op<Reg, ExtState, CustomError, F>(
441    ext_state: &mut ExtState,
442    vd: VReg,
443    vs2: VReg,
444    src: OpSrc,
445    vm: bool,
446    sew: Vsew,
447    zero_extend_b: bool,
448    op: F,
449) where
450    Reg: Register,
451    ExtState: VectorRegistersExt<Reg, CustomError>,
452    [(); ExtState::ELEN as usize]:,
453    [(); ExtState::VLEN as usize]:,
454    [(); ExtState::VLENB as usize]:,
455    CustomError: fmt::Debug,
456    F: Fn(u64, u64) -> u64,
457{
458    let vl = ext_state.vl();
459    let vstart = ext_state.vstart();
460    let wide_sew = sew
461        .double_width()
462        .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
463
464    // SAFETY: `vl <= VLEN`
465    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
466
467    for i in u32::from(vstart)..vl {
468        if !mask_bit(&mask_buf, i) {
469            continue;
470        }
471        // vs2 is already 2×SEW; read at wide width
472        // SAFETY: `vs2` aligned to `2*group_regs`; element `i` fits within it
473        let wide_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
474        let wide_b = match src {
475            OpSrc::Vreg(vs1) => {
476                // SAFETY: `vs1` is aligned to `group_regs` and fits within `[0, 32)`,
477                // verified by caller; `i < vl <= group_regs * (VLENB / sew.bytes_width())`,
478                // so `vs1_base + i / elems_per_reg < vs1_base + group_regs <= 32`
479                let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1, i, sew) };
480                if zero_extend_b {
481                    raw_b
482                } else {
483                    sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
484                }
485            }
486            OpSrc::Scalar(val) => {
487                if zero_extend_b {
488                    scalar_unsigned_for_sew(val, sew.bits_width())
489                } else {
490                    scalar_signed_for_sew(val, sew.bits_width())
491                }
492            }
493        };
494        let result = op(wide_a, wide_b);
495        // SAFETY: same as `execute_widen_op` for vd
496        unsafe {
497            write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
498        }
499    }
500    ext_state.mark_vs_dirty();
501    ext_state.reset_vstart();
502}
503
504/// Execute a narrowing right-shift.
505///
506/// `vs2` is 2×SEW wide; the shift amount comes from `src` (SEW-wide or scalar).
507/// The shift amount is masked to `log2(2*SEW)` bits per spec §12.6.
508/// `arithmetic` selects sign-extending (true) vs zero-extending (false) before shifting.
509///
510/// # Safety
511/// - `vd` aligned to `group_regs`, fits in `[0,32)`
512/// - `vs2` aligned to `wide_group_regs`, fits in `[0,32)`; aliasing with the low half of `vs2` is
513///   permitted per spec §11.7 - reads complete before writes to any overlapping element since the
514///   destination SEW is half the source SEW
515/// - `src` register (when `OpSrc::Vreg`) aligned to `group_regs`, fits in `[0,32)`
516/// - `vl <= group_regs * VLENB / sew.bytes_width()`
517/// - SEW < 64
518/// - When `vm=false`: `vd.to_bits() != 0`
519#[inline(always)]
520#[doc(hidden)]
521pub unsafe fn execute_narrow_shift<Reg, ExtState, CustomError>(
522    ext_state: &mut ExtState,
523    vd: VReg,
524    vs2: VReg,
525    src: OpSrc,
526    vm: bool,
527    sew: Vsew,
528    arithmetic: bool,
529) where
530    Reg: Register,
531    ExtState: VectorRegistersExt<Reg, CustomError>,
532    [(); ExtState::ELEN as usize]:,
533    [(); ExtState::VLEN as usize]:,
534    [(); ExtState::VLENB as usize]:,
535    CustomError: fmt::Debug,
536{
537    let vl = ext_state.vl();
538    let vstart = ext_state.vstart();
539    let wide_sew = sew
540        .double_width()
541        .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
542    // Shift amount mask: log2(2*SEW) bits = log2(SEW) + 1 bits
543    let shamt_mask = u64::from(wide_sew.bits_width() - 1);
544
545    // SAFETY: `vl <= VLEN`
546    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
547
548    for i in u32::from(vstart)..vl {
549        if !mask_bit(&mask_buf, i) {
550            continue;
551        }
552        // SAFETY: `vs2` is the wide source group
553        let wide_val = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
554        let shamt = match src {
555            OpSrc::Vreg(vs1_base) => {
556                // SAFETY: `vs1` is aligned to `group_regs` and fits within `[0, 32)`,
557                // verified by caller; `i < vl <= group_regs * (VLENB / sew.bytes_width())`,
558                // so `vs1_base + i / elems_per_reg < vs1_base + group_regs <= 32`
559                let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
560                raw & shamt_mask
561            }
562            // Scalar shift amount: only the low log2(2*SEW) bits are used per spec
563            OpSrc::Scalar(val) => val & shamt_mask,
564        };
565        let result_wide = if arithmetic {
566            // Sign-extend to i64 first, then shift arithmetically as i64 to
567            // preserve sign bits, then cast back. Shifting u64 after cast_unsigned()
568            // would be a logical shift and lose sign bits.
569            (sign_extend_bits(wide_val, wide_sew.bits_width()) >> shamt).cast_unsigned()
570        } else {
571            wide_val >> shamt
572        };
573        // Truncate to SEW bits
574        let result = result_wide & ((1u64 << sew.bits_width()) - 1);
575        // SAFETY: `vd` is the narrow destination group
576        unsafe {
577            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
578        }
579    }
580    ext_state.mark_vs_dirty();
581    ext_state.reset_vstart();
582}
583
584/// Execute an integer extension (vzext/vsext).
585///
586/// Source element width is `sew.divide_by_factor(factor).bytes_width()`; destination is
587/// `sew.bytes_width()`. `sign_extend` selects sign- vs zero-extension.
588///
589/// The source EMUL = LMUL / factor; the source register group is `max(1, group_regs / factor)`
590/// registers.
591///
592/// # Safety
593/// - `vd` aligned to `group_regs`, fits in `[0,32)`
594/// - `vs2` aligned to `src_group_regs`, fits in `[0,32)`, does not overlap `vd`
595/// - `vl <= group_regs * VLENB / sew.bytes_width()`
596/// - `sew.divide_by_factor(factor).is_some()`
597/// - When `vm=false`: `vd.to_bits() != 0`
598#[inline(always)]
599#[doc(hidden)]
600pub unsafe fn execute_extension<Reg, ExtState, CustomError>(
601    ext_state: &mut ExtState,
602    vd: VReg,
603    vs2: VReg,
604    vm: bool,
605    sew: Vsew,
606    factor: VsewFactor,
607    sign: bool,
608) where
609    Reg: Register,
610    ExtState: VectorRegistersExt<Reg, CustomError>,
611    [(); ExtState::ELEN as usize]:,
612    [(); ExtState::VLEN as usize]:,
613    [(); ExtState::VLENB as usize]:,
614    CustomError: fmt::Debug,
615{
616    let vl = ext_state.vl();
617    let vstart = ext_state.vstart();
618    let src_sew = sew
619        .divide_by_factor(factor)
620        .expect("SEW >= factor*8 and valid according to function contract; qed");
621
622    // SAFETY: `vl <= VLEN`
623    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
624
625    for i in u32::from(vstart)..vl {
626        if !mask_bit(&mask_buf, i) {
627            continue;
628        }
629        // SAFETY: vs2 group covers `vl` narrow elements
630        let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, src_sew) };
631        let result = if sign {
632            sign_extend_bits(raw, src_sew.bits_width()).cast_unsigned()
633        } else {
634            raw
635        };
636        // SAFETY: vd group covers `vl` wide elements
637        unsafe {
638            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
639        }
640    }
641    ext_state.mark_vs_dirty();
642    ext_state.reset_vstart();
643}