1use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10use core::hint::cold_path;
11
12#[inline(always)]
15#[doc(hidden)]
16pub fn check_vd_widen_no_src_check<Reg, Memory, PC, CustomError>(
17 program_counter: &PC,
18 vd: VReg,
19 wide_group_regs: u8,
20) -> Result<(), ExecutionError<Reg::Type, CustomError>>
21where
22 Reg: Register,
23 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
24{
25 let vd_idx = vd.to_bits();
26 if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
27 cold_path();
28 return Err(ExecutionError::IllegalInstruction {
29 address: program_counter.old_pc(INSTRUCTION_SIZE),
30 });
31 }
32 Ok(())
33}
34
35#[inline(always)]
44#[doc(hidden)]
45pub fn check_vs_ext_alignment<Reg, Memory, PC, CustomError>(
46 program_counter: &PC,
47 vs2: VReg,
48 src_group_regs: u8,
49 vd: VReg,
50 group_regs: u8,
51) -> Result<(), ExecutionError<Reg::Type, CustomError>>
52where
53 Reg: Register,
54 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
55{
56 let vs2_idx = vs2.to_bits();
57 if !vs2_idx.is_multiple_of(src_group_regs) || vs2_idx + src_group_regs > 32 {
58 cold_path();
59 return Err(ExecutionError::IllegalInstruction {
60 address: program_counter.old_pc(INSTRUCTION_SIZE),
61 });
62 }
63 if widen_src_overlap_illegal(vd.to_bits(), group_regs, vs2_idx, src_group_regs) {
66 cold_path();
67 return Err(ExecutionError::IllegalInstruction {
68 address: program_counter.old_pc(INSTRUCTION_SIZE),
69 });
70 }
71 Ok(())
72}
73
74#[inline(always)]
87#[doc(hidden)]
88pub fn check_vd_widen_alignment<Reg, Memory, PC, CustomError>(
89 program_counter: &PC,
90 vd: VReg,
91 vs_a: VReg,
92 vs_b_opt: Option<VReg>,
93 group_regs: u8,
94 wide_group_regs: u8,
95) -> Result<(), ExecutionError<Reg::Type, CustomError>>
96where
97 Reg: Register,
98 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
99{
100 let vd_idx = vd.to_bits();
101 if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
102 cold_path();
103 return Err(ExecutionError::IllegalInstruction {
104 address: program_counter.old_pc(INSTRUCTION_SIZE),
105 });
106 }
107 if widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_a.to_bits(), group_regs) {
108 cold_path();
109 return Err(ExecutionError::IllegalInstruction {
110 address: program_counter.old_pc(INSTRUCTION_SIZE),
111 });
112 }
113 if let Some(vs_b) = vs_b_opt
114 && widen_src_overlap_illegal(vd_idx, wide_group_regs, vs_b.to_bits(), group_regs)
115 {
116 cold_path();
117 return Err(ExecutionError::IllegalInstruction {
118 address: program_counter.old_pc(INSTRUCTION_SIZE),
119 });
120 }
121 Ok(())
122}
123
124#[inline(always)]
133fn widen_src_overlap_illegal(vd_idx: u8, wide_group_regs: u8, vs_idx: u8, group_regs: u8) -> bool {
134 if !ranges_overlap(vd_idx, wide_group_regs, vs_idx, group_regs) {
135 return false;
136 }
137 let high_part_overlap =
138 wide_group_regs > group_regs && vs_idx == vd_idx + wide_group_regs - group_regs;
139 !high_part_overlap
140}
141
142#[inline(always)]
145#[doc(hidden)]
146pub fn check_vs_wide_alignment<Reg, Memory, PC, CustomError>(
147 program_counter: &PC,
148 vs: VReg,
149 wide_group_regs: u8,
150) -> Result<(), ExecutionError<Reg::Type, CustomError>>
151where
152 Reg: Register,
153 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
154{
155 let vs_idx = vs.to_bits();
156 if !vs_idx.is_multiple_of(wide_group_regs) || vs_idx + wide_group_regs > 32 {
157 cold_path();
158 return Err(ExecutionError::IllegalInstruction {
159 address: program_counter.old_pc(INSTRUCTION_SIZE),
160 });
161 }
162 Ok(())
163}
164
165#[inline(always)]
171#[doc(hidden)]
172pub fn check_vd_narrow_alignment<Reg, Memory, PC, CustomError>(
173 program_counter: &PC,
174 vd: VReg,
175 group_regs: u8,
176) -> Result<(), ExecutionError<Reg::Type, CustomError>>
177where
178 Reg: Register,
179 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
180{
181 let vd_idx = vd.to_bits();
182 if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
183 cold_path();
184 return Err(ExecutionError::IllegalInstruction {
185 address: program_counter.old_pc(INSTRUCTION_SIZE),
186 });
187 }
188 Ok(())
189}
190
191#[inline(always)]
193fn ranges_overlap(a_start: u8, a_len: u8, b_start: u8, b_len: u8) -> bool {
194 a_start < b_start + b_len && b_start < a_start + a_len
195}
196
197#[inline(always)]
199fn mask_bit(mask: &[u8], i: u32) -> bool {
200 mask.get((i / u8::BITS) as usize)
201 .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
202}
203
204#[inline(always)]
211unsafe fn snapshot_mask<const VLENB: usize>(
212 vregs: &VectorRegisterFile<VLENB>,
213 vm: bool,
214 vl: u32,
215) -> [u8; VLENB] {
216 let mut buf = [0u8; VLENB];
217 if vm {
218 buf = [0xffu8; VLENB];
219 } else {
220 let mask_bytes = vl.div_ceil(u8::BITS) as usize;
221 unsafe {
223 buf.get_unchecked_mut(..mask_bytes)
224 .copy_from_slice(vregs.get(VReg::V0).get_unchecked(..mask_bytes));
225 }
226 }
227 buf
228}
229
230#[inline(always)]
236unsafe fn read_element_u64<const VLENB: usize>(
237 vregs: &VectorRegisterFile<VLENB>,
238 base_reg: VReg,
239 elem_i: u32,
240 sew: Vsew,
241) -> u64 {
242 let sew_bytes = usize::from(sew.bytes_width());
243 let elems_per_reg = VLENB / sew_bytes;
244 let reg_off = elem_i as usize / elems_per_reg;
245 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
246 let reg = unsafe {
248 vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
249 };
250 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
252 let mut buf = [0u8; 8];
253 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
255 u64::from_le_bytes(buf)
256}
257
258#[inline(always)]
263unsafe fn write_element_u64<const VLENB: usize>(
264 vregs: &mut VectorRegisterFile<VLENB>,
265 base_reg: VReg,
266 elem_i: u32,
267 sew: Vsew,
268 value: u64,
269) {
270 let sew_bytes = usize::from(sew.bytes_width());
271 let elems_per_reg = VLENB / sew_bytes;
272 let reg_off = elem_i as usize / elems_per_reg;
273 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
274 let buf = value.to_le_bytes();
275 let reg = unsafe {
277 vregs.get_mut(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
278 };
279 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
281 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
283}
284
285#[inline(always)]
287#[doc(hidden)]
288pub fn sign_extend_bits(val: u64, bits: u8) -> i64 {
289 let shift = u64::BITS - u32::from(bits);
290 (val.cast_signed() << shift) >> shift
291}
292
293#[inline(always)]
314fn scalar_unsigned_for_sew(val: u64, sew_bits: u8) -> u64 {
315 val & (u64::MAX >> (u64::BITS - u32::from(sew_bits)))
316}
317
318#[inline(always)]
340fn scalar_signed_for_sew(val: u64, sew_bits: u8) -> u64 {
341 sign_extend_bits(val, sew_bits).cast_unsigned()
342}
343
344#[inline(always)]
362#[doc(hidden)]
363pub unsafe fn execute_widen_op<const ZERO_EXTEND_AB: bool, Reg, ExtState, CustomError, F>(
364 ext_state: &mut ExtState,
365 vd: VReg,
366 vs2: VReg,
367 src: OpSrc,
368 vm: bool,
369 sew: Vsew,
370 op: F,
371) where
372 Reg: Register,
373 ExtState: VectorRegistersExt<Reg, CustomError>,
374 [(); ExtState::ELEN as usize]:,
375 [(); ExtState::VLEN as usize]:,
376 [(); ExtState::VLENB as usize]:,
377 CustomError: fmt::Debug,
378 F: Fn(u64, u64) -> u64,
379{
380 let vl = ext_state.vl();
381 let vstart = ext_state.vstart();
382 let wide_sew = sew
383 .double_width()
384 .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
385
386 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
388
389 for i in u32::from(vstart)..vl {
390 if !mask_bit(&mask_buf, i) {
391 continue;
392 }
393 let raw_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
396 let wide_a = if ZERO_EXTEND_AB {
397 raw_a
398 } else {
399 sign_extend_bits(raw_a, sew.bits_width()).cast_unsigned()
400 };
401 let wide_b = match src {
402 OpSrc::Vreg(vs1_base) => {
403 let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
405 if ZERO_EXTEND_AB {
406 raw_b
407 } else {
408 sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
409 }
410 }
411 OpSrc::Scalar(val) => {
412 if ZERO_EXTEND_AB {
413 scalar_unsigned_for_sew(val, sew.bits_width())
414 } else {
415 scalar_signed_for_sew(val, sew.bits_width())
416 }
417 }
418 };
419 let result = op(wide_a, wide_b);
420 unsafe {
424 write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
425 }
426 }
427 ext_state.mark_vs_dirty();
428 ext_state.reset_vstart();
429}
430
431#[inline(always)]
444#[doc(hidden)]
445pub unsafe fn execute_widen_w_op<const ZERO_EXTEND_B: bool, Reg, ExtState, CustomError, F>(
446 ext_state: &mut ExtState,
447 vd: VReg,
448 vs2: VReg,
449 src: OpSrc,
450 vm: bool,
451 sew: Vsew,
452 op: F,
453) where
454 Reg: Register,
455 ExtState: VectorRegistersExt<Reg, CustomError>,
456 [(); ExtState::ELEN as usize]:,
457 [(); ExtState::VLEN as usize]:,
458 [(); ExtState::VLENB as usize]:,
459 CustomError: fmt::Debug,
460 F: Fn(u64, u64) -> u64,
461{
462 let vl = ext_state.vl();
463 let vstart = ext_state.vstart();
464 let wide_sew = sew
465 .double_width()
466 .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
467
468 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
470
471 for i in u32::from(vstart)..vl {
472 if !mask_bit(&mask_buf, i) {
473 continue;
474 }
475 let wide_a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
478 let wide_b = match src {
479 OpSrc::Vreg(vs1) => {
480 let raw_b = unsafe { read_element_u64(ext_state.read_vregs(), vs1, i, sew) };
484 if ZERO_EXTEND_B {
485 raw_b
486 } else {
487 sign_extend_bits(raw_b, sew.bits_width()).cast_unsigned()
488 }
489 }
490 OpSrc::Scalar(val) => {
491 if ZERO_EXTEND_B {
492 scalar_unsigned_for_sew(val, sew.bits_width())
493 } else {
494 scalar_signed_for_sew(val, sew.bits_width())
495 }
496 }
497 };
498 let result = op(wide_a, wide_b);
499 unsafe {
501 write_element_u64(ext_state.write_vregs(), vd, i, wide_sew, result);
502 }
503 }
504 ext_state.mark_vs_dirty();
505 ext_state.reset_vstart();
506}
507
508#[inline(always)]
524#[doc(hidden)]
525pub unsafe fn execute_narrow_shift<const ARITHMETIC: bool, Reg, ExtState, CustomError>(
526 ext_state: &mut ExtState,
527 vd: VReg,
528 vs2: VReg,
529 src: OpSrc,
530 vm: bool,
531 sew: Vsew,
532) where
533 Reg: Register,
534 ExtState: VectorRegistersExt<Reg, CustomError>,
535 [(); ExtState::ELEN as usize]:,
536 [(); ExtState::VLEN as usize]:,
537 [(); ExtState::VLENB as usize]:,
538 CustomError: fmt::Debug,
539{
540 let vl = ext_state.vl();
541 let vstart = ext_state.vstart();
542 let wide_sew = sew
543 .double_width()
544 .expect("SEW < 64 is enforced by caller, hence this is always valid; qed");
545 let shamt_mask = u64::from(wide_sew.bits_width() - 1);
547
548 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
550
551 for i in u32::from(vstart)..vl {
552 if !mask_bit(&mask_buf, i) {
553 continue;
554 }
555 let wide_val = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, wide_sew) };
557 let shamt = match src {
558 OpSrc::Vreg(vs1_base) => {
559 let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
563 raw & shamt_mask
564 }
565 OpSrc::Scalar(val) => val & shamt_mask,
567 };
568 let result_wide = if ARITHMETIC {
569 (sign_extend_bits(wide_val, wide_sew.bits_width()) >> shamt).cast_unsigned()
573 } else {
574 wide_val >> shamt
575 };
576 let result = result_wide & ((1u64 << sew.bits_width()) - 1);
578 unsafe {
580 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
581 }
582 }
583 ext_state.mark_vs_dirty();
584 ext_state.reset_vstart();
585}
586
587#[inline(always)]
602#[doc(hidden)]
603pub unsafe fn execute_extension<const SIGN: bool, Reg, ExtState, CustomError>(
604 ext_state: &mut ExtState,
605 vd: VReg,
606 vs2: VReg,
607 vm: bool,
608 sew: Vsew,
609 factor: VsewFactor,
610) where
611 Reg: Register,
612 ExtState: VectorRegistersExt<Reg, CustomError>,
613 [(); ExtState::ELEN as usize]:,
614 [(); ExtState::VLEN as usize]:,
615 [(); ExtState::VLENB as usize]:,
616 CustomError: fmt::Debug,
617{
618 let vl = ext_state.vl();
619 let vstart = ext_state.vstart();
620 let src_sew = sew
621 .divide_by_factor(factor)
622 .expect("SEW >= factor*8 and valid according to function contract; qed");
623
624 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
626
627 for i in u32::from(vstart)..vl {
628 if !mask_bit(&mask_buf, i) {
629 continue;
630 }
631 let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, src_sew) };
633 let result = if SIGN {
634 sign_extend_bits(raw, src_sew.bits_width()).cast_unsigned()
635 } else {
636 raw
637 };
638 unsafe {
640 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
641 }
642 }
643 ext_state.mark_vs_dirty();
644 ext_state.reset_vstart();
645}