ab_riscv_interpreter/v/zvexx/arith/
zvexx_arith_helpers.rs1use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10#[inline(always)]
12#[doc(hidden)]
13pub fn check_vreg_group_alignment<Reg, Memory, PC, CustomError>(
14 program_counter: &PC,
15 vreg: VReg,
16 group_regs: u8,
17) -> Result<(), ExecutionError<Reg::Type, CustomError>>
18where
19 Reg: Register,
20 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
21{
22 let vreg_idx = vreg.to_bits();
23 if !vreg_idx.is_multiple_of(group_regs) || vreg_idx + group_regs > 32 {
24 return Err(ExecutionError::IllegalInstruction {
25 address: program_counter.old_pc(INSTRUCTION_SIZE),
26 });
27 }
28 Ok(())
29}
30
31#[inline(always)]
37#[doc(hidden)]
38pub fn check_mask_dest_no_overlap<Reg, Memory, PC, CustomError>(
39 program_counter: &PC,
40 vd: VReg,
41 src_base: VReg,
42 group_regs: u8,
43) -> Result<(), ExecutionError<Reg::Type, CustomError>>
44where
45 Reg: Register,
46 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
47{
48 if group_regs > 1 {
49 let vd_idx = vd.to_bits();
50 let src = src_base.to_bits();
51 if vd_idx >= src && vd_idx < src + group_regs {
52 return Err(ExecutionError::IllegalInstruction {
53 address: program_counter.old_pc(INSTRUCTION_SIZE),
54 });
55 }
56 }
57 Ok(())
58}
59
60#[inline(always)]
71pub(in super::super) unsafe fn read_element_u64<const VLENB: usize>(
72 vregs: &VectorRegisterFile<VLENB>,
73 base_reg: VReg,
74 elem_i: u32,
75 sew: Vsew,
76) -> u64 {
77 let sew_bytes = usize::from(sew.bytes_width());
78 let elems_per_reg = VLENB / sew_bytes;
79 let reg_off = elem_i as usize / elems_per_reg;
80 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
81 let reg = vregs
83 .get(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
84 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
87 let mut buf = [0u8; 8];
88 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
90 u64::from_le_bytes(buf)
91}
92
93#[inline(always)]
99pub(in super::super) unsafe fn write_element_u64<const VLENB: usize>(
100 vregs: &mut VectorRegisterFile<VLENB>,
101 base_reg: VReg,
102 elem_i: u32,
103 sew: Vsew,
104 value: u64,
105) {
106 let sew_bytes = usize::from(sew.bytes_width());
107 let elems_per_reg = VLENB / sew_bytes;
108 let reg_off = elem_i as usize / elems_per_reg;
109 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
110 let buf = value.to_le_bytes();
111 let reg = vregs
113 .get_mut(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
114 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
117 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
119}
120
121#[inline(always)]
131pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
132 vregs: &mut VectorRegisterFile<VLENB>,
133 vd: VReg,
134 elem_i: u32,
135 result: bool,
136) {
137 let byte_idx = (elem_i / u8::BITS) as usize;
138 let bit_idx = elem_i % u8::BITS;
139 let byte = unsafe { vregs.get_mut(vd).get_unchecked_mut(byte_idx) };
141 if result {
142 *byte |= 1 << bit_idx;
143 } else {
144 *byte &= !(1 << bit_idx);
145 }
146}
147
148#[derive(Debug)]
150#[doc(hidden)]
151pub enum OpSrc {
152 Vreg(VReg),
154 Scalar(u64),
156}
157
158#[inline(always)]
169#[doc(hidden)]
170pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
171 ext_state: &mut ExtState,
172 vd: VReg,
173 vs2: VReg,
174 src: OpSrc,
175 vm: bool,
176 sew: Vsew,
177 op: F,
178) where
179 Reg: Register,
180 ExtState: VectorRegistersExt<Reg, CustomError>,
181 [(); ExtState::ELEN as usize]:,
182 [(); ExtState::VLEN as usize]:,
183 [(); ExtState::VLENB as usize]:,
184 CustomError: fmt::Debug,
185 F: Fn(u64, u64, Vsew) -> u64,
186{
187 let vl = ext_state.vl();
188 let vstart = ext_state.vstart();
189 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
191
192 for i in u32::from(vstart)..vl {
193 if !mask_bit(&mask_buf, i) {
194 continue;
195 }
196
197 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
200
201 let b = match src {
202 OpSrc::Vreg(vs1_base) => {
203 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
205 }
206 OpSrc::Scalar(val) => val,
207 };
208
209 let result = op(a, b, sew);
210
211 unsafe {
214 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
215 }
216 }
217
218 ext_state.mark_vs_dirty();
219 ext_state.reset_vstart();
220}
221
222#[inline(always)]
236#[doc(hidden)]
237pub unsafe fn execute_compare_op<Reg, ExtState, CustomError, F>(
238 ext_state: &mut ExtState,
239 vd: VReg,
240 vs2: VReg,
241 src: OpSrc,
242 vm: bool,
243 sew: Vsew,
244 op: F,
245) where
246 Reg: Register,
247 ExtState: VectorRegistersExt<Reg, CustomError>,
248 [(); ExtState::ELEN as usize]:,
249 [(); ExtState::VLEN as usize]:,
250 [(); ExtState::VLENB as usize]:,
251 CustomError: fmt::Debug,
252 F: Fn(u64, u64, Vsew) -> bool,
253{
254 let vl = ext_state.vl();
255 let vstart = ext_state.vstart();
256 let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
258
259 for i in u32::from(vstart)..vl {
260 if !mask_bit(&mask_buf, i) {
263 continue;
264 }
265
266 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
268
269 let b = match src {
270 OpSrc::Vreg(vs1_base) => {
271 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
273 }
274 OpSrc::Scalar(val) => val,
275 };
276
277 let result = op(a, b, sew);
278
279 unsafe {
281 write_mask_bit(ext_state.write_vregs(), vd, i, result);
282 }
283 }
284
285 ext_state.mark_vs_dirty();
286 ext_state.reset_vstart();
287}
288
289#[inline(always)]
291#[doc(hidden)]
292pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
293 let shift = u64::BITS - u32::from(sew.bits_width());
294 (val.cast_signed() << shift) >> shift
295}
296
297#[inline(always)]
302#[doc(hidden)]
303pub fn sew_mask(sew: Vsew) -> u64 {
304 if u32::from(sew.bits_width()) == u64::BITS {
305 u64::MAX
306 } else {
307 (1u64 << sew.bits_width()) - 1
308 }
309}