Skip to main content

ab_riscv_interpreter/v/zvexx/carry/
zvexx_carry_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4pub use crate::v::zvexx::arith::zvexx_arith_helpers::{
5    OpSrc, check_mask_dest_no_overlap, check_vreg_group_alignment,
6};
7use crate::v::zvexx::arith::zvexx_arith_helpers::{
8    read_element_u64, sew_mask, write_element_u64, write_mask_bit,
9};
10use crate::v::zvexx::load::zvexx_load_helpers::mask_bit;
11use ab_riscv_primitives::prelude::*;
12use core::fmt;
13
14// TODO: Safety comment here doesn't make sense
15/// Read a single mask bit from vector register `v0` at element index `i`.
16///
17/// Used to retrieve the per-element carry-in or borrow-in for vadc/vsbc.
18///
19/// # Safety
20/// `i / 8 < VLENB` must hold, guaranteed when `i < vl <= VLEN`.
21#[inline(always)]
22pub(in super::super) unsafe fn carry_bit<const VLENB: usize>(
23    vregs: &VectorRegisterFile<VLENB>,
24    i: u32,
25) -> u64 {
26    let v0 = vregs.get(VReg::V0);
27    u64::from(mask_bit(v0, i))
28}
29
30/// Execute an element-wise add-with-carry over `vstart..vl`, writing SEW-wide data results into
31/// `vd`.
32///
33/// Carry-in for each element is read from `v0[i]` when `WITH_CARRY` is true. All elements in
34/// `vstart..vl` are processed unconditionally (no execution mask).
35///
36/// # Safety
37/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32`
38/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
39/// - `src` register satisfies the same alignment (verified by caller)
40/// - `vd.to_bits() != 0` (vd must not overlap v0, which holds the carry-in)
41/// - `vl <= group_regs * VLENB / sew_bytes`
42#[inline(always)]
43#[doc(hidden)]
44pub unsafe fn execute_carry_add<const WITH_CARRY: bool, Reg, ExtState, CustomError>(
45    ext_state: &mut ExtState,
46    vd: VReg,
47    vs2: VReg,
48    src: OpSrc,
49    sew: Vsew,
50) where
51    Reg: Register,
52    ExtState: VectorRegistersExt<Reg, CustomError>,
53    [(); ExtState::ELEN as usize]:,
54    [(); ExtState::VLEN as usize]:,
55    [(); ExtState::VLENB as usize]:,
56    CustomError: fmt::Debug,
57{
58    let vl = ext_state.vl();
59    let vstart = ext_state.vstart();
60    for i in u32::from(vstart)..vl {
61        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
62        // `i < vl <= group_regs * elems_per_reg`, so
63        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
64        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
65        let b = match src {
66            OpSrc::Vreg(vs1_base) => {
67                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
68                // constraint as vs2; the index argument is identical, so the same bound holds:
69                // `vs1_base + i / elems_per_reg < 32`
70                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
71            }
72            OpSrc::Scalar(val) => val,
73        };
74        let c = if WITH_CARRY {
75            // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
76            unsafe { carry_bit(ext_state.read_vregs(), i) }
77        } else {
78            0
79        };
80
81        // Wrap naturally: write_element_u64 writes only the low sew_bytes
82        let result = a.wrapping_add(b).wrapping_add(c);
83        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
84        // `i < vl <= group_regs * elems_per_reg`, so
85        // `vd + i / elems_per_reg < vd + group_regs <= 32`
86        unsafe {
87            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
88        }
89    }
90
91    ext_state.mark_vs_dirty();
92    ext_state.reset_vstart();
93}
94
95/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing SEW-wide data results
96/// into `vd`.
97///
98/// Borrow-in for each element is read from `v0[i]` (always true for vsbc). All elements in
99/// `vstart..vl` are processed unconditionally.
100///
101/// # Safety
102/// Same as [`execute_carry_add()`].
103#[inline(always)]
104#[doc(hidden)]
105pub unsafe fn execute_carry_sub<Reg, ExtState, CustomError>(
106    ext_state: &mut ExtState,
107    vd: VReg,
108    vs2: VReg,
109    src: OpSrc,
110    sew: Vsew,
111) where
112    Reg: Register,
113    ExtState: VectorRegistersExt<Reg, CustomError>,
114    [(); ExtState::ELEN as usize]:,
115    [(); ExtState::VLEN as usize]:,
116    [(); ExtState::VLENB as usize]:,
117    CustomError: fmt::Debug,
118{
119    let vl = ext_state.vl();
120    let vstart = ext_state.vstart();
121    for i in u32::from(vstart)..vl {
122        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
123        // `i < vl <= group_regs * elems_per_reg`, so
124        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
125        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
126        let b = match src {
127            OpSrc::Vreg(vs1_base) => {
128                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
129                // constraint as vs2; the index argument is identical, so the same bound holds:
130                // `vs1_base + i / elems_per_reg < 32`
131                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
132            }
133            OpSrc::Scalar(val) => val,
134        };
135        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
136        let borrow = unsafe { carry_bit(ext_state.read_vregs(), i) };
137
138        let result = a.wrapping_sub(b).wrapping_sub(borrow);
139        // SAFETY: `vd % group_regs == 0` and `vd + group_regs <= 32` (caller precondition);
140        // `i < vl <= group_regs * elems_per_reg`, so
141        // `vd + i / elems_per_reg < vd + group_regs <= 32`
142        unsafe {
143            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
144        }
145    }
146
147    ext_state.mark_vs_dirty();
148    ext_state.reset_vstart();
149}
150
151/// Execute an element-wise add-with-carry over `vstart..vl`, writing the carry-out as a single mask
152/// bit per element into `vd`.
153///
154/// When `WITH_CARRY` is true, carry-in for element `i` is read from `v0[i]`. When false, carry-in
155/// is treated as zero.
156///
157/// All elements are processed unconditionally (no execution mask).
158///
159/// Tail mask bits (indices `>= vl`) are left undisturbed per spec ยง5.3.
160///
161/// # Safety
162/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32`
163/// - `src` register satisfies the same alignment
164/// - `vl <= group_regs * VLENB / sew_bytes` and `vl <= VLEN`
165/// - vd overlap constraints checked by caller
166#[inline(always)]
167#[doc(hidden)]
168pub unsafe fn execute_carry_add_mask<const WITH_CARRY: bool, Reg, ExtState, CustomError>(
169    ext_state: &mut ExtState,
170    vd: VReg,
171    vs2: VReg,
172    src: OpSrc,
173    sew: Vsew,
174) where
175    Reg: Register,
176    ExtState: VectorRegistersExt<Reg, CustomError>,
177    [(); ExtState::ELEN as usize]:,
178    [(); ExtState::VLEN as usize]:,
179    [(); ExtState::VLENB as usize]:,
180    CustomError: fmt::Debug,
181{
182    let vl = ext_state.vl();
183    let vstart = ext_state.vstart();
184    let mask = sew_mask(sew);
185
186    for i in u32::from(vstart)..vl {
187        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
188        // `i < vl <= group_regs * elems_per_reg`, so
189        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
190        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
191        let b = match src {
192            OpSrc::Vreg(vs1_base) => {
193                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
194                // constraint as vs2; the index argument is identical, so the same bound holds:
195                // `vs1_base + i / elems_per_reg < 32`
196                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
197            }
198            OpSrc::Scalar(val) => val,
199        };
200        let c = if WITH_CARRY {
201            // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
202            unsafe { carry_bit(ext_state.read_vregs(), i) }
203        } else {
204            0
205        };
206
207        // Use u128 to capture the carry-out bit beyond SEW
208        let sum = u128::from(a & mask) + u128::from(b & mask) + u128::from(c);
209        let carry_out = (sum >> sew.bits_width()) & 1 != 0;
210
211        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
212        unsafe {
213            write_mask_bit(ext_state.write_vregs(), vd, i, carry_out);
214        }
215    }
216
217    ext_state.mark_vs_dirty();
218    ext_state.reset_vstart();
219}
220
221/// Execute an element-wise subtract-with-borrow over `vstart..vl`, writing the borrow-out as a
222/// single mask bit per element into `vd`.
223///
224/// When `WITH_BORROW` is true, borrow-in for element `i` is read from `v0[i]`. When false,
225/// borrow-in is treated as zero.
226///
227/// Borrow-out is 1 when the subtraction underflows unsigned:
228/// `borrow_out = (b + borrow_in) > a` (compared as SEW-wide unsigned values).
229///
230/// # Safety
231/// Same as [`execute_carry_add_mask()`].
232#[inline(always)]
233#[doc(hidden)]
234pub unsafe fn execute_carry_sub_mask<const WITH_BORROW: bool, Reg, ExtState, CustomError>(
235    ext_state: &mut ExtState,
236    vd: VReg,
237    vs2: VReg,
238    src: OpSrc,
239    sew: Vsew,
240) where
241    Reg: Register,
242    ExtState: VectorRegistersExt<Reg, CustomError>,
243    [(); ExtState::ELEN as usize]:,
244    [(); ExtState::VLEN as usize]:,
245    [(); ExtState::VLENB as usize]:,
246    CustomError: fmt::Debug,
247{
248    let vl = ext_state.vl();
249    let vstart = ext_state.vstart();
250    let mask = sew_mask(sew);
251
252    for i in u32::from(vstart)..vl {
253        // SAFETY: `vs2 % group_regs == 0` and `vs2 + group_regs <= 32` (caller precondition);
254        // `i < vl <= group_regs * elems_per_reg`, so
255        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
256        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
257        let b = match src {
258            OpSrc::Vreg(vs1_base) => {
259                // SAFETY: caller verified that the vs1 register group satisfies the same alignment
260                // constraint as vs2; the index argument is identical, so the same bound holds:
261                // `vs1_base + i / elems_per_reg < 32`
262                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
263            }
264            OpSrc::Scalar(val) => val,
265        };
266        let borrow_in = if WITH_BORROW {
267            // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
268            unsafe { carry_bit(ext_state.read_vregs(), i) }
269        } else {
270            0
271        };
272
273        let a_m = u128::from(a & mask);
274        let rhs = u128::from(b & mask) + u128::from(borrow_in);
275        let borrow_out = a_m < rhs;
276
277        // SAFETY: `i < vl <= VLEN`, so `i / 8 < VLENB`
278        unsafe {
279            write_mask_bit(ext_state.write_vregs(), vd, i, borrow_out);
280        }
281    }
282
283    ext_state.mark_vs_dirty();
284    ext_state.reset_vstart();
285}