1use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zve64x::arith::zve64x_arith_helpers::{
5 OpSrc, check_vreg_group_alignment as check_vd, check_vreg_group_alignment as check_vs, sew_mask,
6};
7use crate::v::zve64x::arith::zve64x_arith_helpers::{
8 read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, ProgramCounter};
13use ab_riscv_primitives::prelude::*;
14use core::fmt;
15
16#[inline(always)]
21fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
22 if shift == 0 {
23 return 0;
24 }
25 let d_minus1_bit = (val >> (shift - 1)) & 1;
27 let sticky = if shift >= 2 {
29 (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
31 } else {
32 false
33 };
34 match mode {
35 Vxrm::Rnu => d_minus1_bit,
37 Vxrm::Rne => {
39 d_minus1_bit
40 & (if sticky || current_result_lsb != 0 {
41 1
42 } else {
43 0
44 })
45 }
46 Vxrm::Rdn => 0,
48 Vxrm::Rod => {
50 if current_result_lsb == 0 && (d_minus1_bit != 0 || sticky) {
51 1
52 } else {
53 0
54 }
55 }
56 }
57}
58
59#[inline(always)]
63#[doc(hidden)]
64pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
65 let truncated = val >> shift;
66 let r = round_increment(val, shift, mode, truncated & 1);
67 truncated.wrapping_add(r)
68}
69
70#[inline(always)]
74#[doc(hidden)]
75pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
76 let signed = sign_extend(val, sew);
77 let truncated_signed = signed >> shift;
81 let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
82 truncated_signed.cast_unsigned().wrapping_add(r)
83}
84
85#[inline(always)]
89#[doc(hidden)]
90pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
91 let mask = sew_mask(sew);
92 let a_w = a & mask;
93 let b_w = b & mask;
94 let result = a_w.wrapping_add(b_w);
95 if result & mask < a_w {
96 *vxsat = true;
98 mask
99 } else {
100 result & mask
101 }
102}
103
104#[inline(always)]
108#[doc(hidden)]
109pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
110 let sa = sign_extend(a, sew) as i128;
111 let sb = sign_extend(b, sew) as i128;
112 let result = sa.wrapping_add(sb);
113 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits()));
114 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits()));
115 if result < min_val {
116 *vxsat = true;
117 (min_val as i64).cast_unsigned() & sew_mask(sew)
118 } else if result > max_val {
119 *vxsat = true;
120 (max_val as i64).cast_unsigned() & sew_mask(sew)
121 } else {
122 (result as i64).cast_unsigned() & sew_mask(sew)
123 }
124}
125
126#[inline(always)]
130#[doc(hidden)]
131pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
132 let mask = sew_mask(sew);
133 let a_w = a & mask;
134 let b_w = b & mask;
135 if a_w < b_w {
136 *vxsat = true;
137 0
138 } else {
139 (a_w - b_w) & mask
140 }
141}
142
143#[inline(always)]
147#[doc(hidden)]
148pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
149 let sa = sign_extend(a, sew) as i128;
150 let sb = sign_extend(b, sew) as i128;
151 let result = sa.wrapping_sub(sb);
152 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits()));
153 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits()));
154 if result < min_val {
155 *vxsat = true;
156 (min_val as i64).cast_unsigned() & sew_mask(sew)
157 } else if result > max_val {
158 *vxsat = true;
159 (max_val as i64).cast_unsigned() & sew_mask(sew)
160 } else {
161 (result as i64).cast_unsigned() & sew_mask(sew)
162 }
163}
164
165#[inline(always)]
169#[doc(hidden)]
170pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
171 let mask = sew_mask(sew);
172 let a_w = a & mask;
173 let b_w = b & mask;
174 let sum = a_w.wrapping_add(b_w);
178 let carry = if sum & mask < a_w { 1u64 } else { 0u64 };
180 let r = round_increment(sum & mask, 1, mode, (sum >> 1) & 1);
184 let shifted = (carry << (sew.bits() as u32 - 1)) | ((sum & mask) >> 1);
186 (shifted.wrapping_add(r)) & mask
187}
188
189#[inline(always)]
193#[doc(hidden)]
194pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
195 let sa = sign_extend(a, sew);
196 let sb = sign_extend(b, sew);
197 let sum = (sa as i128).wrapping_add(sb as i128);
199 let r = match mode {
201 Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
202 Vxrm::Rne => {
203 let result_lsb = ((sum >> 1) & 1).cast_unsigned() as u64;
206 ((sum & 1).cast_unsigned() as u64) & result_lsb
207 }
208 Vxrm::Rdn => 0,
209 Vxrm::Rod => {
210 let result_lsb = (sum >> 1) & 1;
212 if result_lsb == 0 && (sum & 1) != 0 {
213 1
214 } else {
215 0
216 }
217 }
218 };
219 let result = (sum >> 1) + r as i128;
220 (result as i64).cast_unsigned() & sew_mask(sew)
221}
222
223#[inline(always)]
227#[doc(hidden)]
228pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
229 let mask = sew_mask(sew);
230 let a_w = a & mask;
231 let b_w = b & mask;
232 let diff = a_w.wrapping_sub(b_w);
234 let borrow = if a_w < b_w { 1u64 } else { 0u64 };
236 let r = round_increment(diff & mask, 1, mode, (diff >> 1) & 1);
240 let sign_fill = borrow.wrapping_neg(); let shifted = (sign_fill << (sew.bits() as u32 - 1)) | ((diff & mask) >> 1);
247 (shifted.wrapping_add(r)) & mask
248}
249
250#[inline(always)]
254#[doc(hidden)]
255pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
256 let sa = sign_extend(a, sew);
257 let sb = sign_extend(b, sew);
258 let diff = (sa as i128).wrapping_sub(sb as i128);
259 let r = match mode {
260 Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
261 Vxrm::Rne => {
262 let result_lsb = ((diff >> 1) & 1).cast_unsigned() as u64;
263 ((diff & 1).cast_unsigned() as u64) & result_lsb
264 }
265 Vxrm::Rdn => 0,
266 Vxrm::Rod => {
267 let result_lsb = (diff >> 1) & 1;
268 if result_lsb == 0 && (diff & 1) != 0 {
269 1
270 } else {
271 0
272 }
273 }
274 };
275 let result = (diff >> 1) + r as i128;
276 (result as i64).cast_unsigned() & sew_mask(sew)
277}
278
279#[inline(always)]
287#[doc(hidden)]
288pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
289 let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits()));
291 let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits()));
292 let sa = i128::from(sign_extend(a, sew));
293 let sb = i128::from(sign_extend(b, sew));
294 if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
297 *vxsat = true;
298 return max_sew.cast_unsigned() & sew_mask(sew);
299 }
300 let product = sa * sb;
303 let doubled = product << 1;
306 let shift = u32::from(sew.bits());
309 let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
310 let truncated = doubled >> shift;
312 let r = round_increment(
313 low_bits,
314 shift.min(64),
315 mode,
316 (truncated.cast_unsigned() as u64) & 1,
317 );
318 let result = (truncated as i64).wrapping_add(r.cast_signed());
320 if result < min_sew {
322 *vxsat = true;
323 min_sew.cast_unsigned() & sew_mask(sew)
324 } else if result > max_sew {
325 *vxsat = true;
326 max_sew.cast_unsigned() & sew_mask(sew)
327 } else {
328 result.cast_unsigned() & sew_mask(sew)
329 }
330}
331
332#[inline(always)]
344#[doc(hidden)]
345pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
346 let shifted = rounded_srl(vs2_elem, shamt, mode);
348 let max_dst = sew_mask(sew);
350 if shifted > max_dst {
351 *vxsat = true;
352 max_dst
353 } else {
354 shifted & max_dst
355 }
356}
357
358#[inline(always)]
363#[doc(hidden)]
364pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
365 let double_sew_bits = sew.bits() * 2;
368 let shift_amt = i64::BITS - u32::from(double_sew_bits);
369 let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
370 let low_bits = signed_wide.cast_unsigned()
373 & if double_sew_bits == 64 {
374 u64::MAX
375 } else {
376 (1u64 << double_sew_bits) - 1
377 };
378 let truncated = signed_wide >> shamt;
379 let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
380 let rounded = truncated.wrapping_add(r.cast_signed());
381 let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits()));
383 let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits()));
384 if rounded < min_dst {
385 *vxsat = true;
386 min_dst.cast_unsigned() & sew_mask(sew)
387 } else if rounded > max_dst {
388 *vxsat = true;
389 max_dst.cast_unsigned() & sew_mask(sew)
390 } else {
391 rounded.cast_unsigned() & sew_mask(sew)
392 }
393}
394
395#[inline(always)]
407pub unsafe fn read_wide_element_u64<const VLENB: usize>(
408 vreg: &[[u8; VLENB]; 32],
409 base_reg: usize,
410 elem_i: u32,
411 sew: Vsew,
412) -> u64 {
413 let double_sew_bytes = usize::from(sew.bytes()) * 2;
414 let elems_per_reg = VLENB / double_sew_bytes;
415 let reg_off = elem_i as usize / elems_per_reg;
416 let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
417 let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
419 let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
421 let mut buf = [0u8; 8];
422 unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
424 u64::from_le_bytes(buf)
425}
426
427#[inline(always)]
435#[expect(clippy::too_many_arguments, reason = "Internal API")]
436#[doc(hidden)]
437pub unsafe fn execute_fixed_point_op<Reg, ExtState, CustomError, F>(
438 ext_state: &mut ExtState,
439 vd: VReg,
440 vs2: VReg,
441 src: OpSrc,
442 vm: bool,
443 vl: u32,
444 vstart: u32,
445 sew: Vsew,
446 op: F,
447) where
448 Reg: Register,
449 ExtState: VectorRegistersExt<Reg, CustomError>,
450 [(); ExtState::ELEN as usize]:,
451 [(); ExtState::VLEN as usize]:,
452 [(); ExtState::VLENB as usize]:,
453 CustomError: fmt::Debug,
454 F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
456{
457 let vxrm = ext_state.vxrm();
458 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
460 let vd_base = vd.bits();
461 let vs2_base = vs2.bits();
462 let mut any_sat = false;
463 for i in vstart..vl {
464 if !mask_bit(&mask_buf, i) {
465 continue;
466 }
467 let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
469 let b = match &src {
470 OpSrc::Vreg(vs1_base) => {
471 unsafe { read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew) }
473 }
474 OpSrc::Scalar(val) => *val,
475 };
476 let result = op(a, b, sew, vxrm, &mut any_sat);
477 unsafe {
479 write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
480 }
481 }
482 if any_sat {
483 ext_state.set_vxsat(true);
485 }
486 ext_state.mark_vs_dirty();
487 ext_state.reset_vstart();
488}
489
490#[inline(always)]
505#[expect(clippy::too_many_arguments, reason = "Internal API")]
506#[doc(hidden)]
507pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, CustomError, F>(
508 ext_state: &mut ExtState,
509 vd: VReg,
510 vs2: VReg,
511 src: OpSrc,
512 vm: bool,
513 vl: u32,
514 vstart: u32,
515 sew: Vsew,
516 op: F,
517) where
518 Reg: Register,
519 ExtState: VectorRegistersExt<Reg, CustomError>,
520 [(); ExtState::ELEN as usize]:,
521 [(); ExtState::VLEN as usize]:,
522 [(); ExtState::VLENB as usize]:,
523 CustomError: fmt::Debug,
524 F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
526{
527 let vxrm = ext_state.vxrm();
528 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
530 let vd_base = vd.bits();
531 let vs2_base = vs2.bits();
532 let mut any_sat = false;
533 let shamt_mask = u64::from(sew.bits() * 2 - 1);
535 for i in vstart..vl {
536 if !mask_bit(&mask_buf, i) {
537 continue;
538 }
539 let wide_a =
542 unsafe { read_wide_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
543 let shamt = match &src {
544 OpSrc::Vreg(vs1_base) => {
545 let raw = unsafe {
547 read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
548 };
549 (raw & shamt_mask) as u32
550 }
551 OpSrc::Scalar(val) => (*val & shamt_mask) as u32,
552 };
553 let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
554 unsafe {
556 write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
557 }
558 }
559 if any_sat {
560 ext_state.set_vxsat(true);
561 }
562 ext_state.mark_vs_dirty();
563 ext_state.reset_vstart();
564}
565
566#[inline(always)]
570#[doc(hidden)]
571pub fn check_narrowing_sew<Reg, Memory, PC, CustomError>(
572 program_counter: &PC,
573 sew: Vsew,
574) -> Result<(), ExecutionError<Reg::Type, CustomError>>
575where
576 Reg: Register,
577 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
578{
579 if sew.bits() > 32 {
580 return Err(ExecutionError::IllegalInstruction {
581 address: program_counter.old_pc(INSTRUCTION_SIZE),
582 });
583 }
584 Ok(())
585}
586
587#[inline(always)]
589#[doc(hidden)]
590pub fn check_vs2_narrowing_alignment<Reg, Memory, PC, CustomError>(
591 program_counter: &PC,
592 vs2: VReg,
593 group_regs: u8,
594) -> Result<(), ExecutionError<Reg::Type, CustomError>>
595where
596 Reg: Register,
597 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
598{
599 if group_regs > 4 {
602 return Err(ExecutionError::IllegalInstruction {
603 address: program_counter.old_pc(INSTRUCTION_SIZE),
604 });
605 }
606 let double_group = group_regs * 2;
607 let vs2_idx = vs2.bits();
608 if !vs2_idx.is_multiple_of(double_group) || vs2_idx + double_group > 32 {
609 return Err(ExecutionError::IllegalInstruction {
610 address: program_counter.old_pc(INSTRUCTION_SIZE),
611 });
612 }
613 Ok(())
614}