Skip to main content

ab_riscv_interpreter/v/zve64x/
reduction.rs

1//! Zve64x integer reduction instructions
2
3#[cfg(test)]
4mod tests;
5pub mod zve64x_reduction_helpers;
6
7use crate::v::vector_registers::VectorRegistersExt;
8use crate::v::zve64x::arith::zve64x_arith_helpers;
9use crate::v::zve64x::zve64x_helpers;
10use crate::{
11    ExecutableInstruction, ExecutionError, InterpreterState, ProgramCounter, VirtualMemory,
12};
13use ab_riscv_macros::instruction_execution;
14use ab_riscv_primitives::instructions::v::zve64x::reduction::Zve64xReductionInstruction;
15use ab_riscv_primitives::registers::general_purpose::Register;
16use core::fmt;
17use core::ops::ControlFlow;
18
19#[instruction_execution]
20impl<Reg, ExtState, Memory, PC, InstructionHandler, CustomError>
21    ExecutableInstruction<
22        InterpreterState<Reg, ExtState, Memory, PC, InstructionHandler, CustomError>,
23        CustomError,
24    > for Zve64xReductionInstruction<Reg>
25where
26    Reg: Register,
27    [(); Reg::N]:,
28    ExtState: VectorRegistersExt<Reg, CustomError>,
29    [(); ExtState::ELEN as usize]:,
30    [(); ExtState::VLEN as usize]:,
31    [(); ExtState::VLENB as usize]:,
32    Memory: VirtualMemory,
33    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
34    CustomError: fmt::Debug,
35{
36    #[inline(always)]
37    fn execute(
38        self,
39        state: &mut InterpreterState<Reg, ExtState, Memory, PC, InstructionHandler, CustomError>,
40    ) -> Result<ControlFlow<()>, ExecutionError<Reg::Type, CustomError>> {
41        match self {
42            Self::Vredsum { vd, vs2, vs1, vm } => {
43                if !state.ext_state.vector_instructions_allowed() {
44                    Err(ExecutionError::IllegalInstruction {
45                        address: state
46                            .instruction_fetcher
47                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
48                    })?;
49                }
50                let vtype = state
51                    .ext_state
52                    .vtype()
53                    .ok_or(ExecutionError::IllegalInstruction {
54                        address: state
55                            .instruction_fetcher
56                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
57                    })?;
58                let group_regs = vtype.vlmul().register_count();
59                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
60                let sew = vtype.vsew();
61                let vl = state.ext_state.vl();
62                let vstart = u32::from(state.ext_state.vstart());
63                // SAFETY: `vs2` alignment checked; `vs1` and `vd` are single-register scalar
64                // operands with no LMUL alignment constraint; `vl <= VLMAX`
65                unsafe {
66                    zve64x_reduction_helpers::execute_reduce_op(
67                        state,
68                        vd,
69                        vs2,
70                        vs1,
71                        vm,
72                        vl,
73                        vstart,
74                        sew,
75                        |acc, elem, _sew| acc.wrapping_add(elem),
76                    );
77                }
78            }
79            Self::Vredand { vd, vs2, vs1, vm } => {
80                if !state.ext_state.vector_instructions_allowed() {
81                    Err(ExecutionError::IllegalInstruction {
82                        address: state
83                            .instruction_fetcher
84                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
85                    })?;
86                }
87                let vtype = state
88                    .ext_state
89                    .vtype()
90                    .ok_or(ExecutionError::IllegalInstruction {
91                        address: state
92                            .instruction_fetcher
93                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
94                    })?;
95                let group_regs = vtype.vlmul().register_count();
96                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
97                let sew = vtype.vsew();
98                let vl = state.ext_state.vl();
99                let vstart = u32::from(state.ext_state.vstart());
100                // SAFETY: see `Vredsum`
101                unsafe {
102                    zve64x_reduction_helpers::execute_reduce_op(
103                        state,
104                        vd,
105                        vs2,
106                        vs1,
107                        vm,
108                        vl,
109                        vstart,
110                        sew,
111                        |acc, elem, _sew| acc & elem,
112                    );
113                }
114            }
115            Self::Vredor { vd, vs2, vs1, vm } => {
116                if !state.ext_state.vector_instructions_allowed() {
117                    Err(ExecutionError::IllegalInstruction {
118                        address: state
119                            .instruction_fetcher
120                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
121                    })?;
122                }
123                let vtype = state
124                    .ext_state
125                    .vtype()
126                    .ok_or(ExecutionError::IllegalInstruction {
127                        address: state
128                            .instruction_fetcher
129                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
130                    })?;
131                let group_regs = vtype.vlmul().register_count();
132                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
133                let sew = vtype.vsew();
134                let vl = state.ext_state.vl();
135                let vstart = u32::from(state.ext_state.vstart());
136                // SAFETY: see `Vredsum`
137                unsafe {
138                    zve64x_reduction_helpers::execute_reduce_op(
139                        state,
140                        vd,
141                        vs2,
142                        vs1,
143                        vm,
144                        vl,
145                        vstart,
146                        sew,
147                        |acc, elem, _sew| acc | elem,
148                    );
149                }
150            }
151            Self::Vredxor { vd, vs2, vs1, vm } => {
152                if !state.ext_state.vector_instructions_allowed() {
153                    Err(ExecutionError::IllegalInstruction {
154                        address: state
155                            .instruction_fetcher
156                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
157                    })?;
158                }
159                let vtype = state
160                    .ext_state
161                    .vtype()
162                    .ok_or(ExecutionError::IllegalInstruction {
163                        address: state
164                            .instruction_fetcher
165                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
166                    })?;
167                let group_regs = vtype.vlmul().register_count();
168                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
169                let sew = vtype.vsew();
170                let vl = state.ext_state.vl();
171                let vstart = u32::from(state.ext_state.vstart());
172                // SAFETY: see `Vredsum`
173                unsafe {
174                    zve64x_reduction_helpers::execute_reduce_op(
175                        state,
176                        vd,
177                        vs2,
178                        vs1,
179                        vm,
180                        vl,
181                        vstart,
182                        sew,
183                        |acc, elem, _sew| acc ^ elem,
184                    );
185                }
186            }
187            Self::Vredminu { vd, vs2, vs1, vm } => {
188                if !state.ext_state.vector_instructions_allowed() {
189                    Err(ExecutionError::IllegalInstruction {
190                        address: state
191                            .instruction_fetcher
192                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
193                    })?;
194                }
195                let vtype = state
196                    .ext_state
197                    .vtype()
198                    .ok_or(ExecutionError::IllegalInstruction {
199                        address: state
200                            .instruction_fetcher
201                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
202                    })?;
203                let group_regs = vtype.vlmul().register_count();
204                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
205                let sew = vtype.vsew();
206                let vl = state.ext_state.vl();
207                let vstart = u32::from(state.ext_state.vstart());
208                // SAFETY: see `Vredsum`
209                unsafe {
210                    zve64x_reduction_helpers::execute_reduce_op(
211                        state,
212                        vd,
213                        vs2,
214                        vs1,
215                        vm,
216                        vl,
217                        vstart,
218                        sew,
219                        |acc, elem, sew| {
220                            let mask = zve64x_arith_helpers::sew_mask(sew);
221                            if elem & mask < acc & mask { elem } else { acc }
222                        },
223                    );
224                }
225            }
226            Self::Vredmin { vd, vs2, vs1, vm } => {
227                if !state.ext_state.vector_instructions_allowed() {
228                    Err(ExecutionError::IllegalInstruction {
229                        address: state
230                            .instruction_fetcher
231                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
232                    })?;
233                }
234                let vtype = state
235                    .ext_state
236                    .vtype()
237                    .ok_or(ExecutionError::IllegalInstruction {
238                        address: state
239                            .instruction_fetcher
240                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
241                    })?;
242                let group_regs = vtype.vlmul().register_count();
243                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
244                let sew = vtype.vsew();
245                let vl = state.ext_state.vl();
246                let vstart = u32::from(state.ext_state.vstart());
247                // SAFETY: see `Vredsum`
248                unsafe {
249                    zve64x_reduction_helpers::execute_reduce_op(
250                        state,
251                        vd,
252                        vs2,
253                        vs1,
254                        vm,
255                        vl,
256                        vstart,
257                        sew,
258                        |acc, elem, sew| {
259                            if zve64x_arith_helpers::sign_extend(elem, sew)
260                                < zve64x_arith_helpers::sign_extend(acc, sew)
261                            {
262                                elem
263                            } else {
264                                acc
265                            }
266                        },
267                    );
268                }
269            }
270            Self::Vredmaxu { vd, vs2, vs1, vm } => {
271                if !state.ext_state.vector_instructions_allowed() {
272                    Err(ExecutionError::IllegalInstruction {
273                        address: state
274                            .instruction_fetcher
275                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
276                    })?;
277                }
278                let vtype = state
279                    .ext_state
280                    .vtype()
281                    .ok_or(ExecutionError::IllegalInstruction {
282                        address: state
283                            .instruction_fetcher
284                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
285                    })?;
286                let group_regs = vtype.vlmul().register_count();
287                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
288                let sew = vtype.vsew();
289                let vl = state.ext_state.vl();
290                let vstart = u32::from(state.ext_state.vstart());
291                // SAFETY: see `Vredsum`
292                unsafe {
293                    zve64x_reduction_helpers::execute_reduce_op(
294                        state,
295                        vd,
296                        vs2,
297                        vs1,
298                        vm,
299                        vl,
300                        vstart,
301                        sew,
302                        |acc, elem, sew| {
303                            let mask = zve64x_arith_helpers::sew_mask(sew);
304                            if elem & mask > acc & mask { elem } else { acc }
305                        },
306                    );
307                }
308            }
309            Self::Vredmax { vd, vs2, vs1, vm } => {
310                if !state.ext_state.vector_instructions_allowed() {
311                    Err(ExecutionError::IllegalInstruction {
312                        address: state
313                            .instruction_fetcher
314                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
315                    })?;
316                }
317                let vtype = state
318                    .ext_state
319                    .vtype()
320                    .ok_or(ExecutionError::IllegalInstruction {
321                        address: state
322                            .instruction_fetcher
323                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
324                    })?;
325                let group_regs = vtype.vlmul().register_count();
326                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
327                let sew = vtype.vsew();
328                let vl = state.ext_state.vl();
329                let vstart = u32::from(state.ext_state.vstart());
330                // SAFETY: see `Vredsum`
331                unsafe {
332                    zve64x_reduction_helpers::execute_reduce_op(
333                        state,
334                        vd,
335                        vs2,
336                        vs1,
337                        vm,
338                        vl,
339                        vstart,
340                        sew,
341                        |acc, elem, sew| {
342                            if zve64x_arith_helpers::sign_extend(elem, sew)
343                                > zve64x_arith_helpers::sign_extend(acc, sew)
344                            {
345                                elem
346                            } else {
347                                acc
348                            }
349                        },
350                    );
351                }
352            }
353            Self::Vwredsumu { vd, vs2, vs1, vm } => {
354                if !state.ext_state.vector_instructions_allowed() {
355                    Err(ExecutionError::IllegalInstruction {
356                        address: state
357                            .instruction_fetcher
358                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
359                    })?;
360                }
361                let vtype = state
362                    .ext_state
363                    .vtype()
364                    .ok_or(ExecutionError::IllegalInstruction {
365                        address: state
366                            .instruction_fetcher
367                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
368                    })?;
369                // Widening: 2*SEW must fit in ELEN
370                if u32::from(vtype.vsew().bits()) * 2 > ExtState::ELEN {
371                    Err(ExecutionError::IllegalInstruction {
372                        address: state
373                            .instruction_fetcher
374                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
375                    })?;
376                }
377                let group_regs = vtype.vlmul().register_count();
378                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
379                let sew = vtype.vsew();
380                let vl = state.ext_state.vl();
381                let vstart = u32::from(state.ext_state.vstart());
382                // SAFETY: `vs2` alignment checked; widening SEW constraint checked above;
383                // `vd` and `vs1` are single-register 2*SEW scalar operands
384                unsafe {
385                    zve64x_reduction_helpers::execute_widening_reduce_op(
386                        state,
387                        vd,
388                        vs2,
389                        vs1,
390                        vm,
391                        vl,
392                        vstart,
393                        sew,
394                        // Zero-extend vs2 elements then accumulate
395                        |acc, elem, _sew| acc.wrapping_add(elem),
396                        false,
397                    );
398                }
399            }
400            Self::Vwredsum { vd, vs2, vs1, vm } => {
401                if !state.ext_state.vector_instructions_allowed() {
402                    Err(ExecutionError::IllegalInstruction {
403                        address: state
404                            .instruction_fetcher
405                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
406                    })?;
407                }
408                let vtype = state
409                    .ext_state
410                    .vtype()
411                    .ok_or(ExecutionError::IllegalInstruction {
412                        address: state
413                            .instruction_fetcher
414                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
415                    })?;
416                // Widening: 2*SEW must fit in ELEN
417                if u32::from(vtype.vsew().bits()) * 2 > ExtState::ELEN {
418                    Err(ExecutionError::IllegalInstruction {
419                        address: state
420                            .instruction_fetcher
421                            .old_pc(zve64x_helpers::INSTRUCTION_SIZE),
422                    })?;
423                }
424                let group_regs = vtype.vlmul().register_count();
425                zve64x_arith_helpers::check_vreg_group_alignment(state, vs2, group_regs)?;
426                let sew = vtype.vsew();
427                let vl = state.ext_state.vl();
428                let vstart = u32::from(state.ext_state.vstart());
429                // SAFETY: see `Vwredsumu`
430                unsafe {
431                    zve64x_reduction_helpers::execute_widening_reduce_op(
432                        state,
433                        vd,
434                        vs2,
435                        vs1,
436                        vm,
437                        vl,
438                        vstart,
439                        sew,
440                        // Sign-extend vs2 elements then accumulate
441                        |acc, elem, _sew| acc.wrapping_add(elem),
442                        true,
443                    );
444                }
445            }
446            Self::PhantomZve64xReduction(_) => unreachable!("Never constructed"),
447        }
448
449        Ok(ControlFlow::Continue(()))
450    }
451}