Skip to main content

ab_riscv_interpreter/v/zvexx/widen_narrow/
zvexx_widen_narrow_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10use core::hint::cold_path;
11
12/// Check that a widening destination `vd` is aligned to `wide_group_regs` and fits within
13/// `[0,32)`, without any source overlap check
14#[inline(always)]
15#[doc(hidden)]
16pub fn check_vd_widen_no_src_check<Reg, Memory, PC, CustomError>(
17    program_counter: &PC,
18    vd: VReg,
19    wide_group_regs: u8,
20) -> Result<(), ExecutionError<Reg::Type, CustomError>>
21where
22    Reg: Register,
23    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
24{
25    let vd_idx = vd.to_bits();
26    if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
27        cold_path();
28        return Err(ExecutionError::IllegalInstruction {
29            address: program_counter.old_pc(INSTRUCTION_SIZE),
30        });
31    }
32    Ok(())
33}
34
35/// Check that an extension source `vs2` is aligned to `src_group_regs`, fits in `[0,32)`, and only
36/// overlaps `vd` (which occupies `group_regs` registers) in a manner permitted by the spec.
37///
38/// Per the vector spec §5.2, the destination EEW (SEW) of an extension is greater than the source
39/// EEW (SEW/factor), so the destination may overlap the source only when the source EMUL is at
40/// least 1 and the overlap is in the highest-numbered part of the destination register group (e.g.
41/// `vzext.vf4 v0, v6` with LMUL=8, where the narrow source `{v6,v7}` aliases the high registers of
42/// the wide `{v0..v7}` destination). Any other overlap is illegal.
43#[inline(always)]
44#[doc(hidden)]
45pub fn check_vs_ext_alignment<Reg, Memory, PC, CustomError>(
46    program_counter: &PC,
47    vs2: VReg,
48    src_group_regs: u8,
49    vd: VReg,
50    group_regs: u8,
51) -> Result<(), ExecutionError<Reg::Type, CustomError>>
52where
53    Reg: Register,
54    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
55{
56    let vs2_idx = vs2.to_bits();
57    if !vs2_idx.is_multiple_of(src_group_regs) || vs2_idx + src_group_regs > 32 {
58        cold_path();
59        return Err(ExecutionError::IllegalInstruction {
60            address: program_counter.old_pc(INSTRUCTION_SIZE),
61        });
62    }
63    // The wide destination (group_regs) may overlap the narrow source (src_group_regs) only in the
64    // highest-numbered part of the destination group, and only when the source EMUL >= 1.
65    if widen_src_overlap_illegal(vd.to_bits(), group_regs, vs2_idx, src_group_regs) {
66        cold_path();
67        return Err(ExecutionError::IllegalInstruction {
68            address: program_counter.old_pc(INSTRUCTION_SIZE),
69        });
70    }
71    Ok(())
72}
73
74/// Check that a widening destination `vd` is aligned to `wide_group_regs`, fits within `[0, 32)`,
75/// and only overlaps the `group_regs`-register narrow source(s) starting at `vs_a`/`vs_b` in a
76/// manner permitted by the spec.
77///
78/// `wide_group_regs` is the pre-computed register count for the wide EMUL (2*LMUL), obtained via
79/// `Vlmul::index_register_count(wide_eew, sew)`. `group_regs` is the narrow LMUL register count.
80///
81/// Per the vector spec §5.2, a destination whose EEW (2*SEW) is greater than a source's EEW (SEW)
82/// may overlap that source only when the source EMUL is at least 1 and the overlap is in the
83/// highest-numbered part of the destination register group (e.g. `vwsubu.wv v2, v14, v3` with
84/// LMUL=1, where the narrow `v3` aliases the high register of the wide `{v2, v3}` destination).
85/// Any other overlap is illegal.
86#[inline(always)]
87#[doc(hidden)]
88pub fn check_vd_widen_alignment<Reg, Memory, PC, CustomError>(
89    program_counter: &PC,
90    vd: VReg,
91    vs_a: VReg,
92    vs_b_opt: Option<VReg>,
93    group_regs: u8,
94    wide_group_regs: u8,
95) -> Result<(), ExecutionError<Reg::Type, CustomError>>
96where
97    Reg: Register,
98    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
99{
100    let vd_idx = vd.to_bits();
101    if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
102        cold_path();
103        return Err(ExecutionError::IllegalInstruction {
104            address: program_counter.old_pc(INSTRUCTION_SIZE),
105        });
106    }
107    if widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_a.to_bits(), group_regs) {
108        cold_path();
109        return Err(ExecutionError::IllegalInstruction {
110            address: program_counter.old_pc(INSTRUCTION_SIZE),
111        });
112    }
113    if let Some(vs_b) = vs_b_opt
114        && widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_b.to_bits(), group_regs)
115    {
116        cold_path();
117        return Err(ExecutionError::IllegalInstruction {
118            address: program_counter.old_pc(INSTRUCTION_SIZE),
119        });
120    }
121    Ok(())
122}
123
124/// Returns `true` when a narrow source group of `group_regs` registers starting at `vs_idx`
125/// overlaps the wide destination group (`wide_group_regs` registers starting at `vd_idx`) in a way
126/// that is *not* permitted by the spec.
127///
128/// Overlap is only legal when the source EMUL is at least 1 - which, on widening, is exactly when
129/// the destination register count strictly exceeds the narrow source count (for fractional LMUL
130/// both counts collapse to 1) - and the source occupies the highest-numbered registers of the
131/// destination group.
132#[inline(always)]
133fn widen_src_overlap_illegal(vd_idx: u8, wide_group_regs: u8, vs_idx: u8, group_regs: u8) -> bool {
134    if !ranges_overlap(vd_idx, wide_group_regs, vs_idx, group_regs) {
135        return false;
136    }
137    let high_part_overlap =
138        wide_group_regs > group_regs && vs_idx == vd_idx + wide_group_regs - group_regs;
139    !high_part_overlap
140}
141
142/// Check that a widening source `vs2` that is already 2×SEW wide is aligned to `wide_group_regs`
143/// and fits within `[0, 32)`.
144#[inline(always)]
145#[doc(hidden)]
146pub fn check_vs_wide_alignment<Reg, Memory, PC, CustomError>(
147    program_counter: &PC,
148    vs: VReg,
149    wide_group_regs: u8,
150) -> Result<(), ExecutionError<Reg::Type, CustomError>>
151where
152    Reg: Register,
153    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
154{
155    let vs_idx = vs.to_bits();
156    if !vs_idx.is_multiple_of(wide_group_regs) || vs_idx + wide_group_regs > 32 {
157        cold_path();
158        return Err(ExecutionError::IllegalInstruction {
159            address: program_counter.old_pc(INSTRUCTION_SIZE),
160        });
161    }
162    Ok(())
163}
164
165/// Check that a narrowing destination `vd` is aligned to `group_regs` and fits
166/// within `[0, 32)`.
167///
168/// No overlap check against `vs2` is performed here because narrowing instructions
169/// permit `vd` to alias the low half of the wide `vs2` register group per spec §11.7.
170#[inline(always)]
171#[doc(hidden)]
172pub fn check_vd_narrow_alignment<Reg, Memory, PC, CustomError>(
173    program_counter: &PC,
174    vd: VReg,
175    group_regs: u8,
176) -> Result<(), ExecutionError<Reg::Type, CustomError>>
177where
178    Reg: Register,
179    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
180{
181    let vd_idx = vd.to_bits();
182    if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
183        cold_path();
184        return Err(ExecutionError::IllegalInstruction {
185            address: program_counter.old_pc(INSTRUCTION_SIZE),
186        });
187    }
188    Ok(())
189}
190
191/// Returns `true` when `[a_start, a_start+a_len)` overlaps `[b_start, b_start+b_len)`.
192#[inline(always)]
193fn ranges_overlap(a_start: u8, a_len: u8, b_start: u8, b_len: u8) -> bool {
194    a_start < b_start + b_len && b_start < a_start + a_len
195}
196
197/// Return whether mask bit `i` is set in the mask byte slice (LSB-first within each byte).
198#[inline(always)]
199fn mask_bit(mask: &[u8], i: u32) -> bool {
200    mask.get((i / u8::BITS) as usize)
201        .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
202}
203
204/// Snapshot the mask register into a stack buffer.
205///
206/// When `vm=true` (unmasked), all bytes are `0xff`.
207///
208/// # Safety
209/// `vl.div_ceil(8) <= VLENB` must hold. This is guaranteed when `vl <= VLEN`.
210#[inline(always)]
211unsafe fn snapshot_mask<const VLENB: usize>(
212    vregs: &VectorRegisterFile<VLENB>,
213    vm: bool,
214    vl: u32,
215) -> [u8; VLENB] {
216    let mut buf = [0u8; VLENB];
217    if vm {
218        buf = [0xffu8; VLENB];
219    } else {
220        let mask_bytes = vl.div_ceil(u8::BITS) as usize;
221        // SAFETY: `mask_bytes <= VLENB` by precondition
222        unsafe {
223            buf.get_unchecked_mut(..mask_bytes)
224                .copy_from_slice(vregs.get(VReg::V0).get_unchecked(..mask_bytes));
225        }
226    }
227    buf
228}
229
230/// Read the low `sew.bytes_width()` of the element `elem_i` from the register group `base_reg`,
231/// zero-extended to `u64`.
232///
233/// # Safety
234/// `base_reg + elem_i / (VLENB / sew.bytes_width()) < 32`
235#[inline(always)]
236unsafe fn read_element_u64<const VLENB: usize>(
237    vregs: &VectorRegisterFile<VLENB>,
238    base_reg: VReg,
239    elem_i: u32,
240    sew: Vsew,
241) -> u64 {
242    let sew_bytes = usize::from(sew.bytes_width());
243    let elems_per_reg = VLENB / sew_bytes;
244    let reg_off = elem_i as usize / elems_per_reg;
245    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
246    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
247    let reg = unsafe {
248        vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
249    };
250    // SAFETY: `byte_off + sew_bytes <= VLENB`
251    let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
252    let mut buf = [0u8; 8];
253    // SAFETY: `sew_bytes <= 8`
254    unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
255    u64::from_le_bytes(buf)
256}
257
258/// Write the low `sew.bytes_width()` of `value` into element `elem_i` in register group `base_reg`.
259///
260/// # Safety
261/// `base_reg + elem_i / (VLENB / sew.bytes_width()) < 32`
262#[inline(always)]
263unsafe fn write_element_u64<const VLENB: usize>(
264    vregs: &mut VectorRegisterFile<VLENB>,
265    base_reg: VReg,
266    elem_i: u32,
267    sew: Vsew,
268    value: u64,
269) {
270    let sew_bytes = usize::from(sew.bytes_width());
271    let elems_per_reg = VLENB / sew_bytes;
272    let reg_off = elem_i as usize / elems_per_reg;
273    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
274    let buf = value.to_le_bytes();
275    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
276    let reg = unsafe {
277        vregs.get_mut(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
278    };
279    // SAFETY: `byte_off + sew_bytes <= VLENB`
280    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
281    // SAFETY: `sew_bytes <= 8`
282    dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
283}
284
285/// Sign-extend the low `bits` of `val` to `i64`.
286#[inline(always)]
287#[doc(hidden)]
288pub fn sign_extend_bits(val: u64, bits: u8) -> i64 {
289    let shift = u64::BITS - u32::from(bits);
290    (val.cast_signed() << shift) >> shift
291}
292
293/// Interpret a scalar operand as an unsigned SEW-wide value.
294///
295/// RVV widening scalar instructions (.vx/.wx) conceptually use a scalar
296/// operand whose width matches the current SEW, not the full XLEN width.
297///
298/// For example on RV64:
299///
300/// SEW=8:
301///     val = 0x0000_0000_0000_01ff
302///     result = 0x0000_0000_0000_00ff
303///
304/// SEW=16:
305///     val = 0x0000_0000_0000_01ff
306///     result = 0x0000_0000_0000_01ff
307///
308/// SEW=32:
309///     val = 0xffff_ffff_1234_5678
310///     result = 0x0000_0000_1234_5678
311///
312/// This helper performs that SEW-width truncation without sign extension.
313#[inline(always)]
314fn scalar_unsigned_for_sew(val: u64, sew_bits: u8) -> u64 {
315    val & (u64::MAX >> (u64::BITS - u32::from(sew_bits)))
316}
317
318/// Interpret a scalar operand as a signed SEW-wide value.
319///
320/// The scalar is first truncated to SEW bits, then sign-extended back to
321/// 64 bits.
322///
323/// For example on RV64:
324///
325/// SEW=8:
326///     val = 0x0000_0000_0000_00ff
327///     result = 0xffff_ffff_ffff_ffff (-1)
328///
329/// SEW=8:
330///     val = 0x0000_0000_0000_007f
331///     result = 0x0000_0000_0000_007f (+127)
332///
333/// SEW=16:
334///     val = 0x0000_0000_0000_ffff
335///     result = 0xffff_ffff_ffff_ffff (-1)
336///
337/// This matches the signed widening behavior required by instructions such
338/// as vwadd.vx and vwsub.vx.
339#[inline(always)]
340fn scalar_signed_for_sew(val: u64, sew_bits: u8) -> u64 {
341    sign_extend_bits(val, sew_bits).cast_unsigned()
342}
343
344/// Execute a widening integer add/subtract.
345///
346/// Each source element is SEW-wide; the destination element is 2×SEW-wide.
347/// `ZERO_EXTEND_AB` selects unsigned or signed widening for sources (unsigned = zero-extend,
348/// signed = sign-extend).
349///
350/// `op` receives `(wide_a: u64, wide_b: u64) -> u64`.
351///
352/// # Safety
353/// - `vd` aligned to `2*group_regs`, fits in `[0,32)`, does not overlap `vs2` or `src` (verified by
354///   caller)
355/// - `vs2` aligned to `group_regs`, fits in `[0,32)` (verified by caller)
356/// - `src` register (when `WidenSrc::Vreg`) aligned to `group_regs`, fits in `[0,32)` (verified by
357///   caller)
358/// - `vl <= group_regs * VLENB / sew.bytes_width()` (all elements fit)
359/// - SEW < 64
360/// - When `vm=false`: `vd.to_bits() != 0`
361#[inline(always)]
362#[doc(hidden)]
363pub unsafe fn execute_widen_op<const ZERO_EXTEND_AB: bool, Reg, ExtState, CustomError, F>(
364    ext_state: &mut ExtState,
365    vd: VReg,
366    vs2: VReg,
367    src: OpSrc,
368    vm: bool,
369    sew: Vsew,
370    op: F,
371) where
372    Reg: Register,
373    ExtState: VectorRegistersExt<Reg, CustomError>,
374    [(); ExtState::ELEN as usize]:,
375    [(); ExtState::VLEN as usize]:,
376    [(); ExtState::VLENB as usize]:,
377    CustomError: fmt::Debug,
378    F: Fn(u64, u64) -> u64,
379{
380    let vl = ext_state.vl();
381    let vstart = ext_state.vstart();
382    let wide_sew = sew
383        .double_width()
384        .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
385
386    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
387    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
388
389    for i in u32::from(vstart)..vl {
390        if !mask_bit(&mask_buf, i) {
391            continue;
392        }
393        // SAFETY: `vs2` aligned to `group_regs`;
394        // `i < vl <= group_regs * (VLENB / sew.bytes_width())`
395        let raw_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
396        let wide_a = if ZERO_EXTEND_AB {
397            raw_a
398        } else {
399            sign_extend_bits(raw_a, sew.bits_width()).cast_unsigned()
400        };
401        let wide_b = match src {
402            OpSrc::Vreg(vs1_base) => {
403                // SAFETY: same argument as vs2
404                let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
405                if ZERO_EXTEND_AB {
406                    raw_b
407                } else {
408                    sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
409                }
410            }
411            OpSrc::Scalar(val) => {
412                if ZERO_EXTEND_AB {
413                    scalar_unsigned_for_sew(val, sew.bits_width())
414                } else {
415                    scalar_signed_for_sew(val, sew.bits_width())
416                }
417            }
418        };
419        let result = op(wide_a, wide_b);
420        // SAFETY: `vd` aligned to `2*group_regs`;
421        // `i < vl <= group_regs * (VLENB / sew.bytes_width())` so
422        // `i < 2*group_regs * (VLENB / wide_sew.bytes_width())` - element fits in the wide group
423        unsafe {
424            write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
425        }
426    }
427    ext_state.mark_vs_dirty();
428    ext_state.reset_vstart();
429}
430
431/// Execute a widening add/subtract where `vs2` is already 2×SEW wide.
432///
433/// `vs2` is read at `wide_sew.bytes_width()`; `src` (narrow) is read at `sew.bytes_width()` and
434/// widened. `ZERO_EXTEND_B` selects unsigned vs signed widening for the narrow source operand.
435///
436/// # Safety
437/// - `vd` aligned to `2*group_regs`, fits in `[0,32)`, does not overlap `vs2` or `src`
438/// - `vs2` aligned to `2*group_regs`, fits in `[0,32)` (wide source)
439/// - `src` register (when `WidenSrc::Vreg`) aligned to `group_regs`, fits in `[0,32)`
440/// - `vl <= group_regs * VLENB / sew.bytes_width()`
441/// - SEW < 64
442/// - When `vm=false`: `vd.to_bits() != 0`
443#[inline(always)]
444#[doc(hidden)]
445pub unsafe fn execute_widen_w_op<const ZERO_EXTEND_B: bool, Reg, ExtState, CustomError, F>(
446    ext_state: &mut ExtState,
447    vd: VReg,
448    vs2: VReg,
449    src: OpSrc,
450    vm: bool,
451    sew: Vsew,
452    op: F,
453) where
454    Reg: Register,
455    ExtState: VectorRegistersExt<Reg, CustomError>,
456    [(); ExtState::ELEN as usize]:,
457    [(); ExtState::VLEN as usize]:,
458    [(); ExtState::VLENB as usize]:,
459    CustomError: fmt::Debug,
460    F: Fn(u64, u64) -> u64,
461{
462    let vl = ext_state.vl();
463    let vstart = ext_state.vstart();
464    let wide_sew = sew
465        .double_width()
466        .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
467
468    // SAFETY: `vl <= VLEN`
469    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
470
471    for i in u32::from(vstart)..vl {
472        if !mask_bit(&mask_buf, i) {
473            continue;
474        }
475        // vs2 is already 2×SEW; read at wide width
476        // SAFETY: `vs2` aligned to `2*group_regs`; element `i` fits within it
477        let wide_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
478        let wide_b = match src {
479            OpSrc::Vreg(vs1) => {
480                // SAFETY: `vs1` is aligned to `group_regs` and fits within `[0, 32)`,
481                // verified by caller; `i < vl <= group_regs * (VLENB / sew.bytes_width())`,
482                // so `vs1_base + i / elems_per_reg < vs1_base + group_regs <= 32`
483                let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1, i, sew) };
484                if ZERO_EXTEND_B {
485                    raw_b
486                } else {
487                    sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
488                }
489            }
490            OpSrc::Scalar(val) => {
491                if ZERO_EXTEND_B {
492                    scalar_unsigned_for_sew(val, sew.bits_width())
493                } else {
494                    scalar_signed_for_sew(val, sew.bits_width())
495                }
496            }
497        };
498        let result = op(wide_a, wide_b);
499        // SAFETY: same as `execute_widen_op` for vd
500        unsafe {
501            write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
502        }
503    }
504    ext_state.mark_vs_dirty();
505    ext_state.reset_vstart();
506}
507
508/// Execute a narrowing right-shift.
509///
510/// `vs2` is 2×SEW wide; the shift amount comes from `src` (SEW-wide or scalar).
511/// The shift amount is masked to `log2(2*SEW)` bits per spec §12.6.
512/// `ARITHMETIC` selects sign-extending (true) vs zero-extending (false) before shifting.
513///
514/// # Safety
515/// - `vd` aligned to `group_regs`, fits in `[0,32)`
516/// - `vs2` aligned to `wide_group_regs`, fits in `[0,32)`; aliasing with the low half of `vs2` is
517///   permitted per spec §11.7 - reads complete before writes to any overlapping element since the
518///   destination SEW is half the source SEW
519/// - `src` register (when `OpSrc::Vreg`) aligned to `group_regs`, fits in `[0,32)`
520/// - `vl <= group_regs * VLENB / sew.bytes_width()`
521/// - SEW < 64
522/// - When `vm=false`: `vd.to_bits() != 0`
523#[inline(always)]
524#[doc(hidden)]
525pub unsafe fn execute_narrow_shift<const ARITHMETIC: bool, Reg, ExtState, CustomError>(
526    ext_state: &mut ExtState,
527    vd: VReg,
528    vs2: VReg,
529    src: OpSrc,
530    vm: bool,
531    sew: Vsew,
532) where
533    Reg: Register,
534    ExtState: VectorRegistersExt<Reg, CustomError>,
535    [(); ExtState::ELEN as usize]:,
536    [(); ExtState::VLEN as usize]:,
537    [(); ExtState::VLENB as usize]:,
538    CustomError: fmt::Debug,
539{
540    let vl = ext_state.vl();
541    let vstart = ext_state.vstart();
542    let wide_sew = sew
543        .double_width()
544        .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
545    // Shift amount mask: log2(2*SEW) bits = log2(SEW) + 1 bits
546    let shamt_mask = u64::from(wide_sew.bits_width() - 1);
547
548    // SAFETY: `vl <= VLEN`
549    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
550
551    for i in u32::from(vstart)..vl {
552        if !mask_bit(&mask_buf, i) {
553            continue;
554        }
555        // SAFETY: `vs2` is the wide source group
556        let wide_val = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
557        let shamt = match src {
558            OpSrc::Vreg(vs1_base) => {
559                // SAFETY: `vs1` is aligned to `group_regs` and fits within `[0, 32)`,
560                // verified by caller; `i < vl <= group_regs * (VLENB / sew.bytes_width())`,
561                // so `vs1_base + i / elems_per_reg < vs1_base + group_regs <= 32`
562                let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
563                raw & shamt_mask
564            }
565            // Scalar shift amount: only the low log2(2*SEW) bits are used per spec
566            OpSrc::Scalar(val) => val & shamt_mask,
567        };
568        let result_wide = if ARITHMETIC {
569            // Sign-extend to i64 first, then shift arithmetically as i64 to
570            // preserve sign bits, then cast back. Shifting u64 after cast_unsigned()
571            // would be a logical shift and lose sign bits.
572            (sign_extend_bits(wide_val, wide_sew.bits_width()) >> shamt).cast_unsigned()
573        } else {
574            wide_val >> shamt
575        };
576        // Truncate to SEW bits
577        let result = result_wide & ((1u64 << sew.bits_width()) - 1);
578        // SAFETY: `vd` is the narrow destination group
579        unsafe {
580            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
581        }
582    }
583    ext_state.mark_vs_dirty();
584    ext_state.reset_vstart();
585}
586
587/// Execute an integer extension (vzext/vsext).
588///
589/// Source element width is `sew.divide_by_factor(factor).bytes_width()`; destination is
590/// `sew.bytes_width()`. `SIGN` selects sign- or zero-extension.
591///
592/// The source EMUL = LMUL / factor; the source register group is `max(1, group_regs / factor)`
593/// registers.
594///
595/// # Safety
596/// - `vd` aligned to `group_regs`, fits in `[0,32)`
597/// - `vs2` aligned to `src_group_regs`, fits in `[0,32)`, does not overlap `vd`
598/// - `vl <= group_regs * VLENB / sew.bytes_width()`
599/// - `sew.divide_by_factor(factor).is_some()`
600/// - When `vm=false`: `vd.to_bits() != 0`
601#[inline(always)]
602#[doc(hidden)]
603pub unsafe fn execute_extension<const SIGN: bool, Reg, ExtState, CustomError>(
604    ext_state: &mut ExtState,
605    vd: VReg,
606    vs2: VReg,
607    vm: bool,
608    sew: Vsew,
609    factor: VsewFactor,
610) where
611    Reg: Register,
612    ExtState: VectorRegistersExt<Reg, CustomError>,
613    [(); ExtState::ELEN as usize]:,
614    [(); ExtState::VLEN as usize]:,
615    [(); ExtState::VLENB as usize]:,
616    CustomError: fmt::Debug,
617{
618    let vl = ext_state.vl();
619    let vstart = ext_state.vstart();
620    let src_sew = sew
621        .divide_by_factor(factor)
622        .expect("SEW >= factor*8 and valid according to function contract; qed");
623
624    // SAFETY: `vl <= VLEN`
625    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
626
627    for i in u32::from(vstart)..vl {
628        if !mask_bit(&mask_buf, i) {
629            continue;
630        }
631        // SAFETY: vs2 group covers `vl` narrow elements
632        let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, src_sew) };
633        let result = if SIGN {
634            sign_extend_bits(raw, src_sew.bits_width()).cast_unsigned()
635        } else {
636            raw
637        };
638        // SAFETY: vd group covers `vl` wide elements
639        unsafe {
640            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
641        }
642    }
643    ext_state.mark_vs_dirty();
644    ext_state.reset_vstart();
645}