Skip to main content

ab_riscv_interpreter/v/zvexx/carry/
zvexx_carry_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5    OpSrc, check_mask_dest_no_overlap, check_vreg_group_alignment,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8    read_element_u64, sew_mask, write_element_u64, write_mask_bit,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
11use ab_riscv_primitives::prelude::*;
12use core::fmt;
13
14// TODO: Safety comment here doesn't make sense
15/// Read a single mask bit from vector register `v0` at element index `i`.
16///
17/// Used to retrieve the per-element carry-in or borrow-in for vadc/vsbc.
18///
19/// # Safety
20/// `i / 8 < VLENB` must hold, guaranteed when `i < vl <= VLEN`.
21#[inline(always)]
22pub(in super::super) unsafe fn carry_bit<const VLENB: usize>(
23    vregs: &VectorRegisterFile<VLENB>,
24    i: u32,
25) -> u64 {
26    let v0 = vregs.get(VReg::V0);
27    u64::from(mask_bit(v0, i))
28}
29
30/// Execute an element-wise add-with-carry over `vstart..vl`, writing SEW-wide data results into
31/// `vd`.
32///
33/// Carry-in for each element is read from `v0[i]` when `with_carry` is true. All elements in
34/// `vstart..vl` are processed unconditionally (no execution mask).
35///
36/// # Safety
37/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
38/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
39/// - `src` register satisfies the same alignment (verified by caller)
40/// - `vd.to_bits() != 0` (vd must not overlap v0, which holds the carry-in)
41/// - `vl <= group_regs * VLENB / sew_bytes`
42#[inline(always)]
43#[doc(hidden)]
44pub unsafe fn execute_carry_add<Reg, ExtState, CustomError>(
45    ext_state: &mut ExtState,
46    vd: VReg,
47    vs2: VReg,
48    src: OpSrc,
49    with_carry: bool,
50    sew: Vsew,
51) where
52    Reg: Register,
53    ExtState: VectorRegistersExt<Reg, CustomError>,
54    [(); ExtState::ELEN as usize]:,
55    [(); ExtState::VLEN as usize]:,
56    [(); ExtState::VLENB as usize]:,
57    CustomError: fmt::Debug,
58{
59    let vl = ext_state.vl();
60    let vstart = ext_state.vstart();
61    for i in u32::from(vstart)..vl {
62        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
63        // `i < vl <= group_regs * elems_per_reg`, so
64        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
65        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
66        let b = match src {
67            OpSrc::Vreg(vs1_base) => {
68                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
69                // constraint as vs2; the index argument is identical, so the same bound holds:
70                // `vs1_base + i / elems_per_reg < 32`
71                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
72            }
73            OpSrc::Scalar(val) => val,
74        };
75        let c = if with_carry {
76            // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
77            unsafe { carry_bit(ext_state.read_vregs(), i) }
78        } else {
79            0
80        };
81
82        // Wrap naturally: write_element_u64 writes only the low sew_bytes
83        let result = a.wrapping_add(b).wrapping_add(c);
84        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
85        // `i < vl <= group_regs * elems_per_reg`, so
86        // `vd + i / elems_per_reg < vd + group_regs <= 32`
87        unsafe {
88            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
89        }
90    }
91
92    ext_state.mark_vs_dirty();
93    ext_state.reset_vstart();
94}
95
96/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing SEW-wide data results
97/// into `vd`.
98///
99/// Borrow-in for each element is read from `v0[i]` (always true for vsbc). All elements in
100/// `vstart..vl` are processed unconditionally.
101///
102/// # Safety
103/// Same as [`execute_carry_add()`].
104#[inline(always)]
105#[doc(hidden)]
106pub unsafe fn execute_carry_sub<Reg, ExtState, CustomError>(
107    ext_state: &mut ExtState,
108    vd: VReg,
109    vs2: VReg,
110    src: OpSrc,
111    sew: Vsew,
112) where
113    Reg: Register,
114    ExtState: VectorRegistersExt<Reg, CustomError>,
115    [(); ExtState::ELEN as usize]:,
116    [(); ExtState::VLEN as usize]:,
117    [(); ExtState::VLENB as usize]:,
118    CustomError: fmt::Debug,
119{
120    let vl = ext_state.vl();
121    let vstart = ext_state.vstart();
122    for i in u32::from(vstart)..vl {
123        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
124        // `i < vl <= group_regs * elems_per_reg`, so
125        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
126        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
127        let b = match src {
128            OpSrc::Vreg(vs1_base) => {
129                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
130                // constraint as vs2; the index argument is identical, so the same bound holds:
131                // `vs1_base + i / elems_per_reg < 32`
132                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
133            }
134            OpSrc::Scalar(val) => val,
135        };
136        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
137        let borrow = unsafe { carry_bit(ext_state.read_vregs(), i) };
138
139        let result = a.wrapping_sub(b).wrapping_sub(borrow);
140        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
141        // `i < vl <= group_regs * elems_per_reg`, so
142        // `vd + i / elems_per_reg < vd + group_regs <= 32`
143        unsafe {
144            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
145        }
146    }
147
148    ext_state.mark_vs_dirty();
149    ext_state.reset_vstart();
150}
151
152/// Execute an element-wise add-with-carry over `vstart..vl`, writing the carry-out as a single mask
153/// bit per element into `vd`.
154///
155/// When `with_carry` is true, carry-in for element `i` is read from `v0[i]`. When false, carry-in
156/// is treated as zero.
157///
158/// All elements are processed unconditionally (no execution mask).
159///
160/// Tail mask bits (indices `>= vl`) are left undisturbed per spec ยง5.3.
161///
162/// # Safety
163/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
164/// - `src` register satisfies the same alignment
165/// - `vl <= group_regs * VLENB / sew_bytes` and `vl <= VLEN`
166/// - vd overlap constraints checked by caller
167#[inline(always)]
168#[doc(hidden)]
169pub unsafe fn execute_carry_add_mask<Reg, ExtState, CustomError>(
170    ext_state: &mut ExtState,
171    vd: VReg,
172    vs2: VReg,
173    src: OpSrc,
174    with_carry: bool,
175    sew: Vsew,
176) where
177    Reg: Register,
178    ExtState: VectorRegistersExt<Reg, CustomError>,
179    [(); ExtState::ELEN as usize]:,
180    [(); ExtState::VLEN as usize]:,
181    [(); ExtState::VLENB as usize]:,
182    CustomError: fmt::Debug,
183{
184    let vl = ext_state.vl();
185    let vstart = ext_state.vstart();
186    let mask = sew_mask(sew);
187
188    for i in u32::from(vstart)..vl {
189        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
190        // `i < vl <= group_regs * elems_per_reg`, so
191        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
192        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
193        let b = match src {
194            OpSrc::Vreg(vs1_base) => {
195                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
196                // constraint as vs2; the index argument is identical, so the same bound holds:
197                // `vs1_base + i / elems_per_reg < 32`
198                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
199            }
200            OpSrc::Scalar(val) => val,
201        };
202        let c = if with_carry {
203            // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
204            unsafe { carry_bit(ext_state.read_vregs(), i) }
205        } else {
206            0
207        };
208
209        // Use u128 to capture the carry-out bit beyond SEW
210        let sum = u128::from(a & mask) + u128::from(b & mask) + u128::from(c);
211        let carry_out = (sum >> sew.bits_width()) & 1 != 0;
212
213        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
214        unsafe {
215            write_mask_bit(ext_state.write_vregs(), vd, i, carry_out);
216        }
217    }
218
219    ext_state.mark_vs_dirty();
220    ext_state.reset_vstart();
221}
222
223/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing the borrow-out as a
224/// single mask bit per element into `vd`.
225///
226/// When `with_borrow` is true, borrow-in for element `i` is read from `v0[i]`. When false,
227/// borrow-in is treated as zero.
228///
229/// Borrow-out is 1 when the subtraction underflows unsigned:
230/// `borrow_out = (b + borrow_in) > a` (compared as SEW-wide unsigned values).
231///
232/// # Safety
233/// Same as [`execute_carry_add_mask()`].
234#[inline(always)]
235#[doc(hidden)]
236pub unsafe fn execute_carry_sub_mask<Reg, ExtState, CustomError>(
237    ext_state: &mut ExtState,
238    vd: VReg,
239    vs2: VReg,
240    src: OpSrc,
241    with_borrow: bool,
242    sew: Vsew,
243) where
244    Reg: Register,
245    ExtState: VectorRegistersExt<Reg, CustomError>,
246    [(); ExtState::ELEN as usize]:,
247    [(); ExtState::VLEN as usize]:,
248    [(); ExtState::VLENB as usize]:,
249    CustomError: fmt::Debug,
250{
251    let vl = ext_state.vl();
252    let vstart = ext_state.vstart();
253    let mask = sew_mask(sew);
254
255    for i in u32::from(vstart)..vl {
256        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
257        // `i < vl <= group_regs * elems_per_reg`, so
258        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
259        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
260        let b = match src {
261            OpSrc::Vreg(vs1_base) => {
262                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
263                // constraint as vs2; the index argument is identical, so the same bound holds:
264                // `vs1_base + i / elems_per_reg < 32`
265                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
266            }
267            OpSrc::Scalar(val) => val,
268        };
269        let borrow_in = if with_borrow {
270            // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
271            unsafe { carry_bit(ext_state.read_vregs(), i) }
272        } else {
273            0
274        };
275
276        let a_m = u128::from(a & mask);
277        let rhs = u128::from(b & mask) + u128::from(borrow_in);
278        let borrow_out = a_m < rhs;
279
280        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
281        unsafe {
282            write_mask_bit(ext_state.write_vregs(), vd, i, borrow_out);
283        }
284    }
285
286    ext_state.mark_vs_dirty();
287    ext_state.reset_vstart();
288}