ab_riscv_interpreter/v/zve64x/arith/
zve64x_arith_helpers.rs1use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, InterpreterState, ProgramCounter, VirtualMemory};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::registers::general_purpose::Register;
9use ab_riscv_primitives::registers::vector::VReg;
10use core::fmt;
11
12#[inline(always)]
14#[doc(hidden)]
15pub fn check_vreg_group_alignment<Reg, ExtState, Memory, PC, IH, CustomError>(
16 state: &InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
17 vreg: VReg,
18 group_regs: u8,
19) -> Result<(), ExecutionError<Reg::Type, CustomError>>
20where
21 Reg: Register,
22 [(); Reg::N]:,
23 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
24{
25 let vd_idx = vreg.bits();
26 if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
27 return Err(ExecutionError::IllegalInstruction {
28 address: state.instruction_fetcher.old_pc(INSTRUCTION_SIZE),
29 });
30 }
31 Ok(())
32}
33
34#[inline(always)]
45pub(in super::super) unsafe fn read_element_u64<const VLENB: usize>(
46 vreg: &[[u8; VLENB]; 32],
47 base_reg: usize,
48 elem_i: u32,
49 sew: Vsew,
50) -> u64 {
51 let sew_bytes = usize::from(sew.bytes());
52 let elems_per_reg = VLENB / sew_bytes;
53 let reg_off = elem_i as usize / elems_per_reg;
54 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
55 let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
57 let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
60 let mut buf = [0u8; 8];
61 unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
63 u64::from_le_bytes(buf)
64}
65
66#[inline(always)]
72pub(in super::super) unsafe fn write_element_u64<const VLENB: usize>(
73 vreg: &mut [[u8; VLENB]; 32],
74 base_reg: u8,
75 elem_i: u32,
76 sew: Vsew,
77 value: u64,
78) {
79 let sew_bytes = usize::from(sew.bytes());
80 let elems_per_reg = VLENB / sew_bytes;
81 let reg_off = elem_i as usize / elems_per_reg;
82 let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
83 let buf = value.to_le_bytes();
84 let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
86 let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
89 dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
91}
92
93#[inline(always)]
103pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
104 vreg: &mut [[u8; VLENB]; 32],
105 vd: VReg,
106 elem_i: u32,
107 result: bool,
108) {
109 let byte_idx = (elem_i / u8::BITS) as usize;
110 let bit_idx = elem_i % u8::BITS;
111 let byte = unsafe {
113 vreg.get_unchecked_mut(usize::from(vd.bits()))
114 .get_unchecked_mut(byte_idx)
115 };
116 if result {
117 *byte |= 1 << bit_idx;
118 } else {
119 *byte &= !(1 << bit_idx);
120 }
121}
122
123#[derive(Debug)]
125#[doc(hidden)]
126pub enum OpSrc {
127 Vreg(u8),
129 Scalar(u64),
131}
132
133#[inline(always)]
144#[expect(clippy::too_many_arguments, reason = "Internal API")]
145#[doc(hidden)]
146pub unsafe fn execute_arith_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
147 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
148 vd: VReg,
149 vs2: VReg,
150 src: OpSrc,
151 vm: bool,
152 vl: u32,
153 vstart: u32,
154 sew: Vsew,
155 op: F,
156) where
157 Reg: Register,
158 [(); Reg::N]:,
159 ExtState: VectorRegistersExt<Reg, CustomError>,
160 [(); ExtState::ELEN as usize]:,
161 [(); ExtState::VLEN as usize]:,
162 [(); ExtState::VLENB as usize]:,
163 Memory: VirtualMemory,
164 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
165 CustomError: fmt::Debug,
166 F: Fn(u64, u64, Vsew) -> u64,
167{
168 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
170
171 let vd_base = vd.bits();
172 let vs2_base = vs2.bits();
173
174 for i in vstart..vl {
175 if !mask_bit(&mask_buf, i) {
176 continue;
177 }
178
179 let a =
182 unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
183
184 let b = match &src {
185 OpSrc::Vreg(vs1_base) => {
186 unsafe {
188 read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
189 }
190 }
191 OpSrc::Scalar(val) => *val,
192 };
193
194 let result = op(a, b, sew);
195
196 unsafe {
199 write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
200 }
201 }
202
203 state.ext_state.mark_vs_dirty();
204 state.ext_state.reset_vstart();
205}
206
207#[inline(always)]
221#[expect(clippy::too_many_arguments, reason = "Internal API")]
222#[doc(hidden)]
223pub unsafe fn execute_compare_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
224 state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
225 vd: VReg,
226 vs2: VReg,
227 src: OpSrc,
228 vm: bool,
229 vl: u32,
230 vstart: u32,
231 sew: Vsew,
232 op: F,
233) where
234 Reg: Register,
235 [(); Reg::N]:,
236 ExtState: VectorRegistersExt<Reg, CustomError>,
237 [(); ExtState::ELEN as usize]:,
238 [(); ExtState::VLEN as usize]:,
239 [(); ExtState::VLENB as usize]:,
240 Memory: VirtualMemory,
241 PC: ProgramCounter<Reg::Type, Memory, CustomError>,
242 CustomError: fmt::Debug,
243 F: Fn(u64, u64, Vsew) -> bool,
244{
245 let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
247
248 let vs2_base = vs2.bits();
249
250 for i in vstart..vl {
251 if !mask_bit(&mask_buf, i) {
254 continue;
255 }
256
257 let a =
259 unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
260
261 let b = match &src {
262 OpSrc::Vreg(vs1_base) => {
263 unsafe {
265 read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
266 }
267 }
268 OpSrc::Scalar(val) => *val,
269 };
270
271 let result = op(a, b, sew);
272
273 unsafe {
275 write_mask_bit(state.ext_state.write_vreg(), vd, i, result);
276 }
277 }
278
279 state.ext_state.mark_vs_dirty();
280 state.ext_state.reset_vstart();
281}
282
283#[inline(always)]
285#[doc(hidden)]
286pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
287 let shift = u64::BITS - u32::from(sew.bits());
288 (val.cast_signed() << shift) >> shift
289}
290
291#[inline(always)]
296#[doc(hidden)]
297pub fn sew_mask(sew: Vsew) -> u64 {
298 if u32::from(sew.bits()) == u64::BITS {
299 u64::MAX
300 } else {
301 (1u64 << sew.bits()) - 1
302 }
303}