1use crate::v::vector_registers::VectorRegistersExt;
4pub use crate::v::zve64x::arith::zve64x_arith_helpers::{
5 OpSrc, check_vreg_group_alignment as check_vd, check_vreg_group_alignment as check_vs, sew_mask,
6};
7use crate::v::zve64x::arith::zve64x_arith_helpers::{
8 read_element_u64, sign_extend, write_element_u64,
9};
10use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
11use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
12use crate::{ExecutionError, InterpreterState, ProgramCounter, VirtualMemory};
13use ab_riscv_primitives::instructions::v::{Vsew, Vxrm};
14use ab_riscv_primitives::registers::general_purpose::Register;
15use ab_riscv_primitives::registers::vector::VReg;
16use core::fmt;
17
18#[inline(always)]
23fn round_increment(val: u64, shift: u32, mode: Vxrm, current_result_lsb: u64) -> u64 {
24 if shift == 0 {
25 return 0;
26 }
27 let d_minus1_bit = (val >> (shift - 1)) & 1;
29 let sticky = if shift >= 2 {
31 (val & ((1u64 << (shift - 1)).wrapping_sub(1))) != 0
33 } else {
34 false
35 };
36 match mode {
37 Vxrm::Rnu => d_minus1_bit,
39 Vxrm::Rne => {
41 d_minus1_bit
42 & (if sticky || current_result_lsb != 0 {
43 1
44 } else {
45 0
46 })
47 }
48 Vxrm::Rdn => 0,
50 Vxrm::Rod => {
52 if current_result_lsb == 0 && (d_minus1_bit != 0 || sticky) {
53 1
54 } else {
55 0
56 }
57 }
58 }
59}
60
61#[inline(always)]
65pub fn rounded_srl(val: u64, shift: u32, mode: Vxrm) -> u64 {
66 let truncated = val >> shift;
67 let r = round_increment(val, shift, mode, truncated & 1);
68 truncated.wrapping_add(r)
69}
70
71#[inline(always)]
75pub fn rounded_sra(val: u64, shift: u32, mode: Vxrm, sew: Vsew) -> u64 {
76 let signed = sign_extend(val, sew);
77 let truncated_signed = signed >> shift;
81 let r = round_increment(val, shift, mode, truncated_signed.cast_unsigned() & 1);
82 truncated_signed.cast_unsigned().wrapping_add(r)
83}
84
85#[inline(always)]
89pub fn sat_addu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
90 let mask = sew_mask(sew);
91 let a_w = a & mask;
92 let b_w = b & mask;
93 let result = a_w.wrapping_add(b_w);
94 if result & mask < a_w {
95 *vxsat = true;
97 mask
98 } else {
99 result & mask
100 }
101}
102
103#[inline(always)]
107pub fn sat_add(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
108 let sa = sign_extend(a, sew) as i128;
109 let sb = sign_extend(b, sew) as i128;
110 let result = sa.wrapping_add(sb);
111 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits()));
112 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits()));
113 if result < min_val {
114 *vxsat = true;
115 (min_val as i64).cast_unsigned() & sew_mask(sew)
116 } else if result > max_val {
117 *vxsat = true;
118 (max_val as i64).cast_unsigned() & sew_mask(sew)
119 } else {
120 (result as i64).cast_unsigned() & sew_mask(sew)
121 }
122}
123
124#[inline(always)]
128pub fn sat_subu(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
129 let mask = sew_mask(sew);
130 let a_w = a & mask;
131 let b_w = b & mask;
132 if a_w < b_w {
133 *vxsat = true;
134 0
135 } else {
136 (a_w - b_w) & mask
137 }
138}
139
140#[inline(always)]
144pub fn sat_sub(a: u64, b: u64, sew: Vsew, vxsat: &mut bool) -> u64 {
145 let sa = sign_extend(a, sew) as i128;
146 let sb = sign_extend(b, sew) as i128;
147 let result = sa.wrapping_sub(sb);
148 let min_val = i128::MIN >> (i128::BITS - u32::from(sew.bits()));
149 let max_val = i128::MAX >> (i128::BITS - u32::from(sew.bits()));
150 if result < min_val {
151 *vxsat = true;
152 (min_val as i64).cast_unsigned() & sew_mask(sew)
153 } else if result > max_val {
154 *vxsat = true;
155 (max_val as i64).cast_unsigned() & sew_mask(sew)
156 } else {
157 (result as i64).cast_unsigned() & sew_mask(sew)
158 }
159}
160
161#[inline(always)]
165pub fn avg_addu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
166 let mask = sew_mask(sew);
167 let a_w = a & mask;
168 let b_w = b & mask;
169 let sum = a_w.wrapping_add(b_w);
173 let carry = if sum & mask < a_w { 1u64 } else { 0u64 };
175 let r = round_increment(sum & mask, 1, mode, (sum >> 1) & 1);
179 let shifted = (carry << (sew.bits() as u32 - 1)) | ((sum & mask) >> 1);
181 (shifted.wrapping_add(r)) & mask
182}
183
184#[inline(always)]
188pub fn avg_add(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
189 let sa = sign_extend(a, sew);
190 let sb = sign_extend(b, sew);
191 let sum = (sa as i128).wrapping_add(sb as i128);
193 let r = match mode {
195 Vxrm::Rnu => (sum & 1).cast_unsigned() as u64,
196 Vxrm::Rne => {
197 let result_lsb = ((sum >> 1) & 1).cast_unsigned() as u64;
200 ((sum & 1).cast_unsigned() as u64) & result_lsb
201 }
202 Vxrm::Rdn => 0,
203 Vxrm::Rod => {
204 let result_lsb = (sum >> 1) & 1;
206 if result_lsb == 0 && (sum & 1) != 0 {
207 1
208 } else {
209 0
210 }
211 }
212 };
213 let result = (sum >> 1) + r as i128;
214 (result as i64).cast_unsigned() & sew_mask(sew)
215}
216
217#[inline(always)]
221pub fn avg_subu(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
222 let mask = sew_mask(sew);
223 let a_w = a & mask;
224 let b_w = b & mask;
225 let diff = a_w.wrapping_sub(b_w);
227 let borrow = if a_w < b_w { 1u64 } else { 0u64 };
229 let r = round_increment(diff & mask, 1, mode, (diff >> 1) & 1);
233 let sign_fill = borrow.wrapping_neg(); let shifted = (sign_fill << (sew.bits() as u32 - 1)) | ((diff & mask) >> 1);
240 (shifted.wrapping_add(r)) & mask
241}
242
243#[inline(always)]
247pub fn avg_sub(a: u64, b: u64, sew: Vsew, mode: Vxrm) -> u64 {
248 let sa = sign_extend(a, sew);
249 let sb = sign_extend(b, sew);
250 let diff = (sa as i128).wrapping_sub(sb as i128);
251 let r = match mode {
252 Vxrm::Rnu => (diff & 1).cast_unsigned() as u64,
253 Vxrm::Rne => {
254 let result_lsb = ((diff >> 1) & 1).cast_unsigned() as u64;
255 ((diff & 1).cast_unsigned() as u64) & result_lsb
256 }
257 Vxrm::Rdn => 0,
258 Vxrm::Rod => {
259 let result_lsb = (diff >> 1) & 1;
260 if result_lsb == 0 && (diff & 1) != 0 {
261 1
262 } else {
263 0
264 }
265 }
266 };
267 let result = (diff >> 1) + r as i128;
268 (result as i64).cast_unsigned() & sew_mask(sew)
269}
270
271#[inline(always)]
279pub fn smul(a: u64, b: u64, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
280 let min_sew = i64::MIN >> (i64::BITS - u32::from(sew.bits()));
282 let max_sew = i64::MAX >> (i64::BITS - u32::from(sew.bits()));
283 let sa = i128::from(sign_extend(a, sew));
284 let sb = i128::from(sign_extend(b, sew));
285 if sa == i128::from(min_sew) && sb == i128::from(min_sew) {
288 *vxsat = true;
289 return max_sew.cast_unsigned() & sew_mask(sew);
290 }
291 let product = sa * sb;
294 let doubled = product << 1;
297 let shift = u32::from(sew.bits());
300 let low_bits = (doubled.cast_unsigned() & u128::from(sew_mask(sew))) as u64;
301 let truncated = doubled >> shift;
303 let r = round_increment(
304 low_bits,
305 shift.min(64),
306 mode,
307 (truncated.cast_unsigned() as u64) & 1,
308 );
309 let result = (truncated as i64).wrapping_add(r.cast_signed());
311 if result < min_sew {
313 *vxsat = true;
314 min_sew.cast_unsigned() & sew_mask(sew)
315 } else if result > max_sew {
316 *vxsat = true;
317 max_sew.cast_unsigned() & sew_mask(sew)
318 } else {
319 result.cast_unsigned() & sew_mask(sew)
320 }
321}
322
323#[inline(always)]
335pub fn nclipu(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
336 let shifted = rounded_srl(vs2_elem, shamt, mode);
338 let max_dst = sew_mask(sew);
340 if shifted > max_dst {
341 *vxsat = true;
342 max_dst
343 } else {
344 shifted & max_dst
345 }
346}
347
348#[inline(always)]
353pub fn nclip(vs2_elem: u64, shamt: u32, sew: Vsew, mode: Vxrm, vxsat: &mut bool) -> u64 {
354 let double_sew_bits = sew.bits() * 2;
357 let shift_amt = i64::BITS - u32::from(double_sew_bits);
358 let signed_wide = (vs2_elem.cast_signed() << shift_amt) >> shift_amt;
359 let low_bits = signed_wide.cast_unsigned()
362 & if double_sew_bits == 64 {
363 u64::MAX
364 } else {
365 (1u64 << double_sew_bits) - 1
366 };
367 let truncated = signed_wide >> shamt;
368 let r = round_increment(low_bits, shamt, mode, (truncated.cast_unsigned()) & 1);
369 let rounded = truncated.wrapping_add(r.cast_signed());
370 let min_dst = i64::MIN >> (i64::BITS - u32::from(sew.bits()));
372 let max_dst = i64::MAX >> (i64::BITS - u32::from(sew.bits()));
373 if rounded < min_dst {
374 *vxsat = true;
375 min_dst.cast_unsigned() & sew_mask(sew)
376 } else if rounded > max_dst {
377 *vxsat = true;
378 max_dst.cast_unsigned() & sew_mask(sew)
379 } else {
380 rounded.cast_unsigned() & sew_mask(sew)
381 }
382}
383
384#[inline(always)]
396pub unsafe fn read_wide_element_u64<const VLENB: usize>(
397 vreg: &[[u8; VLENB]; 32],
398 base_reg: usize,
399 elem_i: u32,
400 sew: Vsew,
401) -> u64 {
402 let double_sew_bytes = usize::from(sew.bytes()) * 2;
403 let elems_per_reg = VLENB / double_sew_bytes;
404 let reg_off = elem_i as usize / elems_per_reg;
405 let byte_off = (elem_i as usize % elems_per_reg) * double_sew_bytes;
406 let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
408 let src = unsafe { reg.get_unchecked(byte_off..byte_off + double_sew_bytes) };
410 let mut buf = [0u8; 8];
411 unsafe { buf.get_unchecked_mut(..double_sew_bytes) }.copy_from_slice(src);
413 u64::from_le_bytes(buf)
414}
415
416#[inline(always)]
424#[expect(clippy::too_many_arguments, reason = "Internal API")]
425#[doc(hidden)]
426pub unsafe fn execute_fixed_point_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
427 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
428 vd: VReg,
429 vs2: VReg,
430 src: OpSrc,
431 vm: bool,
432 vl: u32,
433 vstart: u32,
434 sew: Vsew,
435 op: F,
436) where
437 Reg: Register,
438 [(); Reg::N]:,
439 ExtState: VectorRegistersExt<Reg, CustomError>,
440 [(); ExtState::ELEN as usize]:,
441 [(); ExtState::VLEN as usize]:,
442 [(); ExtState::VLENB as usize]:,
443 Memory: VirtualMemory,
444 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
445 CustomError: fmt::Debug,
446 F: Fn(u64, u64, Vsew, Vxrm, &mut bool) -> u64,
448{
449 let vxrm = state.ext_state.vxrm();
450 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
452 let vd_base = vd.bits();
453 let vs2_base = vs2.bits();
454 let mut any_sat = false;
455 for i in vstart..vl {
456 if !mask_bit(&mask_buf, i) {
457 continue;
458 }
459 let a =
461 unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
462 let b = match &src {
463 OpSrc::Vreg(vs1_base) => {
464 unsafe {
466 read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
467 }
468 }
469 OpSrc::Scalar(val) => *val,
470 };
471 let result = op(a, b, sew, vxrm, &mut any_sat);
472 unsafe {
474 write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
475 }
476 }
477 if any_sat {
478 state.ext_state.set_vxsat(true);
480 }
481 state.ext_state.mark_vs_dirty();
482 state.ext_state.reset_vstart();
483}
484
485#[inline(always)]
500#[expect(clippy::too_many_arguments, reason = "Internal API")]
501#[doc(hidden)]
502pub unsafe fn execute_narrowing_clip_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
503 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
504 vd: VReg,
505 vs2: VReg,
506 src: OpSrc,
507 vm: bool,
508 vl: u32,
509 vstart: u32,
510 sew: Vsew,
511 op: F,
512) where
513 Reg: Register,
514 [(); Reg::N]:,
515 ExtState: VectorRegistersExt<Reg, CustomError>,
516 [(); ExtState::ELEN as usize]:,
517 [(); ExtState::VLEN as usize]:,
518 [(); ExtState::VLENB as usize]:,
519 Memory: VirtualMemory,
520 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
521 CustomError: fmt::Debug,
522 F: Fn(u64, u32, Vsew, Vxrm, &mut bool) -> u64,
524{
525 let vxrm = state.ext_state.vxrm();
526 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
528 let vd_base = vd.bits();
529 let vs2_base = vs2.bits();
530 let mut any_sat = false;
531 let shamt_mask = u64::from(sew.bits() * 2 - 1);
533 for i in vstart..vl {
534 if !mask_bit(&mask_buf, i) {
535 continue;
536 }
537 let wide_a = unsafe {
540 read_wide_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew)
541 };
542 let shamt = match &src {
543 OpSrc::Vreg(vs1_base) => {
544 let raw = unsafe {
546 read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
547 };
548 (raw & shamt_mask) as u32
549 }
550 OpSrc::Scalar(val) => (*val & shamt_mask) as u32,
551 };
552 let result = op(wide_a, shamt, sew, vxrm, &mut any_sat);
553 unsafe {
555 write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
556 }
557 }
558 if any_sat {
559 state.ext_state.set_vxsat(true);
560 }
561 state.ext_state.mark_vs_dirty();
562 state.ext_state.reset_vstart();
563}
564
565#[inline(always)]
569pub fn check_narrowing_sew<Reg, ExtState, Memory, PC, IH, CustomError>(
570 state: &InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
571 sew: Vsew,
572) -> Result<(), ExecutionError<Reg::Type, CustomError>>
573where
574 Reg: Register,
575 [(); Reg::N]:,
576 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
577{
578 if sew.bits() > 32 {
579 return Err(ExecutionError::IllegalInstruction {
580 address: state.instruction_fetcher.old_pc(INSTRUCTION_SIZE),
581 });
582 }
583 Ok(())
584}
585
586#[inline(always)]
588pub fn check_vs2_narrowing_alignment<Reg, ExtState, Memory, PC, IH, CustomError>(
589 state: &InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
590 vs2: VReg,
591 group_regs: u8,
592) -> Result<(), ExecutionError<Reg::Type, CustomError>>
593where
594 Reg: Register,
595 [(); Reg::N]:,
596 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
597{
598 let double_group = group_regs.saturating_mul(2);
599 let vs2_idx = vs2.bits();
600 if !vs2_idx.is_multiple_of(double_group) || vs2_idx + double_group > 32 {
601 return Err(ExecutionError::IllegalInstruction {
602 address: state.instruction_fetcher.old_pc(INSTRUCTION_SIZE),
603 });
604 }
605 Ok(())
606}