ab_riscv_interpreter/v/zvexx/carry/zvexx_carry_helpers.rs
1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5 OpSrc, check_mask_dest_no_overlap, check_vreg_group_alignment,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8 read_element_u64, sew_mask, write_element_u64, write_mask_bit,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
11use ab_riscv_primitives::prelude::*;
12use core::fmt;
13
14// TODO: Safety comment here doesn't make sense
15/// Read a single mask bit from vector register `v0` at element index `i`.
16///
17/// Used to retrieve the per-element carry-in or borrow-in for vadc/vsbc.
18///
19/// # Safety
20/// `i / 8 < VLENB` must hold, guaranteed when `i < vl <= VLEN`.
21#[inline(always)]
22pub(in super::super) unsafe fn carry_bit<const VLENB: usize>(
23 vregs: &VectorRegisterFile<VLENB>,
24 i: u32,
25) -> u64 {
26 let v0 = vregs.get(VReg::V0);
27 u64::from(mask_bit(v0, i))
28}
29
30/// Execute an element-wise add-with-carry over `vstart..vl`, writing SEW-wide data results into
31/// `vd`.
32///
33/// Carry-in for each element is read from `v0[i]` when `WITH_CARRY` is true. All elements in
34/// `vstart..vl` are processed unconditionally (no execution mask).
35///
36/// # Safety
37/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
38/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
39/// - `src` register satisfies the same alignment (verified by caller)
40/// - `vd.to_bits() != 0` (vd must not overlap v0, which holds the carry-in)
41/// - `vl <= group_regs * VLENB / sew_bytes`
42#[inline(always)]
43#[doc(hidden)]
44pub unsafe fn execute_carry_add<const WITH_CARRY: bool, Reg, ExtState, CustomError>(
45 ext_state: &mut ExtState,
46 vd: VReg,
47 vs2: VReg,
48 src: OpSrc,
49 sew: Vsew,
50) where
51 Reg: Register,
52 ExtState: VectorRegistersExt<Reg, CustomError>,
53 [(); ExtState::ELEN as usize]:,
54 [(); ExtState::VLEN as usize]:,
55 [(); ExtState::VLENB as usize]:,
56 CustomError: fmt::Debug,
57{
58 let vl = ext_state.vl();
59 let vstart = ext_state.vstart();
60 for i in u32::from(vstart)..vl {
61 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
62 // `i < vl <= group_regs * elems_per_reg`, so
63 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
64 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
65 let b = match src {
66 OpSrc::Vreg(vs1_base) => {
67 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
68 // constraint as vs2; the index argument is identical, so the same bound holds:
69 // `vs1_base + i / elems_per_reg < 32`
70 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
71 }
72 OpSrc::Scalar(val) => val,
73 };
74 let c = if WITH_CARRY {
75 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
76 unsafe { carry_bit(ext_state.read_vregs(), i) }
77 } else {
78 0
79 };
80
81 // Wrap naturally: write_element_u64 writes only the low sew_bytes
82 let result = a.wrapping_add(b).wrapping_add(c);
83 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
84 // `i < vl <= group_regs * elems_per_reg`, so
85 // `vd + i / elems_per_reg < vd + group_regs <= 32`
86 unsafe {
87 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
88 }
89 }
90
91 ext_state.mark_vs_dirty();
92 ext_state.reset_vstart();
93}
94
95/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing SEW-wide data results
96/// into `vd`.
97///
98/// Borrow-in for each element is read from `v0[i]` (always true for vsbc). All elements in
99/// `vstart..vl` are processed unconditionally.
100///
101/// # Safety
102/// Same as [`execute_carry_add()`].
103#[inline(always)]
104#[doc(hidden)]
105pub unsafe fn execute_carry_sub<Reg, ExtState, CustomError>(
106 ext_state: &mut ExtState,
107 vd: VReg,
108 vs2: VReg,
109 src: OpSrc,
110 sew: Vsew,
111) where
112 Reg: Register,
113 ExtState: VectorRegistersExt<Reg, CustomError>,
114 [(); ExtState::ELEN as usize]:,
115 [(); ExtState::VLEN as usize]:,
116 [(); ExtState::VLENB as usize]:,
117 CustomError: fmt::Debug,
118{
119 let vl = ext_state.vl();
120 let vstart = ext_state.vstart();
121 for i in u32::from(vstart)..vl {
122 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
123 // `i < vl <= group_regs * elems_per_reg`, so
124 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
125 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
126 let b = match src {
127 OpSrc::Vreg(vs1_base) => {
128 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
129 // constraint as vs2; the index argument is identical, so the same bound holds:
130 // `vs1_base + i / elems_per_reg < 32`
131 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
132 }
133 OpSrc::Scalar(val) => val,
134 };
135 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
136 let borrow = unsafe { carry_bit(ext_state.read_vregs(), i) };
137
138 let result = a.wrapping_sub(b).wrapping_sub(borrow);
139 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
140 // `i < vl <= group_regs * elems_per_reg`, so
141 // `vd + i / elems_per_reg < vd + group_regs <= 32`
142 unsafe {
143 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
144 }
145 }
146
147 ext_state.mark_vs_dirty();
148 ext_state.reset_vstart();
149}
150
151/// Execute an element-wise add-with-carry over `vstart..vl`, writing the carry-out as a single mask
152/// bit per element into `vd`.
153///
154/// When `WITH_CARRY` is true, carry-in for element `i` is read from `v0[i]`. When false, carry-in
155/// is treated as zero.
156///
157/// All elements are processed unconditionally (no execution mask).
158///
159/// Tail mask bits (indices `>= vl`) are left undisturbed per spec ยง5.3.
160///
161/// # Safety
162/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
163/// - `src` register satisfies the same alignment
164/// - `vl <= group_regs * VLENB / sew_bytes` and `vl <= VLEN`
165/// - vd overlap constraints checked by caller
166#[inline(always)]
167#[doc(hidden)]
168pub unsafe fn execute_carry_add_mask<const WITH_CARRY: bool, Reg, ExtState, CustomError>(
169 ext_state: &mut ExtState,
170 vd: VReg,
171 vs2: VReg,
172 src: OpSrc,
173 sew: Vsew,
174) where
175 Reg: Register,
176 ExtState: VectorRegistersExt<Reg, CustomError>,
177 [(); ExtState::ELEN as usize]:,
178 [(); ExtState::VLEN as usize]:,
179 [(); ExtState::VLENB as usize]:,
180 CustomError: fmt::Debug,
181{
182 let vl = ext_state.vl();
183 let vstart = ext_state.vstart();
184 let mask = sew_mask(sew);
185
186 for i in u32::from(vstart)..vl {
187 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
188 // `i < vl <= group_regs * elems_per_reg`, so
189 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
190 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
191 let b = match src {
192 OpSrc::Vreg(vs1_base) => {
193 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
194 // constraint as vs2; the index argument is identical, so the same bound holds:
195 // `vs1_base + i / elems_per_reg < 32`
196 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
197 }
198 OpSrc::Scalar(val) => val,
199 };
200 let c = if WITH_CARRY {
201 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
202 unsafe { carry_bit(ext_state.read_vregs(), i) }
203 } else {
204 0
205 };
206
207 // Use u128 to capture the carry-out bit beyond SEW
208 let sum = u128::from(a & mask) + u128::from(b & mask) + u128::from(c);
209 let carry_out = (sum >> sew.bits_width()) & 1 != 0;
210
211 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
212 unsafe {
213 write_mask_bit(ext_state.write_vregs(), vd, i, carry_out);
214 }
215 }
216
217 ext_state.mark_vs_dirty();
218 ext_state.reset_vstart();
219}
220
221/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing the borrow-out as a
222/// single mask bit per element into `vd`.
223///
224/// When `WITH_BORROW` is true, borrow-in for element `i` is read from `v0[i]`. When false,
225/// borrow-in is treated as zero.
226///
227/// Borrow-out is 1 when the subtraction underflows unsigned:
228/// `borrow_out = (b + borrow_in) > a` (compared as SEW-wide unsigned values).
229///
230/// # Safety
231/// Same as [`execute_carry_add_mask()`].
232#[inline(always)]
233#[doc(hidden)]
234pub unsafe fn execute_carry_sub_mask<const WITH_BORROW: bool, Reg, ExtState, CustomError>(
235 ext_state: &mut ExtState,
236 vd: VReg,
237 vs2: VReg,
238 src: OpSrc,
239 sew: Vsew,
240) where
241 Reg: Register,
242 ExtState: VectorRegistersExt<Reg, CustomError>,
243 [(); ExtState::ELEN as usize]:,
244 [(); ExtState::VLEN as usize]:,
245 [(); ExtState::VLENB as usize]:,
246 CustomError: fmt::Debug,
247{
248 let vl = ext_state.vl();
249 let vstart = ext_state.vstart();
250 let mask = sew_mask(sew);
251
252 for i in u32::from(vstart)..vl {
253 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
254 // `i < vl <= group_regs * elems_per_reg`, so
255 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
256 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
257 let b = match src {
258 OpSrc::Vreg(vs1_base) => {
259 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
260 // constraint as vs2; the index argument is identical, so the same bound holds:
261 // `vs1_base + i / elems_per_reg < 32`
262 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
263 }
264 OpSrc::Scalar(val) => val,
265 };
266 let borrow_in = if WITH_BORROW {
267 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
268 unsafe { carry_bit(ext_state.read_vregs(), i) }
269 } else {
270 0
271 };
272
273 let a_m = u128::from(a & mask);
274 let rhs = u128::from(b & mask) + u128::from(borrow_in);
275 let borrow_out = a_m < rhs;
276
277 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
278 unsafe {
279 write_mask_bit(ext_state.write_vregs(), vd, i, borrow_out);
280 }
281 }
282
283 ext_state.mark_vs_dirty();
284 ext_state.reset_vstart();
285}