ab_riscv_interpreter/v/zvexx/arith/
zvexx_arith_helpers.rs1use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9use core::hint::cold_path;
10
11#[inline(always)]
13#[doc(hidden)]
14pub fn check_vreg_group_alignment<Reg, Memory, PC, CustomError>(
15 program_counter: &PC,
16 vreg: VReg,
17 group_regs: u8,
18) -> Result<(), ExecutionError<Reg::Type, CustomError>>
19where
20 Reg: Register,
21 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
22{
23 let vreg_idx = vreg.to_bits();
24 if !vreg_idx.is_multiple_of(group_regs) || vreg_idx + group_regs > 32 {
25 cold_path();
26 return Err(ExecutionError::IllegalInstruction {
27 address: program_counter.old_pc(INSTRUCTION_SIZE),
28 });
29 }
30 Ok(())
31}
32
33#[inline(always)]
39#[doc(hidden)]
40pub fn check_mask_dest_no_overlap<Reg, Memory, PC, CustomError>(
41 program_counter: &PC,
42 vd: VReg,
43 src_base: VReg,
44 group_regs: u8,
45) -> Result<(), ExecutionError<Reg::Type, CustomError>>
46where
47 Reg: Register,
48 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
49{
50 if group_regs > 1 {
51 let vd_idx = vd.to_bits();
52 let src = src_base.to_bits();
53 if vd_idx >= src && vd_idx < src + group_regs {
54 cold_path();
55 return Err(ExecutionError::IllegalInstruction {
56 address: program_counter.old_pc(INSTRUCTION_SIZE),
57 });
58 }
59 }
60 Ok(())
61}
62
63#[inline(always)]
74pub(crate) unsafe fn read_element_u64<const VLENB: usize>(
75 vregs: &VectorRegisterFile<VLENB>,
76 base_reg: VReg,
77 elem_i: u32,
78 sew: Vsew,
79) -> u64 {
80 let sew_bytes = usize::from(sew.bytes_width());
81 let elems_per_reg = VLENB / sew_bytes;
82 let reg_off = elem_i as usize / elems_per_reg;
83 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
84 let reg = vregs
86 .get(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
87 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
90 let mut buf = [0u8; 8];
91 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
93 u64::from_le_bytes(buf)
94}
95
96#[inline(always)]
102pub(crate) unsafe fn write_element_u64<const VLENB: usize>(
103 vregs: &mut VectorRegisterFile<VLENB>,
104 base_reg: VReg,
105 elem_i: u32,
106 sew: Vsew,
107 value: u64,
108) {
109 let sew_bytes = usize::from(sew.bytes_width());
110 let elems_per_reg = VLENB / sew_bytes;
111 let reg_off = elem_i as usize / elems_per_reg;
112 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
113 let buf = value.to_le_bytes();
114 let reg = vregs
116 .get_mut(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
117 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
120 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
122}
123
124#[inline(always)]
134pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
135 vregs: &mut VectorRegisterFile<VLENB>,
136 vd: VReg,
137 elem_i: u32,
138 result: bool,
139) {
140 let byte_idx = (elem_i / u8::BITS) as usize;
141 let bit_idx = elem_i % u8::BITS;
142 let byte = unsafe { vregs.get_mut(vd).get_unchecked_mut(byte_idx) };
144 if result {
145 *byte |= 1 << bit_idx;
146 } else {
147 *byte &= !(1 << bit_idx);
148 }
149}
150
151#[derive(Debug)]
153#[doc(hidden)]
154pub enum OpSrc {
155 Vreg(VReg),
157 Scalar(u64),
159}
160
161#[inline(always)]
172#[doc(hidden)]
173pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
174 ext_state: &mut ExtState,
175 vd: VReg,
176 vs2: VReg,
177 src: OpSrc,
178 vm: bool,
179 sew: Vsew,
180 op: F,
181) where
182 Reg: Register,
183 ExtState: VectorRegistersExt<Reg, CustomError>,
184 [(); ExtState::ELEN as usize]:,
185 [(); ExtState::VLEN as usize]:,
186 [(); ExtState::VLENB as usize]:,
187 CustomError: fmt::Debug,
188 F: Fn(u64, u64, Vsew) -> u64,
189{
190 let vl = ext_state.vl();
191 let vstart = ext_state.vstart();
192 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
194
195 for i in u32::from(vstart)..vl {
196 if !mask_bit(&mask_buf, i) {
197 continue;
198 }
199
200 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
203
204 let b = match src {
205 OpSrc::Vreg(vs1_base) => {
206 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
208 }
209 OpSrc::Scalar(val) => val,
210 };
211
212 let result = op(a, b, sew);
213
214 unsafe {
217 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
218 }
219 }
220
221 ext_state.mark_vs_dirty();
222 ext_state.reset_vstart();
223}
224
225#[inline(always)]
239#[doc(hidden)]
240pub unsafe fn execute_compare_op<Reg, ExtState, CustomError, F>(
241 ext_state: &mut ExtState,
242 vd: VReg,
243 vs2: VReg,
244 src: OpSrc,
245 vm: bool,
246 sew: Vsew,
247 op: F,
248) where
249 Reg: Register,
250 ExtState: VectorRegistersExt<Reg, CustomError>,
251 [(); ExtState::ELEN as usize]:,
252 [(); ExtState::VLEN as usize]:,
253 [(); ExtState::VLENB as usize]:,
254 CustomError: fmt::Debug,
255 F: Fn(u64, u64, Vsew) -> bool,
256{
257 let vl = ext_state.vl();
258 let vstart = ext_state.vstart();
259 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
261
262 for i in u32::from(vstart)..vl {
263 if !mask_bit(&mask_buf, i) {
266 continue;
267 }
268
269 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
271
272 let b = match src {
273 OpSrc::Vreg(vs1_base) => {
274 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
276 }
277 OpSrc::Scalar(val) => val,
278 };
279
280 let result = op(a, b, sew);
281
282 unsafe {
284 write_mask_bit(ext_state.write_vregs(), vd, i, result);
285 }
286 }
287
288 ext_state.mark_vs_dirty();
289 ext_state.reset_vstart();
290}
291
292#[inline(always)]
294#[doc(hidden)]
295pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
296 let shift = u64::BITS - u32::from(sew.bits_width());
297 (val.cast_signed() << shift) >> shift
298}
299
300#[inline(always)]
305#[doc(hidden)]
306pub fn sew_mask(sew: Vsew) -> u64 {
307 if u32::from(sew.bits_width()) == u64::BITS {
308 u64::MAX
309 } else {
310 (1u64 << sew.bits_width()) - 1
311 }
312}