1use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10
11#[inline(always)]
14#[doc(hidden)]
15pub fn check_vd_widen_no_src_check<Reg, Memory, PC, CustomError>(
16 program_counter: &PC,
17 vd: VReg,
18 wide_group_regs: u8,
19) -> Result<(), ExecutionError<Reg::Type, CustomError>>
20where
21 Reg: Register,
22 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
23{
24 let vd_idx = vd.to_bits();
25 if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
26 return Err(ExecutionError::IllegalInstruction {
27 address: program_counter.old_pc(INSTRUCTION_SIZE),
28 });
29 }
30 Ok(())
31}
32
33#[inline(always)]
42#[doc(hidden)]
43pub fn check_vs_ext_alignment<Reg, Memory, PC, CustomError>(
44 program_counter: &PC,
45 vs2: VReg,
46 src_group_regs: u8,
47 vd: VReg,
48 group_regs: u8,
49) -> Result<(), ExecutionError<Reg::Type, CustomError>>
50where
51 Reg: Register,
52 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
53{
54 let vs2_idx = vs2.to_bits();
55 if !vs2_idx.is_multiple_of(src_group_regs) || vs2_idx + src_group_regs > 32 {
56 return Err(ExecutionError::IllegalInstruction {
57 address: program_counter.old_pc(INSTRUCTION_SIZE),
58 });
59 }
60 if widen_src_overlap_illegal(vd.to_bits(), group_regs, vs2_idx, src_group_regs) {
63 return Err(ExecutionError::IllegalInstruction {
64 address: program_counter.old_pc(INSTRUCTION_SIZE),
65 });
66 }
67 Ok(())
68}
69
70#[inline(always)]
83#[doc(hidden)]
84pub fn check_vd_widen_alignment<Reg, Memory, PC, CustomError>(
85 program_counter: &PC,
86 vd: VReg,
87 vs_a: VReg,
88 vs_b_opt: Option<VReg>,
89 group_regs: u8,
90 wide_group_regs: u8,
91) -> Result<(), ExecutionError<Reg::Type, CustomError>>
92where
93 Reg: Register,
94 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
95{
96 let vd_idx = vd.to_bits();
97 if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
98 return Err(ExecutionError::IllegalInstruction {
99 address: program_counter.old_pc(INSTRUCTION_SIZE),
100 });
101 }
102 if widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_a.to_bits(), group_regs) {
103 return Err(ExecutionError::IllegalInstruction {
104 address: program_counter.old_pc(INSTRUCTION_SIZE),
105 });
106 }
107 if let Some(vs_b) = vs_b_opt
108 && widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_b.to_bits(), group_regs)
109 {
110 return Err(ExecutionError::IllegalInstruction {
111 address: program_counter.old_pc(INSTRUCTION_SIZE),
112 });
113 }
114 Ok(())
115}
116
117#[inline(always)]
126fn widen_src_overlap_illegal(vd_idx: u8, wide_group_regs: u8, vs_idx: u8, group_regs: u8) -> bool {
127 if !ranges_overlap(vd_idx, wide_group_regs, vs_idx, group_regs) {
128 return false;
129 }
130 let high_part_overlap =
131 wide_group_regs > group_regs && vs_idx == vd_idx + wide_group_regs - group_regs;
132 !high_part_overlap
133}
134
135#[inline(always)]
138#[doc(hidden)]
139pub fn check_vs_wide_alignment<Reg, Memory, PC, CustomError>(
140 program_counter: &PC,
141 vs: VReg,
142 wide_group_regs: u8,
143) -> Result<(), ExecutionError<Reg::Type, CustomError>>
144where
145 Reg: Register,
146 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
147{
148 let vs_idx = vs.to_bits();
149 if !vs_idx.is_multiple_of(wide_group_regs) || vs_idx + wide_group_regs > 32 {
150 return Err(ExecutionError::IllegalInstruction {
151 address: program_counter.old_pc(INSTRUCTION_SIZE),
152 });
153 }
154 Ok(())
155}
156
157#[inline(always)]
163#[doc(hidden)]
164pub fn check_vd_narrow_alignment<Reg, Memory, PC, CustomError>(
165 program_counter: &PC,
166 vd: VReg,
167 group_regs: u8,
168) -> Result<(), ExecutionError<Reg::Type, CustomError>>
169where
170 Reg: Register,
171 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
172{
173 let vd_idx = vd.to_bits();
174 if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
175 return Err(ExecutionError::IllegalInstruction {
176 address: program_counter.old_pc(INSTRUCTION_SIZE),
177 });
178 }
179 Ok(())
180}
181
182#[inline(always)]
184fn ranges_overlap(a_start: u8, a_len: u8, b_start: u8, b_len: u8) -> bool {
185 a_start < b_start + b_len && b_start < a_start + a_len
186}
187
188#[inline(always)]
190fn mask_bit(mask: &[u8], i: u32) -> bool {
191 mask.get((i / u8::BITS) as usize)
192 .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
193}
194
195#[inline(always)]
202unsafe fn snapshot_mask<const VLENB: usize>(
203 vregs: &VectorRegisterFile<VLENB>,
204 vm: bool,
205 vl: u32,
206) -> [u8; VLENB] {
207 let mut buf = [0u8; VLENB];
208 if vm {
209 buf = [0xffu8; VLENB];
210 } else {
211 let mask_bytes = vl.div_ceil(u8::BITS) as usize;
212 unsafe {
214 buf.get_unchecked_mut(..mask_bytes)
215 .copy_from_slice(vregs.get(VReg::V0).get_unchecked(..mask_bytes));
216 }
217 }
218 buf
219}
220
221#[inline(always)]
227unsafe fn read_element_u64<const VLENB: usize>(
228 vregs: &VectorRegisterFile<VLENB>,
229 base_reg: VReg,
230 elem_i: u32,
231 sew: Vsew,
232) -> u64 {
233 let sew_bytes = usize::from(sew.bytes_width());
234 let elems_per_reg = VLENB / sew_bytes;
235 let reg_off = elem_i as usize / elems_per_reg;
236 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
237 let reg = unsafe {
239 vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
240 };
241 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
243 let mut buf = [0u8; 8];
244 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
246 u64::from_le_bytes(buf)
247}
248
249#[inline(always)]
254unsafe fn write_element_u64<const VLENB: usize>(
255 vregs: &mut VectorRegisterFile<VLENB>,
256 base_reg: VReg,
257 elem_i: u32,
258 sew: Vsew,
259 value: u64,
260) {
261 let sew_bytes = usize::from(sew.bytes_width());
262 let elems_per_reg = VLENB / sew_bytes;
263 let reg_off = elem_i as usize / elems_per_reg;
264 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
265 let buf = value.to_le_bytes();
266 let reg = unsafe {
268 vregs.get_mut(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
269 };
270 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
272 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
274}
275
276#[inline(always)]
278#[doc(hidden)]
279pub fn sign_extend_bits(val: u64, bits: u8) -> i64 {
280 let shift = u64::BITS - u32::from(bits);
281 (val.cast_signed() << shift) >> shift
282}
283
284#[inline(always)]
305fn scalar_unsigned_for_sew(val: u64, sew_bits: u8) -> u64 {
306 val & (u64::MAX >> (u64::BITS - u32::from(sew_bits)))
307}
308
309#[inline(always)]
331fn scalar_signed_for_sew(val: u64, sew_bits: u8) -> u64 {
332 sign_extend_bits(val, sew_bits).cast_unsigned()
333}
334
335#[inline(always)]
353#[expect(clippy::too_many_arguments, reason = "Internal API")]
354#[doc(hidden)]
355pub unsafe fn execute_widen_op<Reg, ExtState, CustomError, F>(
356 ext_state: &mut ExtState,
357 vd: VReg,
358 vs2: VReg,
359 src: OpSrc,
360 vm: bool,
361 sew: Vsew,
362 zero_extend_a: bool,
363 zero_extend_b: bool,
364 op: F,
365) where
366 Reg: Register,
367 ExtState: VectorRegistersExt<Reg, CustomError>,
368 [(); ExtState::ELEN as usize]:,
369 [(); ExtState::VLEN as usize]:,
370 [(); ExtState::VLENB as usize]:,
371 CustomError: fmt::Debug,
372 F: Fn(u64, u64) -> u64,
373{
374 let vl = ext_state.vl();
375 let vstart = ext_state.vstart();
376 let wide_sew = sew
377 .double_width()
378 .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
379
380 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
382
383 for i in u32::from(vstart)..vl {
384 if !mask_bit(&mask_buf, i) {
385 continue;
386 }
387 let raw_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
390 let wide_a = if zero_extend_a {
391 raw_a
392 } else {
393 sign_extend_bits(raw_a, sew.bits_width()).cast_unsigned()
394 };
395 let wide_b = match src {
396 OpSrc::Vreg(vs1_base) => {
397 let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
399 if zero_extend_b {
400 raw_b
401 } else {
402 sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
403 }
404 }
405 OpSrc::Scalar(val) => {
406 if zero_extend_b {
407 scalar_unsigned_for_sew(val, sew.bits_width())
408 } else {
409 scalar_signed_for_sew(val, sew.bits_width())
410 }
411 }
412 };
413 let result = op(wide_a, wide_b);
414 unsafe {
418 write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
419 }
420 }
421 ext_state.mark_vs_dirty();
422 ext_state.reset_vstart();
423}
424
425#[inline(always)]
438#[expect(clippy::too_many_arguments, reason = "Internal API")]
439#[doc(hidden)]
440pub unsafe fn execute_widen_w_op<Reg, ExtState, CustomError, F>(
441 ext_state: &mut ExtState,
442 vd: VReg,
443 vs2: VReg,
444 src: OpSrc,
445 vm: bool,
446 sew: Vsew,
447 zero_extend_b: bool,
448 op: F,
449) where
450 Reg: Register,
451 ExtState: VectorRegistersExt<Reg, CustomError>,
452 [(); ExtState::ELEN as usize]:,
453 [(); ExtState::VLEN as usize]:,
454 [(); ExtState::VLENB as usize]:,
455 CustomError: fmt::Debug,
456 F: Fn(u64, u64) -> u64,
457{
458 let vl = ext_state.vl();
459 let vstart = ext_state.vstart();
460 let wide_sew = sew
461 .double_width()
462 .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
463
464 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
466
467 for i in u32::from(vstart)..vl {
468 if !mask_bit(&mask_buf, i) {
469 continue;
470 }
471 let wide_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
474 let wide_b = match src {
475 OpSrc::Vreg(vs1) => {
476 let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1, i, sew) };
480 if zero_extend_b {
481 raw_b
482 } else {
483 sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
484 }
485 }
486 OpSrc::Scalar(val) => {
487 if zero_extend_b {
488 scalar_unsigned_for_sew(val, sew.bits_width())
489 } else {
490 scalar_signed_for_sew(val, sew.bits_width())
491 }
492 }
493 };
494 let result = op(wide_a, wide_b);
495 unsafe {
497 write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
498 }
499 }
500 ext_state.mark_vs_dirty();
501 ext_state.reset_vstart();
502}
503
504#[inline(always)]
520#[doc(hidden)]
521pub unsafe fn execute_narrow_shift<Reg, ExtState, CustomError>(
522 ext_state: &mut ExtState,
523 vd: VReg,
524 vs2: VReg,
525 src: OpSrc,
526 vm: bool,
527 sew: Vsew,
528 arithmetic: bool,
529) where
530 Reg: Register,
531 ExtState: VectorRegistersExt<Reg, CustomError>,
532 [(); ExtState::ELEN as usize]:,
533 [(); ExtState::VLEN as usize]:,
534 [(); ExtState::VLENB as usize]:,
535 CustomError: fmt::Debug,
536{
537 let vl = ext_state.vl();
538 let vstart = ext_state.vstart();
539 let wide_sew = sew
540 .double_width()
541 .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
542 let shamt_mask = u64::from(wide_sew.bits_width() - 1);
544
545 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
547
548 for i in u32::from(vstart)..vl {
549 if !mask_bit(&mask_buf, i) {
550 continue;
551 }
552 let wide_val = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
554 let shamt = match src {
555 OpSrc::Vreg(vs1_base) => {
556 let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
560 raw & shamt_mask
561 }
562 OpSrc::Scalar(val) => val & shamt_mask,
564 };
565 let result_wide = if arithmetic {
566 (sign_extend_bits(wide_val, wide_sew.bits_width()) >> shamt).cast_unsigned()
570 } else {
571 wide_val >> shamt
572 };
573 let result = result_wide & ((1u64 << sew.bits_width()) - 1);
575 unsafe {
577 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
578 }
579 }
580 ext_state.mark_vs_dirty();
581 ext_state.reset_vstart();
582}
583
584#[inline(always)]
599#[doc(hidden)]
600pub unsafe fn execute_extension<Reg, ExtState, CustomError>(
601 ext_state: &mut ExtState,
602 vd: VReg,
603 vs2: VReg,
604 vm: bool,
605 sew: Vsew,
606 factor: VsewFactor,
607 sign: bool,
608) where
609 Reg: Register,
610 ExtState: VectorRegistersExt<Reg, CustomError>,
611 [(); ExtState::ELEN as usize]:,
612 [(); ExtState::VLEN as usize]:,
613 [(); ExtState::VLENB as usize]:,
614 CustomError: fmt::Debug,
615{
616 let vl = ext_state.vl();
617 let vstart = ext_state.vstart();
618 let src_sew = sew
619 .divide_by_factor(factor)
620 .expect("SEW >= factor*8 and valid according to function contract; qed");
621
622 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
624
625 for i in u32::from(vstart)..vl {
626 if !mask_bit(&mask_buf, i) {
627 continue;
628 }
629 let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, src_sew) };
631 let result = if sign {
632 sign_extend_bits(raw, src_sew.bits_width()).cast_unsigned()
633 } else {
634 raw
635 };
636 unsafe {
638 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
639 }
640 }
641 ext_state.mark_vs_dirty();
642 ext_state.reset_vstart();
643}