ab_riscv_interpreter/v/zvexx/carry/zvexx_carry_helpers.rs
1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5 OpSrc, check_mask_dest_no_overlap, check_vreg_group_alignment,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8 read_element_u64, sew_mask, write_element_u64, write_mask_bit,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
11use ab_riscv_primitives::prelude::*;
12use core::fmt;
13
14// TODO: Safety comment here doesn't make sense
15/// Read a single mask bit from vector register `v0` at element index `i`.
16///
17/// Used to retrieve the per-element carry-in or borrow-in for vadc/vsbc.
18///
19/// # Safety
20/// `i / 8 < VLENB` must hold, guaranteed when `i < vl <= VLEN`.
21#[inline(always)]
22pub(in super::super) unsafe fn carry_bit<const VLENB: usize>(
23 vregs: &VectorRegisterFile<VLENB>,
24 i: u32,
25) -> u64 {
26 let v0 = vregs.get(VReg::V0);
27 u64::from(mask_bit(v0, i))
28}
29
30/// Execute an element-wise add-with-carry over `vstart..vl`, writing SEW-wide data results into
31/// `vd`.
32///
33/// Carry-in for each element is read from `v0[i]` when `with_carry` is true. All elements in
34/// `vstart..vl` are processed unconditionally (no execution mask).
35///
36/// # Safety
37/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
38/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
39/// - `src` register satisfies the same alignment (verified by caller)
40/// - `vd.to_bits() != 0` (vd must not overlap v0, which holds the carry-in)
41/// - `vl <= group_regs * VLENB / sew_bytes`
42#[inline(always)]
43#[doc(hidden)]
44pub unsafe fn execute_carry_add<Reg, ExtState, CustomError>(
45 ext_state: &mut ExtState,
46 vd: VReg,
47 vs2: VReg,
48 src: OpSrc,
49 with_carry: bool,
50 sew: Vsew,
51) where
52 Reg: Register,
53 ExtState: VectorRegistersExt<Reg, CustomError>,
54 [(); ExtState::ELEN as usize]:,
55 [(); ExtState::VLEN as usize]:,
56 [(); ExtState::VLENB as usize]:,
57 CustomError: fmt::Debug,
58{
59 let vl = ext_state.vl();
60 let vstart = ext_state.vstart();
61 for i in u32::from(vstart)..vl {
62 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
63 // `i < vl <= group_regs * elems_per_reg`, so
64 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
65 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
66 let b = match src {
67 OpSrc::Vreg(vs1_base) => {
68 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
69 // constraint as vs2; the index argument is identical, so the same bound holds:
70 // `vs1_base + i / elems_per_reg < 32`
71 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
72 }
73 OpSrc::Scalar(val) => val,
74 };
75 let c = if with_carry {
76 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
77 unsafe { carry_bit(ext_state.read_vregs(), i) }
78 } else {
79 0
80 };
81
82 // Wrap naturally: write_element_u64 writes only the low sew_bytes
83 let result = a.wrapping_add(b).wrapping_add(c);
84 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
85 // `i < vl <= group_regs * elems_per_reg`, so
86 // `vd + i / elems_per_reg < vd + group_regs <= 32`
87 unsafe {
88 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
89 }
90 }
91
92 ext_state.mark_vs_dirty();
93 ext_state.reset_vstart();
94}
95
96/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing SEW-wide data results
97/// into `vd`.
98///
99/// Borrow-in for each element is read from `v0[i]` (always true for vsbc). All elements in
100/// `vstart..vl` are processed unconditionally.
101///
102/// # Safety
103/// Same as [`execute_carry_add()`].
104#[inline(always)]
105#[doc(hidden)]
106pub unsafe fn execute_carry_sub<Reg, ExtState, CustomError>(
107 ext_state: &mut ExtState,
108 vd: VReg,
109 vs2: VReg,
110 src: OpSrc,
111 sew: Vsew,
112) where
113 Reg: Register,
114 ExtState: VectorRegistersExt<Reg, CustomError>,
115 [(); ExtState::ELEN as usize]:,
116 [(); ExtState::VLEN as usize]:,
117 [(); ExtState::VLENB as usize]:,
118 CustomError: fmt::Debug,
119{
120 let vl = ext_state.vl();
121 let vstart = ext_state.vstart();
122 for i in u32::from(vstart)..vl {
123 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
124 // `i < vl <= group_regs * elems_per_reg`, so
125 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
126 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
127 let b = match src {
128 OpSrc::Vreg(vs1_base) => {
129 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
130 // constraint as vs2; the index argument is identical, so the same bound holds:
131 // `vs1_base + i / elems_per_reg < 32`
132 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
133 }
134 OpSrc::Scalar(val) => val,
135 };
136 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
137 let borrow = unsafe { carry_bit(ext_state.read_vregs(), i) };
138
139 let result = a.wrapping_sub(b).wrapping_sub(borrow);
140 // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
141 // `i < vl <= group_regs * elems_per_reg`, so
142 // `vd + i / elems_per_reg < vd + group_regs <= 32`
143 unsafe {
144 write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
145 }
146 }
147
148 ext_state.mark_vs_dirty();
149 ext_state.reset_vstart();
150}
151
152/// Execute an element-wise add-with-carry over `vstart..vl`, writing the carry-out as a single mask
153/// bit per element into `vd`.
154///
155/// When `with_carry` is true, carry-in for element `i` is read from `v0[i]`. When false, carry-in
156/// is treated as zero.
157///
158/// All elements are processed unconditionally (no execution mask).
159///
160/// Tail mask bits (indices `>= vl`) are left undisturbed per spec ยง5.3.
161///
162/// # Safety
163/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
164/// - `src` register satisfies the same alignment
165/// - `vl <= group_regs * VLENB / sew_bytes` and `vl <= VLEN`
166/// - vd overlap constraints checked by caller
167#[inline(always)]
168#[doc(hidden)]
169pub unsafe fn execute_carry_add_mask<Reg, ExtState, CustomError>(
170 ext_state: &mut ExtState,
171 vd: VReg,
172 vs2: VReg,
173 src: OpSrc,
174 with_carry: bool,
175 sew: Vsew,
176) where
177 Reg: Register,
178 ExtState: VectorRegistersExt<Reg, CustomError>,
179 [(); ExtState::ELEN as usize]:,
180 [(); ExtState::VLEN as usize]:,
181 [(); ExtState::VLENB as usize]:,
182 CustomError: fmt::Debug,
183{
184 let vl = ext_state.vl();
185 let vstart = ext_state.vstart();
186 let mask = sew_mask(sew);
187
188 for i in u32::from(vstart)..vl {
189 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
190 // `i < vl <= group_regs * elems_per_reg`, so
191 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
192 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
193 let b = match src {
194 OpSrc::Vreg(vs1_base) => {
195 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
196 // constraint as vs2; the index argument is identical, so the same bound holds:
197 // `vs1_base + i / elems_per_reg < 32`
198 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
199 }
200 OpSrc::Scalar(val) => val,
201 };
202 let c = if with_carry {
203 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
204 unsafe { carry_bit(ext_state.read_vregs(), i) }
205 } else {
206 0
207 };
208
209 // Use u128 to capture the carry-out bit beyond SEW
210 let sum = u128::from(a & mask) + u128::from(b & mask) + u128::from(c);
211 let carry_out = (sum >> sew.bits_width()) & 1 != 0;
212
213 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
214 unsafe {
215 write_mask_bit(ext_state.write_vregs(), vd, i, carry_out);
216 }
217 }
218
219 ext_state.mark_vs_dirty();
220 ext_state.reset_vstart();
221}
222
223/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing the borrow-out as a
224/// single mask bit per element into `vd`.
225///
226/// When `with_borrow` is true, borrow-in for element `i` is read from `v0[i]`. When false,
227/// borrow-in is treated as zero.
228///
229/// Borrow-out is 1 when the subtraction underflows unsigned:
230/// `borrow_out = (b + borrow_in) > a` (compared as SEW-wide unsigned values).
231///
232/// # Safety
233/// Same as [`execute_carry_add_mask()`].
234#[inline(always)]
235#[doc(hidden)]
236pub unsafe fn execute_carry_sub_mask<Reg, ExtState, CustomError>(
237 ext_state: &mut ExtState,
238 vd: VReg,
239 vs2: VReg,
240 src: OpSrc,
241 with_borrow: bool,
242 sew: Vsew,
243) where
244 Reg: Register,
245 ExtState: VectorRegistersExt<Reg, CustomError>,
246 [(); ExtState::ELEN as usize]:,
247 [(); ExtState::VLEN as usize]:,
248 [(); ExtState::VLENB as usize]:,
249 CustomError: fmt::Debug,
250{
251 let vl = ext_state.vl();
252 let vstart = ext_state.vstart();
253 let mask = sew_mask(sew);
254
255 for i in u32::from(vstart)..vl {
256 // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
257 // `i < vl <= group_regs * elems_per_reg`, so
258 // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
259 let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
260 let b = match src {
261 OpSrc::Vreg(vs1_base) => {
262 // SAFETY: caller verified that the vs1 register group satisfies the same alignment
263 // constraint as vs2; the index argument is identical, so the same bound holds:
264 // `vs1_base + i / elems_per_reg < 32`
265 unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
266 }
267 OpSrc::Scalar(val) => val,
268 };
269 let borrow_in = if with_borrow {
270 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
271 unsafe { carry_bit(ext_state.read_vregs(), i) }
272 } else {
273 0
274 };
275
276 let a_m = u128::from(a & mask);
277 let rhs = u128::from(b & mask) + u128::from(borrow_in);
278 let borrow_out = a_m < rhs;
279
280 // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
281 unsafe {
282 write_mask_bit(ext_state.write_vregs(), vd, i, borrow_out);
283 }
284 }
285
286 ext_state.mark_vs_dirty();
287 ext_state.reset_vstart();
288}