ab_riscv_interpreter/v/zve64x/arith/
zve64x_arith_helpers.rs1use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10#[inline(always)]
12#[doc(hidden)]
13pub fn check_vreg_group_alignment<Reg, Memory, PC, CustomError>(
14 program_counter: &PC,
15 vreg: VReg,
16 group_regs: u8,
17) -> Result<(), ExecutionError<Reg::Type, CustomError>>
18where
19 Reg: Register,
20 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
21{
22 let vreg_idx = vreg.bits();
23 if !vreg_idx.is_multiple_of(group_regs) || vreg_idx + group_regs > 32 {
24 return Err(ExecutionError::IllegalInstruction {
25 address: program_counter.old_pc(INSTRUCTION_SIZE),
26 });
27 }
28 Ok(())
29}
30
31#[inline(always)]
37#[doc(hidden)]
38pub fn check_mask_dest_no_overlap<Reg, Memory, PC, CustomError>(
39 program_counter: &PC,
40 vd: VReg,
41 src_base: VReg,
42 group_regs: u8,
43) -> Result<(), ExecutionError<Reg::Type, CustomError>>
44where
45 Reg: Register,
46 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
47{
48 if group_regs > 1 {
49 let vd_idx = vd.bits();
50 let src = src_base.bits();
51 if vd_idx >= src && vd_idx < src + group_regs {
52 return Err(ExecutionError::IllegalInstruction {
53 address: program_counter.old_pc(INSTRUCTION_SIZE),
54 });
55 }
56 }
57 Ok(())
58}
59
60#[inline(always)]
71pub(in super::super) unsafe fn read_element_u64<const VLENB: usize>(
72 vreg: &[[u8; VLENB]; 32],
73 base_reg: usize,
74 elem_i: u32,
75 sew: Vsew,
76) -> u64 {
77 let sew_bytes = usize::from(sew.bytes());
78 let elems_per_reg = VLENB / sew_bytes;
79 let reg_off = elem_i as usize / elems_per_reg;
80 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
81 let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
83 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
86 let mut buf = [0u8; 8];
87 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
89 u64::from_le_bytes(buf)
90}
91
92#[inline(always)]
98pub(in super::super) unsafe fn write_element_u64<const VLENB: usize>(
99 vreg: &mut [[u8; VLENB]; 32],
100 base_reg: u8,
101 elem_i: u32,
102 sew: Vsew,
103 value: u64,
104) {
105 let sew_bytes = usize::from(sew.bytes());
106 let elems_per_reg = VLENB / sew_bytes;
107 let reg_off = elem_i as usize / elems_per_reg;
108 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
109 let buf = value.to_le_bytes();
110 let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
112 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
115 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
117}
118
119#[inline(always)]
129pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
130 vreg: &mut [[u8; VLENB]; 32],
131 vd: VReg,
132 elem_i: u32,
133 result: bool,
134) {
135 let byte_idx = (elem_i / u8::BITS) as usize;
136 let bit_idx = elem_i % u8::BITS;
137 let byte = unsafe {
139 vreg.get_unchecked_mut(usize::from(vd.bits()))
140 .get_unchecked_mut(byte_idx)
141 };
142 if result {
143 *byte |= 1 << bit_idx;
144 } else {
145 *byte &= !(1 << bit_idx);
146 }
147}
148
149#[derive(Debug)]
151#[doc(hidden)]
152pub enum OpSrc {
153 Vreg(u8),
155 Scalar(u64),
157}
158
159#[inline(always)]
170#[expect(clippy::too_many_arguments, reason = "Internal API")]
171#[doc(hidden)]
172pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
173 ext_state: &mut ExtState,
174 vd: VReg,
175 vs2: VReg,
176 src: OpSrc,
177 vm: bool,
178 vl: u32,
179 vstart: u32,
180 sew: Vsew,
181 op: F,
182) where
183 Reg: Register,
184 ExtState: VectorRegistersExt<Reg, CustomError>,
185 [(); ExtState::ELEN as usize]:,
186 [(); ExtState::VLEN as usize]:,
187 [(); ExtState::VLENB as usize]:,
188 CustomError: fmt::Debug,
189 F: Fn(u64, u64, Vsew) -> u64,
190{
191 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
193
194 let vd_base = vd.bits();
195 let vs2_base = vs2.bits();
196
197 for i in vstart..vl {
198 if !mask_bit(&mask_buf, i) {
199 continue;
200 }
201
202 let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
205
206 let b = match &src {
207 OpSrc::Vreg(vs1_base) => {
208 unsafe { read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew) }
210 }
211 OpSrc::Scalar(val) => *val,
212 };
213
214 let result = op(a, b, sew);
215
216 unsafe {
219 write_element_u64(ext_state.write_vreg(), vd_base, i, sew, result);
220 }
221 }
222
223 ext_state.mark_vs_dirty();
224 ext_state.reset_vstart();
225}
226
227#[inline(always)]
241#[expect(clippy::too_many_arguments, reason = "Internal API")]
242#[doc(hidden)]
243pub unsafe fn execute_compare_op<Reg, ExtState, CustomError, F>(
244 ext_state: &mut ExtState,
245 vd: VReg,
246 vs2: VReg,
247 src: OpSrc,
248 vm: bool,
249 vl: u32,
250 vstart: u32,
251 sew: Vsew,
252 op: F,
253) where
254 Reg: Register,
255 ExtState: VectorRegistersExt<Reg, CustomError>,
256 [(); ExtState::ELEN as usize]:,
257 [(); ExtState::VLEN as usize]:,
258 [(); ExtState::VLENB as usize]:,
259 CustomError: fmt::Debug,
260 F: Fn(u64, u64, Vsew) -> bool,
261{
262 let mask_buf = unsafe { snapshot_mask(ext_state.read_vreg(), vm, vl) };
264
265 let vs2_base = vs2.bits();
266
267 for i in vstart..vl {
268 if !mask_bit(&mask_buf, i) {
271 continue;
272 }
273
274 let a = unsafe { read_element_u64(ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
276
277 let b = match &src {
278 OpSrc::Vreg(vs1_base) => {
279 unsafe { read_element_u64(ext_state.read_vreg(), usize::from(*vs1_base), i, sew) }
281 }
282 OpSrc::Scalar(val) => *val,
283 };
284
285 let result = op(a, b, sew);
286
287 unsafe {
289 write_mask_bit(ext_state.write_vreg(), vd, i, result);
290 }
291 }
292
293 ext_state.mark_vs_dirty();
294 ext_state.reset_vstart();
295}
296
297#[inline(always)]
299#[doc(hidden)]
300pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
301 let shift = u64::BITS - u32::from(sew.bits());
302 (val.cast_signed() << shift) >> shift
303}
304
305#[inline(always)]
310#[doc(hidden)]
311pub fn sew_mask(sew: Vsew) -> u64 {
312 if u32::from(sew.bits()) == u64::BITS {
313 u64::MAX
314 } else {
315 (1u64 << sew.bits()) - 1
316 }
317}