1use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zve64x::arith::zve64x_arith_helpers::{OpSrc, check_vreg_group_alignment};
5use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::prelude::*;
9use core::fmt;
10
11#[inline(always)]
14#[doc(hidden)]
15pub fn check_vd_widen_no_src_check<Reg, Memory, PC, CustomError>(
16 program_counter: &PC,
17 vd: VReg,
18 wide_group_regs: u8,
19) -> Result<(), ExecutionError<Reg::Type, CustomError>>
20where
21 Reg: Register,
22 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
23{
24 let vd_idx = vd.bits();
25 if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
26 return Err(ExecutionError::IllegalInstruction {
27 address: program_counter.old_pc(INSTRUCTION_SIZE),
28 });
29 }
30 Ok(())
31}
32
33#[inline(always)]
36#[doc(hidden)]
37pub fn check_vs_ext_alignment<Reg, Memory, PC, CustomError>(
38 program_counter: &PC,
39 vs2: VReg,
40 src_group_regs: u8,
41 vd: VReg,
42 group_regs: u8,
43) -> Result<(), ExecutionError<Reg::Type, CustomError>>
44where
45 Reg: Register,
46 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
47{
48 let vs2_idx = vs2.bits();
49 if !vs2_idx.is_multiple_of(src_group_regs) || vs2_idx + src_group_regs > 32 {
50 return Err(ExecutionError::IllegalInstruction {
51 address: program_counter.old_pc(INSTRUCTION_SIZE),
52 });
53 }
54 if ranges_overlap(vd.bits(), group_regs, vs2_idx, src_group_regs) {
56 return Err(ExecutionError::IllegalInstruction {
57 address: program_counter.old_pc(INSTRUCTION_SIZE),
58 });
59 }
60 Ok(())
61}
62
63#[inline(always)]
69#[doc(hidden)]
70pub fn check_vd_widen_alignment<Reg, Memory, PC, CustomError>(
71 program_counter: &PC,
72 vd: VReg,
73 vs_a: VReg,
74 vs_b_opt: Option<VReg>,
75 group_regs: u8,
76 wide_group_regs: u8,
77) -> Result<(), ExecutionError<Reg::Type, CustomError>>
78where
79 Reg: Register,
80 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
81{
82 let vd_idx = vd.bits();
83 if !vd_idx.is_multiple_of(wide_group_regs) || vd_idx + wide_group_regs > 32 {
84 return Err(ExecutionError::IllegalInstruction {
85 address: program_counter.old_pc(INSTRUCTION_SIZE),
86 });
87 }
88 let va_idx = vs_a.bits();
89 if ranges_overlap(vd_idx, wide_group_regs, va_idx, group_regs) {
90 return Err(ExecutionError::IllegalInstruction {
91 address: program_counter.old_pc(INSTRUCTION_SIZE),
92 });
93 }
94 if let Some(vs_b) = vs_b_opt {
95 let vb_idx = vs_b.bits();
96 if ranges_overlap(vd_idx, wide_group_regs, vb_idx, group_regs) {
97 return Err(ExecutionError::IllegalInstruction {
98 address: program_counter.old_pc(INSTRUCTION_SIZE),
99 });
100 }
101 }
102 Ok(())
103}
104
105#[inline(always)]
108#[doc(hidden)]
109pub fn check_vs_wide_alignment<Reg, Memory, PC, CustomError>(
110 program_counter: &PC,
111 vs: VReg,
112 wide_group_regs: u8,
113) -> Result<(), ExecutionError<Reg::Type, CustomError>>
114where
115 Reg: Register,
116 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
117{
118 let vs_idx = vs.bits();
119 if !vs_idx.is_multiple_of(wide_group_regs) || vs_idx + wide_group_regs > 32 {
120 return Err(ExecutionError::IllegalInstruction {
121 address: program_counter.old_pc(INSTRUCTION_SIZE),
122 });
123 }
124 Ok(())
125}
126
127#[inline(always)]
133#[doc(hidden)]
134pub fn check_vd_narrow_alignment<Reg, Memory, PC, CustomError>(
135 program_counter: &PC,
136 vd: VReg,
137 group_regs: u8,
138) -> Result<(), ExecutionError<Reg::Type, CustomError>>
139where
140 Reg: Register,
141 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
142{
143 let vd_idx = vd.bits();
144 if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
145 return Err(ExecutionError::IllegalInstruction {
146 address: program_counter.old_pc(INSTRUCTION_SIZE),
147 });
148 }
149 Ok(())
150}
151
152#[inline(always)]
154fn ranges_overlap(a_start: u8, a_len: u8, b_start: u8, b_len: u8) -> bool {
155 a_start < b_start + b_len && b_start < a_start + a_len
156}
157
158#[inline(always)]
160fn mask_bit(mask: &[u8], i: u32) -> bool {
161 mask.get((i / u8::BITS) as usize)
162 .is_some_and(|b| (b >> (i % u8::BITS)) & 1 != 0)
163}
164
165#[inline(always)]
172unsafe fn snapshot_mask<const VLENB: usize>(
173 vreg: &[[u8; VLENB]; 32],
174 vm: bool,
175 vl: u32,
176) -> [u8; VLENB] {
177 let mut buf = [0u8; VLENB];
178 if vm {
179 buf = [0xffu8; VLENB];
180 } else {
181 let mask_bytes = vl.div_ceil(u8::BITS) as usize;
182 unsafe {
184 buf.get_unchecked_mut(..mask_bytes)
185 .copy_from_slice(vreg[usize::from(VReg::V0.bits())].get_unchecked(..mask_bytes));
186 }
187 }
188 buf
189}
190
191#[inline(always)]
197unsafe fn read_element_u64<const VLENB: usize>(
198 vreg: &[[u8; VLENB]; 32],
199 base_reg: usize,
200 elem_i: u32,
201 sew_bytes: usize,
202) -> u64 {
203 let elems_per_reg = VLENB / sew_bytes;
204 let reg_off = elem_i as usize / elems_per_reg;
205 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
206 let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
208 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
210 let mut buf = [0u8; 8];
211 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
213 u64::from_le_bytes(buf)
214}
215
216#[inline(always)]
221unsafe fn write_element_u64<const VLENB: usize>(
222 vreg: &mut [[u8; VLENB]; 32],
223 base_reg: u8,
224 elem_i: u32,
225 sew_bytes: usize,
226 value: u64,
227) {
228 let elems_per_reg = VLENB / sew_bytes;
229 let reg_off = elem_i as usize / elems_per_reg;
230 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
231 let buf = value.to_le_bytes();
232 let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
234 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
236 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
238}
239
240#[inline(always)]
242#[doc(hidden)]
243pub fn sign_extend_bits(val: u64, bits: u32) -> i64 {
244 let shift = u64::BITS - bits;
245 (val.cast_signed() << shift) >> shift
246}
247
248#[inline(always)]
266#[expect(clippy::too_many_arguments, reason = "Internal API")]
267#[doc(hidden)]
268pub unsafe fn execute_widen_op<Reg, ExtState, CustomError, F>(
269 ext_state: &mut ExtState,
270 vd: VReg,
271 vs2: VReg,
272 src: OpSrc,
273 vm: bool,
274 vl: u32,
275 vstart: u32,
276 sew: Vsew,
277 zero_extend_a: bool,
278 zero_extend_b: bool,
279 op: F,
280) where
281 Reg: Register,
282 ExtState: VectorRegistersExt<Reg, CustomError>,
283 [(); ExtState::ELEN as usize]:,
284 [(); ExtState::VLEN as usize]:,
285 [(); ExtState::VLENB as usize]:,
286 CustomError: fmt::Debug,
287 F: Fn(u64, u64) -> u64,
288{
289 let sew_bytes = usize::from(sew.bytes());
290 let wide_sew_bytes = sew_bytes * 2;
292 let sew_bits = u32::from(sew.bits());
293
294 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
296 let vd_base = vd.bits();
297 let vs2_base = vs2.bits();
298
299 for i in vstart..vl {
300 if !mask_bit(&mask_buf, i) {
301 continue;
302 }
303 let raw_a =
305 unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew_bytes) };
306 let wide_a = if zero_extend_a {
307 raw_a
308 } else {
309 sign_extend_bits(raw_a, sew_bits).cast_unsigned()
310 };
311 let wide_b = match &src {
312 OpSrc::Vreg(vs1_base) => {
313 let raw_b = unsafe {
315 read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew_bytes)
316 };
317 if zero_extend_b {
318 raw_b
319 } else {
320 sign_extend_bits(raw_b, sew_bits).cast_unsigned()
321 }
322 }
323 OpSrc::Scalar(val) => *val,
324 };
325 let result = op(wide_a, wide_b);
326 unsafe {
329 write_element_u64(ext_state.write_vreg(), vd_base, i, wide_sew_bytes, result);
330 }
331 }
332 ext_state.mark_vs_dirty();
333 ext_state.reset_vstart();
334}
335
336#[inline(always)]
349#[expect(clippy::too_many_arguments, reason = "Internal API")]
350#[doc(hidden)]
351pub unsafe fn execute_widen_w_op<Reg, ExtState, CustomError, F>(
352 ext_state: &mut ExtState,
353 vd: VReg,
354 vs2: VReg,
355 src: OpSrc,
356 vm: bool,
357 vl: u32,
358 vstart: u32,
359 sew: Vsew,
360 zero_extend_b: bool,
361 op: F,
362) where
363 Reg: Register,
364 ExtState: VectorRegistersExt<Reg, CustomError>,
365 [(); ExtState::ELEN as usize]:,
366 [(); ExtState::VLEN as usize]:,
367 [(); ExtState::VLENB as usize]:,
368 CustomError: fmt::Debug,
369 F: Fn(u64, u64) -> u64,
370{
371 let sew_bytes = usize::from(sew.bytes());
372 let wide_sew_bytes = sew_bytes * 2;
373 let sew_bits = u32::from(sew.bits());
374
375 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
377 let vd_base = vd.bits();
378 let vs2_base = vs2.bits();
379
380 for i in vstart..vl {
381 if !mask_bit(&mask_buf, i) {
382 continue;
383 }
384 let wide_a = unsafe {
387 read_element_u64(
388 ext_state.read_vreg(),
389 usize::from(vs2_base),
390 i,
391 wide_sew_bytes,
392 )
393 };
394 let wide_b = match &src {
395 OpSrc::Vreg(vs1_base) => {
396 let raw_b = unsafe {
400 read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew_bytes)
401 };
402 if zero_extend_b {
403 raw_b
404 } else {
405 sign_extend_bits(raw_b, sew_bits).cast_unsigned()
406 }
407 }
408 OpSrc::Scalar(val) => *val,
409 };
410 let result = op(wide_a, wide_b);
411 unsafe {
413 write_element_u64(ext_state.write_vreg(), vd_base, i, wide_sew_bytes, result);
414 }
415 }
416 ext_state.mark_vs_dirty();
417 ext_state.reset_vstart();
418}
419
420#[inline(always)]
436#[expect(clippy::too_many_arguments, reason = "Internal API")]
437#[doc(hidden)]
438pub unsafe fn execute_narrow_shift<Reg, ExtState, CustomError>(
439 ext_state: &mut ExtState,
440 vd: VReg,
441 vs2: VReg,
442 src: OpSrc,
443 vm: bool,
444 vl: u32,
445 vstart: u32,
446 sew: Vsew,
447 arithmetic: bool,
448) where
449 Reg: Register,
450 ExtState: VectorRegistersExt<Reg, CustomError>,
451 [(); ExtState::ELEN as usize]:,
452 [(); ExtState::VLEN as usize]:,
453 [(); ExtState::VLENB as usize]:,
454 CustomError: fmt::Debug,
455{
456 let sew_bytes = usize::from(sew.bytes());
457 let wide_sew_bytes = sew_bytes * 2;
458 let wide_sew_bits = u32::from(sew.bits()) * 2;
461 let shamt_mask = u64::from(wide_sew_bits - 1);
462
463 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
465 let vd_base = vd.bits();
466 let vs2_base = vs2.bits();
467
468 for i in vstart..vl {
469 if !mask_bit(&mask_buf, i) {
470 continue;
471 }
472 let wide_val = unsafe {
474 read_element_u64(
475 ext_state.read_vreg(),
476 usize::from(vs2_base),
477 i,
478 wide_sew_bytes,
479 )
480 };
481 let shamt = match &src {
482 OpSrc::Vreg(vs1_base) => {
483 let raw = unsafe {
487 read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew_bytes)
488 };
489 raw & shamt_mask
490 }
491 OpSrc::Scalar(val) => val & shamt_mask,
493 };
494 let result_wide = if arithmetic {
495 (sign_extend_bits(wide_val, wide_sew_bits) >> shamt).cast_unsigned()
499 } else {
500 wide_val >> shamt
501 };
502 let result = result_wide & ((1u64 << sew.bits()) - 1);
504 unsafe {
506 write_element_u64(ext_state.write_vreg(), vd_base, i, sew_bytes, result);
507 }
508 }
509 ext_state.mark_vs_dirty();
510 ext_state.reset_vstart();
511}
512
513#[inline(always)]
528#[expect(clippy::too_many_arguments, reason = "Internal API")]
529#[doc(hidden)]
530pub unsafe fn execute_extension<Reg, ExtState, CustomError>(
531 ext_state: &mut ExtState,
532 vd: VReg,
533 vs2: VReg,
534 vm: bool,
535 vl: u32,
536 vstart: u32,
537 sew: Vsew,
538 factor: u8,
539 sign: bool,
540) where
541 Reg: Register,
542 ExtState: VectorRegistersExt<Reg, CustomError>,
543 [(); ExtState::ELEN as usize]:,
544 [(); ExtState::VLEN as usize]:,
545 [(); ExtState::VLENB as usize]:,
546 CustomError: fmt::Debug,
547{
548 let sew_bytes = usize::from(sew.bytes());
549 let src_sew_bytes = sew_bytes / usize::from(factor);
550 let src_sew_bits = (u32::from(sew.bits())) / u32::from(factor);
551
552 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
554 let vd_base = vd.bits();
555 let vs2_base = vs2.bits();
556
557 for i in vstart..vl {
558 if !mask_bit(&mask_buf, i) {
559 continue;
560 }
561 let raw = unsafe {
563 read_element_u64(
564 ext_state.read_vreg(),
565 usize::from(vs2_base),
566 i,
567 src_sew_bytes,
568 )
569 };
570 let result = if sign {
571 sign_extend_bits(raw, src_sew_bits).cast_unsigned()
572 } else {
573 raw
574 };
575 unsafe {
577 write_element_u64(ext_state.write_vreg(), vd_base, i, sew_bytes, result);
578 }
579 }
580 ext_state.mark_vs_dirty();
581 ext_state.reset_vstart();
582}