1use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5 OpSrc, check_vreg_group_alignment, sew_mask,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8 read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, ProgramCounter};
13use ab_riscv_primitives::prelude::*;
14use core::fmt;
15use core::hint::cold_path;
16
17#[inline(always)]
22fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
23 if shift == 0 {
24 return 0;
25 }
26 let d_minus1_bit = (val >> (shift - 1)) & 1;
28 let sticky = if shift >= 2 {
30 (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
32 } else {
33 false
34 };
35 match mode {
36 Vxrm::Rnu => d_minus1_bit,
38 Vxrm::Rne => d_minus1_bit & u64::from(sticky || current_result_lsb != 0),
40 Vxrm::Rdn => 0,
42 Vxrm::Rod => u64::from(current_result_lsb == 0 && (d_minus1_bit != 0 || sticky)),
44 }
45}
46
47#[inline(always)]
51#[doc(hidden)]
52pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
53 let truncated = val >> shift;
54 let r = round_increment(val, shift, mode, truncated & 1);
55 truncated.wrapping_add(r)
56}
57
58#[inline(always)]
62#[doc(hidden)]
63pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
64 let signed = sign_extend(val, sew);
65 let truncated_signed = signed >> shift;
69 let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
70 truncated_signed.cast_unsigned().wrapping_add(r)
71}
72
73#[inline(always)]
77#[doc(hidden)]
78pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
79 let mask = sew_mask(sew);
80 let a_w = a & mask;
81 let b_w = b & mask;
82 let result = a_w.wrapping_add(b_w);
83 if result & mask < a_w {
84 *vxsat = true;
86 mask
87 } else {
88 result & mask
89 }
90}
91
92#[inline(always)]
96#[doc(hidden)]
97pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
98 let sa = i128::from(sign_extend(a, sew));
99 let sb = i128::from(sign_extend(b, sew));
100 let result = sa.wrapping_add(sb);
101 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
102 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
103 if result < min_val {
104 *vxsat = true;
105 (min_val as i64).cast_unsigned() & sew_mask(sew)
106 } else if result > max_val {
107 *vxsat = true;
108 (max_val as i64).cast_unsigned() & sew_mask(sew)
109 } else {
110 (result as i64).cast_unsigned() & sew_mask(sew)
111 }
112}
113
114#[inline(always)]
118#[doc(hidden)]
119pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
120 let mask = sew_mask(sew);
121 let a_w = a & mask;
122 let b_w = b & mask;
123 if a_w < b_w {
124 *vxsat = true;
125 0
126 } else {
127 (a_w - b_w) & mask
128 }
129}
130
131#[inline(always)]
135#[doc(hidden)]
136pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
137 let sa = i128::from(sign_extend(a, sew));
138 let sb = i128::from(sign_extend(b, sew));
139 let result = sa.wrapping_sub(sb);
140 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
141 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
142 if result < min_val {
143 *vxsat = true;
144 (min_val as i64).cast_unsigned() & sew_mask(sew)
145 } else if result > max_val {
146 *vxsat = true;
147 (max_val as i64).cast_unsigned() & sew_mask(sew)
148 } else {
149 (result as i64).cast_unsigned() & sew_mask(sew)
150 }
151}
152
153#[inline(always)]
157#[doc(hidden)]
158pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
159 let mask = sew_mask(sew);
160 let a_w = a & mask;
161 let b_w = b & mask;
162 let sum = a_w.wrapping_add(b_w);
166 let carry = u64::from(sum & mask < a_w);
168 let r = round_increment(sum & mask, 1, mode, (sum >> 1u8) & 1);
172 let shifted = (carry << (u32::from(sew.bits_width()) - 1)) | ((sum & mask) >> 1u8);
174 (shifted.wrapping_add(r)) & mask
175}
176
177#[inline(always)]
181#[doc(hidden)]
182pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
183 let sa = sign_extend(a, sew);
184 let sb = sign_extend(b, sew);
185 let sum = i128::from(sa).wrapping_add(i128::from(sb));
187 let r = match mode {
189 Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
190 Vxrm::Rne => {
191 let result_lsb = ((sum >> 1u8) & 1).cast_unsigned() as u64;
194 ((sum & 1).cast_unsigned() as u64) & result_lsb
195 }
196 Vxrm::Rdn => 0,
197 Vxrm::Rod => {
198 let result_lsb = (sum >> 1u8) & 1;
200 u64::from(result_lsb == 0 && (sum & 1) != 0)
201 }
202 };
203 let result = (sum >> 1u8) + i128::from(r);
204 (result as i64).cast_unsigned() & sew_mask(sew)
205}
206
207#[inline(always)]
211#[doc(hidden)]
212pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
213 let mask = sew_mask(sew);
214 let a_w = a & mask;
215 let b_w = b & mask;
216 let diff = a_w.wrapping_sub(b_w);
218 let borrow = u64::from(a_w < b_w);
220 let r = round_increment(diff & mask, 1, mode, (diff >> 1u8) & 1);
224 let sign_fill = borrow.wrapping_neg(); let shifted = (sign_fill << (u32::from(sew.bits_width()) - 1)) | ((diff & mask) >> 1u8);
231 (shifted.wrapping_add(r)) & mask
232}
233
234#[inline(always)]
238#[doc(hidden)]
239pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
240 let sa = sign_extend(a, sew);
241 let sb = sign_extend(b, sew);
242 let diff = i128::from(sa).wrapping_sub(i128::from(sb));
243 let r = match mode {
244 Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
245 Vxrm::Rne => {
246 let result_lsb = ((diff >> 1u8) & 1).cast_unsigned() as u64;
247 ((diff & 1).cast_unsigned() as u64) & result_lsb
248 }
249 Vxrm::Rdn => 0,
250 Vxrm::Rod => {
251 let result_lsb = (diff >> 1u8) & 1;
252 u64::from(result_lsb == 0 && (diff & 1) != 0)
253 }
254 };
255 let result = (diff >> 1u8) + i128::from(r);
256 (result as i64).cast_unsigned() & sew_mask(sew)
257}
258
259#[inline(always)]
267#[doc(hidden)]
268pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
269 let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
271 let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
272 let sa = i128::from(sign_extend(a, sew));
273 let sb = i128::from(sign_extend(b, sew));
274 if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
277 cold_path();
278 *vxsat = true;
279 return max_sew.cast_unsigned() & sew_mask(sew);
280 }
281 let product = sa * sb;
284 let doubled = product << 1u8;
287 let shift = u32::from(sew.bits_width());
290 let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
291 let truncated = doubled >> shift;
293 let r = round_increment(
294 low_bits,
295 shift.min(64),
296 mode,
297 (truncated.cast_unsigned() as u64) & 1,
298 );
299 let result = (truncated as i64).wrapping_add(r.cast_signed());
301 if result < min_sew {
303 *vxsat = true;
304 min_sew.cast_unsigned() & sew_mask(sew)
305 } else if result > max_sew {
306 *vxsat = true;
307 max_sew.cast_unsigned() & sew_mask(sew)
308 } else {
309 result.cast_unsigned() & sew_mask(sew)
310 }
311}
312
313#[inline(always)]
325#[doc(hidden)]
326pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
327 let shifted = rounded_srl(vs2_elem, shamt, mode);
329 let max_dst = sew_mask(sew);
331 if shifted > max_dst {
332 *vxsat = true;
333 max_dst
334 } else {
335 shifted & max_dst
336 }
337}
338
339#[inline(always)]
344#[doc(hidden)]
345pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
346 let double_sew_bits = sew.bits_width() * 2;
350 let shift_amt = i64::BITS - u32::from(double_sew_bits);
351 let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
352 let low_bits = signed_wide.cast_unsigned()
355 & if double_sew_bits == 64 {
356 u64::MAX
357 } else {
358 (1u64 << double_sew_bits) - 1
359 };
360 let truncated = signed_wide >> shamt;
361 let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
362 let rounded = truncated.wrapping_add(r.cast_signed());
363 let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
365 let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
366 if rounded < min_dst {
367 *vxsat = true;
368 min_dst.cast_unsigned() & sew_mask(sew)
369 } else if rounded > max_dst {
370 *vxsat = true;
371 max_dst.cast_unsigned() & sew_mask(sew)
372 } else {
373 rounded.cast_unsigned() & sew_mask(sew)
374 }
375}
376
377#[inline(always)]
389pub unsafe fn read_wide_element_u64<const VLENB: usize>(
390 vregs: &VectorRegisterFile<VLENB>,
391 base_reg: VReg,
392 elem_i: u32,
393 sew: Vsew,
394) -> u64 {
395 let double_sew_bytes = usize::from(sew.bytes_width()) * 2;
396 let elems_per_reg = VLENB / double_sew_bytes;
397 let reg_off = elem_i as usize / elems_per_reg;
398 let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
399 let reg = unsafe {
401 vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
402 };
403 let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
405 let mut buf = [0u8; 8];
406 unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
408 u64::from_le_bytes(buf)
409}
410
411#[inline(always)]
419#[doc(hidden)]
420pub unsafe fn execute_fixed_point_op<Reg, ExtState, CustomError, F>(
421 ext_state: &mut ExtState,
422 vd: VReg,
423 vs2: VReg,
424 src: OpSrc,
425 vm: bool,
426 sew: Vsew,
427 op: F,
428) where
429 Reg: Register,
430 ExtState: VectorRegistersExt<Reg, CustomError>,
431 [(); ExtState::ELEN as usize]:,
432 [(); ExtState::VLEN as usize]:,
433 [(); ExtState::VLENB as usize]:,
434 CustomError: fmt::Debug,
435 F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
437{
438 let vl = ext_state.vl();
439 let vstart = ext_state.vstart();
440 let vxrm = ext_state.vxrm();
441 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
443 let mut any_sat = false;
444 for i in u32::from(vstart)..vl {
445 if !mask_bit(&mask_buf, i) {
446 continue;
447 }
448 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
450 let b = match src {
451 OpSrc::Vreg(vs1_base) => {
452 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
454 }
455 OpSrc::Scalar(val) => val,
456 };
457 let result = op(a, b, sew, vxrm, &mut any_sat);
458 unsafe {
460 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
461 }
462 }
463 if any_sat {
464 ext_state.set_vxsat(true);
466 }
467 ext_state.mark_vs_dirty();
468 ext_state.reset_vstart();
469}
470
471#[inline(always)]
486#[doc(hidden)]
487pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, CustomError, F>(
488 ext_state: &mut ExtState,
489 vd: VReg,
490 vs2: VReg,
491 src: OpSrc,
492 vm: bool,
493 sew: Vsew,
494 op: F,
495) where
496 Reg: Register,
497 ExtState: VectorRegistersExt<Reg, CustomError>,
498 [(); ExtState::ELEN as usize]:,
499 [(); ExtState::VLEN as usize]:,
500 [(); ExtState::VLENB as usize]:,
501 CustomError: fmt::Debug,
502 F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
504{
505 let vl = ext_state.vl();
506 let vstart = ext_state.vstart();
507 let vxrm = ext_state.vxrm();
508 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
510 let mut any_sat = false;
511 let shamt_mask = u64::from(sew.bits_width() * 2 - 1);
513 for i in u32::from(vstart)..vl {
514 if !mask_bit(&mask_buf, i) {
515 continue;
516 }
517 let wide_a = unsafe { read_wide_element_u64(ext_state.read_vregs(), vs2, i, sew) };
520 let shamt = match src {
521 OpSrc::Vreg(vs1_base) => {
522 let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
524 (raw & shamt_mask) as u32
525 }
526 OpSrc::Scalar(val) => (val & shamt_mask) as u32,
527 };
528 let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
529 unsafe {
531 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
532 }
533 }
534 if any_sat {
535 ext_state.set_vxsat(true);
536 }
537 ext_state.mark_vs_dirty();
538 ext_state.reset_vstart();
539}
540
541#[inline(always)]
545#[doc(hidden)]
546pub fn check_narrowing_sew<Reg, Memory, PC, CustomError>(
547 program_counter: &PC,
548 sew: Vsew,
549) -> Result<(), ExecutionError<Reg::Type, CustomError>>
550where
551 Reg: Register,
552 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
553{
554 if sew.bits_width() > 32 {
555 cold_path();
556 return Err(ExecutionError::IllegalInstruction {
557 address: program_counter.old_pc(INSTRUCTION_SIZE),
558 });
559 }
560 Ok(())
561}
562
563#[inline(always)]
573#[doc(hidden)]
574pub fn check_vs2_narrowing_alignment<Reg, Memory, PC, CustomError>(
575 program_counter: &PC,
576 vs2: VReg,
577 vlmul: Vlmul,
578 sew: Vsew,
579) -> Result<(), ExecutionError<Reg::Type, CustomError>>
580where
581 Reg: Register,
582 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
583{
584 let wide_eew = match sew {
587 Vsew::E8 => Eew::E16,
588 Vsew::E16 => Eew::E32,
589 Vsew::E32 => Eew::E64,
590 Vsew::E64 => {
591 cold_path();
592 return Err(ExecutionError::IllegalInstruction {
593 address: program_counter.old_pc(INSTRUCTION_SIZE),
594 });
595 }
596 };
597 let Some(wide_group) = vlmul.data_register_count(wide_eew, sew) else {
599 cold_path();
600 return Err(ExecutionError::IllegalInstruction {
601 address: program_counter.old_pc(INSTRUCTION_SIZE),
602 });
603 };
604 let vs2_idx = vs2.to_bits();
605 if !vs2_idx.is_multiple_of(wide_group) || vs2_idx + wide_group > 32 {
606 cold_path();
607 return Err(ExecutionError::IllegalInstruction {
608 address: program_counter.old_pc(INSTRUCTION_SIZE),
609 });
610 }
611 Ok(())
612}