1use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5 OpSrc, check_vreg_group_alignment, sew_mask,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8 read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, ProgramCounter};
13use ab_riscv_primitives::prelude::*;
14use core::fmt;
15
16#[inline(always)]
21fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
22 if shift == 0 {
23 return 0;
24 }
25 let d_minus1_bit = (val >> (shift - 1)) & 1;
27 let sticky = if shift >= 2 {
29 (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
31 } else {
32 false
33 };
34 match mode {
35 Vxrm::Rnu => d_minus1_bit,
37 Vxrm::Rne => d_minus1_bit & u64::from(sticky || current_result_lsb != 0),
39 Vxrm::Rdn => 0,
41 Vxrm::Rod => u64::from(current_result_lsb == 0 && (d_minus1_bit != 0 || sticky)),
43 }
44}
45
46#[inline(always)]
50#[doc(hidden)]
51pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
52 let truncated = val >> shift;
53 let r = round_increment(val, shift, mode, truncated & 1);
54 truncated.wrapping_add(r)
55}
56
57#[inline(always)]
61#[doc(hidden)]
62pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
63 let signed = sign_extend(val, sew);
64 let truncated_signed = signed >> shift;
68 let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
69 truncated_signed.cast_unsigned().wrapping_add(r)
70}
71
72#[inline(always)]
76#[doc(hidden)]
77pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
78 let mask = sew_mask(sew);
79 let a_w = a & mask;
80 let b_w = b & mask;
81 let result = a_w.wrapping_add(b_w);
82 if result & mask < a_w {
83 *vxsat = true;
85 mask
86 } else {
87 result & mask
88 }
89}
90
91#[inline(always)]
95#[doc(hidden)]
96pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
97 let sa = i128::from(sign_extend(a, sew));
98 let sb = i128::from(sign_extend(b, sew));
99 let result = sa.wrapping_add(sb);
100 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
101 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
102 if result < min_val {
103 *vxsat = true;
104 (min_val as i64).cast_unsigned() & sew_mask(sew)
105 } else if result > max_val {
106 *vxsat = true;
107 (max_val as i64).cast_unsigned() & sew_mask(sew)
108 } else {
109 (result as i64).cast_unsigned() & sew_mask(sew)
110 }
111}
112
113#[inline(always)]
117#[doc(hidden)]
118pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
119 let mask = sew_mask(sew);
120 let a_w = a & mask;
121 let b_w = b & mask;
122 if a_w < b_w {
123 *vxsat = true;
124 0
125 } else {
126 (a_w - b_w) & mask
127 }
128}
129
130#[inline(always)]
134#[doc(hidden)]
135pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
136 let sa = i128::from(sign_extend(a, sew));
137 let sb = i128::from(sign_extend(b, sew));
138 let result = sa.wrapping_sub(sb);
139 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits_width()));
140 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits_width()));
141 if result < min_val {
142 *vxsat = true;
143 (min_val as i64).cast_unsigned() & sew_mask(sew)
144 } else if result > max_val {
145 *vxsat = true;
146 (max_val as i64).cast_unsigned() & sew_mask(sew)
147 } else {
148 (result as i64).cast_unsigned() & sew_mask(sew)
149 }
150}
151
152#[inline(always)]
156#[doc(hidden)]
157pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
158 let mask = sew_mask(sew);
159 let a_w = a & mask;
160 let b_w = b & mask;
161 let sum = a_w.wrapping_add(b_w);
165 let carry = u64::from(sum & mask < a_w);
167 let r = round_increment(sum & mask, 1, mode, (sum >> 1u8) & 1);
171 let shifted = (carry << (u32::from(sew.bits_width()) - 1)) | ((sum & mask) >> 1u8);
173 (shifted.wrapping_add(r)) & mask
174}
175
176#[inline(always)]
180#[doc(hidden)]
181pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
182 let sa = sign_extend(a, sew);
183 let sb = sign_extend(b, sew);
184 let sum = i128::from(sa).wrapping_add(i128::from(sb));
186 let r = match mode {
188 Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
189 Vxrm::Rne => {
190 let result_lsb = ((sum >> 1u8) & 1).cast_unsigned() as u64;
193 ((sum & 1).cast_unsigned() as u64) & result_lsb
194 }
195 Vxrm::Rdn => 0,
196 Vxrm::Rod => {
197 let result_lsb = (sum >> 1u8) & 1;
199 u64::from(result_lsb == 0 && (sum & 1) != 0)
200 }
201 };
202 let result = (sum >> 1u8) + i128::from(r);
203 (result as i64).cast_unsigned() & sew_mask(sew)
204}
205
206#[inline(always)]
210#[doc(hidden)]
211pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
212 let mask = sew_mask(sew);
213 let a_w = a & mask;
214 let b_w = b & mask;
215 let diff = a_w.wrapping_sub(b_w);
217 let borrow = u64::from(a_w < b_w);
219 let r = round_increment(diff & mask, 1, mode, (diff >> 1u8) & 1);
223 let sign_fill = borrow.wrapping_neg(); let shifted = (sign_fill << (u32::from(sew.bits_width()) - 1)) | ((diff & mask) >> 1u8);
230 (shifted.wrapping_add(r)) & mask
231}
232
233#[inline(always)]
237#[doc(hidden)]
238pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
239 let sa = sign_extend(a, sew);
240 let sb = sign_extend(b, sew);
241 let diff = i128::from(sa).wrapping_sub(i128::from(sb));
242 let r = match mode {
243 Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
244 Vxrm::Rne => {
245 let result_lsb = ((diff >> 1u8) & 1).cast_unsigned() as u64;
246 ((diff & 1).cast_unsigned() as u64) & result_lsb
247 }
248 Vxrm::Rdn => 0,
249 Vxrm::Rod => {
250 let result_lsb = (diff >> 1u8) & 1;
251 u64::from(result_lsb == 0 && (diff & 1) != 0)
252 }
253 };
254 let result = (diff >> 1u8) + i128::from(r);
255 (result as i64).cast_unsigned() & sew_mask(sew)
256}
257
258#[inline(always)]
266#[doc(hidden)]
267pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
268 let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
270 let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
271 let sa = i128::from(sign_extend(a, sew));
272 let sb = i128::from(sign_extend(b, sew));
273 if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
276 *vxsat = true;
277 return max_sew.cast_unsigned() & sew_mask(sew);
278 }
279 let product = sa * sb;
282 let doubled = product << 1u8;
285 let shift = u32::from(sew.bits_width());
288 let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
289 let truncated = doubled >> shift;
291 let r = round_increment(
292 low_bits,
293 shift.min(64),
294 mode,
295 (truncated.cast_unsigned() as u64) & 1,
296 );
297 let result = (truncated as i64).wrapping_add(r.cast_signed());
299 if result < min_sew {
301 *vxsat = true;
302 min_sew.cast_unsigned() & sew_mask(sew)
303 } else if result > max_sew {
304 *vxsat = true;
305 max_sew.cast_unsigned() & sew_mask(sew)
306 } else {
307 result.cast_unsigned() & sew_mask(sew)
308 }
309}
310
311#[inline(always)]
323#[doc(hidden)]
324pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
325 let shifted = rounded_srl(vs2_elem, shamt, mode);
327 let max_dst = sew_mask(sew);
329 if shifted > max_dst {
330 *vxsat = true;
331 max_dst
332 } else {
333 shifted & max_dst
334 }
335}
336
337#[inline(always)]
342#[doc(hidden)]
343pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
344 let double_sew_bits = sew.bits_width() * 2;
347 let shift_amt = i64::BITS - u32::from(double_sew_bits);
348 let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
349 let low_bits = signed_wide.cast_unsigned()
352 & if double_sew_bits == 64 {
353 u64::MAX
354 } else {
355 (1u64 << double_sew_bits) - 1
356 };
357 let truncated = signed_wide >> shamt;
358 let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
359 let rounded = truncated.wrapping_add(r.cast_signed());
360 let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits_width()));
362 let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits_width()));
363 if rounded < min_dst {
364 *vxsat = true;
365 min_dst.cast_unsigned() & sew_mask(sew)
366 } else if rounded > max_dst {
367 *vxsat = true;
368 max_dst.cast_unsigned() & sew_mask(sew)
369 } else {
370 rounded.cast_unsigned() & sew_mask(sew)
371 }
372}
373
374#[inline(always)]
386pub unsafe fn read_wide_element_u64<const VLENB: usize>(
387 vregs: &VectorRegisterFile<VLENB>,
388 base_reg: VReg,
389 elem_i: u32,
390 sew: Vsew,
391) -> u64 {
392 let double_sew_bytes = usize::from(sew.bytes_width()) * 2;
393 let elems_per_reg = VLENB / double_sew_bytes;
394 let reg_off = elem_i as usize / elems_per_reg;
395 let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
396 let reg = unsafe {
398 vregs.get(VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked())
399 };
400 let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
402 let mut buf = [0u8; 8];
403 unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
405 u64::from_le_bytes(buf)
406}
407
408#[inline(always)]
416#[doc(hidden)]
417pub unsafe fn execute_fixed_point_op<Reg, ExtState, CustomError, F>(
418 ext_state: &mut ExtState,
419 vd: VReg,
420 vs2: VReg,
421 src: OpSrc,
422 vm: bool,
423 sew: Vsew,
424 op: F,
425) where
426 Reg: Register,
427 ExtState: VectorRegistersExt<Reg, CustomError>,
428 [(); ExtState::ELEN as usize]:,
429 [(); ExtState::VLEN as usize]:,
430 [(); ExtState::VLENB as usize]:,
431 CustomError: fmt::Debug,
432 F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
434{
435 let vl = ext_state.vl();
436 let vstart = ext_state.vstart();
437 let vxrm = ext_state.vxrm();
438 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
440 let mut any_sat = false;
441 for i in u32::from(vstart)..vl {
442 if !mask_bit(&mask_buf, i) {
443 continue;
444 }
445 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
447 let b = match src {
448 OpSrc::Vreg(vs1_base) => {
449 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
451 }
452 OpSrc::Scalar(val) => val,
453 };
454 let result = op(a, b, sew, vxrm, &mut any_sat);
455 unsafe {
457 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
458 }
459 }
460 if any_sat {
461 ext_state.set_vxsat(true);
463 }
464 ext_state.mark_vs_dirty();
465 ext_state.reset_vstart();
466}
467
468#[inline(always)]
483#[doc(hidden)]
484pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, CustomError, F>(
485 ext_state: &mut ExtState,
486 vd: VReg,
487 vs2: VReg,
488 src: OpSrc,
489 vm: bool,
490 sew: Vsew,
491 op: F,
492) where
493 Reg: Register,
494 ExtState: VectorRegistersExt<Reg, CustomError>,
495 [(); ExtState::ELEN as usize]:,
496 [(); ExtState::VLEN as usize]:,
497 [(); ExtState::VLENB as usize]:,
498 CustomError: fmt::Debug,
499 F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
501{
502 let vl = ext_state.vl();
503 let vstart = ext_state.vstart();
504 let vxrm = ext_state.vxrm();
505 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
507 let mut any_sat = false;
508 let shamt_mask = u64::from(sew.bits_width() * 2 - 1);
510 for i in u32::from(vstart)..vl {
511 if !mask_bit(&mask_buf, i) {
512 continue;
513 }
514 let wide_a = unsafe { read_wide_element_u64(ext_state.read_vregs(), vs2, i, sew) };
517 let shamt = match src {
518 OpSrc::Vreg(vs1_base) => {
519 let raw = unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) };
521 (raw & shamt_mask) as u32
522 }
523 OpSrc::Scalar(val) => (val & shamt_mask) as u32,
524 };
525 let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
526 unsafe {
528 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
529 }
530 }
531 if any_sat {
532 ext_state.set_vxsat(true);
533 }
534 ext_state.mark_vs_dirty();
535 ext_state.reset_vstart();
536}
537
538#[inline(always)]
542#[doc(hidden)]
543pub fn check_narrowing_sew<Reg, Memory, PC, CustomError>(
544 program_counter: &PC,
545 sew: Vsew,
546) -> Result<(), ExecutionError<Reg::Type, CustomError>>
547where
548 Reg: Register,
549 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
550{
551 if sew.bits_width() > 32 {
552 return Err(ExecutionError::IllegalInstruction {
553 address: program_counter.old_pc(INSTRUCTION_SIZE),
554 });
555 }
556 Ok(())
557}
558
559#[inline(always)]
569#[doc(hidden)]
570pub fn check_vs2_narrowing_alignment<Reg, Memory, PC, CustomError>(
571 program_counter: &PC,
572 vs2: VReg,
573 vlmul: Vlmul,
574 sew: Vsew,
575) -> Result<(), ExecutionError<Reg::Type, CustomError>>
576where
577 Reg: Register,
578 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
579{
580 let wide_eew = match sew {
583 Vsew::E8 => Eew::E16,
584 Vsew::E16 => Eew::E32,
585 Vsew::E32 => Eew::E64,
586 Vsew::E64 => {
587 return Err(ExecutionError::IllegalInstruction {
588 address: program_counter.old_pc(INSTRUCTION_SIZE),
589 });
590 }
591 };
592 let wide_group =
594 vlmul
595 .data_register_count(wide_eew, sew)
596 .ok_or(ExecutionError::IllegalInstruction {
597 address: program_counter.old_pc(INSTRUCTION_SIZE),
598 })?;
599 let vs2_idx = vs2.to_bits();
600 if !vs2_idx.is_multiple_of(wide_group) || vs2_idx + wide_group > 32 {
601 return Err(ExecutionError::IllegalInstruction {
602 address: program_counter.old_pc(INSTRUCTION_SIZE),
603 });
604 }
605 Ok(())
606}