Skip to main content

ab_riscv_interpreter/v/zvexx/
reduction.rs

1//! ZveXx integer reduction instructions
2
3#[cfg(test)]
4mod tests;
5pub mod zvexx_reduction_helpers;
6
7use crate::v::vector_registers::VectorRegistersExt;
8use crate::v::zvexx::arith::zvexx_arith_helpers;
9use crate::v::zvexx::zvexx_helpers;
10use crate::{
11    ExecutableInstruction, ExecutableInstructionCsr, ExecutableInstructionOperands, ExecutionError,
12    ProgramCounter, RegisterFile, Rs1Rs2OperandValues, Rs1Rs2Operands, VirtualMemory,
13};
14use ab_riscv_macros::instruction_execution;
15use ab_riscv_primitives::prelude::*;
16use core::fmt;
17use core::ops::ControlFlow;
18
19#[instruction_execution]
20impl<Reg> ExecutableInstructionOperands for ZveXxReductionInstruction<Reg> where Reg: Register {}
21
22#[instruction_execution]
23impl<Reg, ExtState, CustomError> ExecutableInstructionCsr<ExtState, CustomError>
24    for ZveXxReductionInstruction<Reg>
25where
26    Reg: Register,
27{
28}
29
30#[instruction_execution]
31impl<Reg, Regs, ExtState, Memory, PC, InstructionHandler, CustomError>
32    ExecutableInstruction<Regs, ExtState, Memory, PC, InstructionHandler, CustomError>
33    for ZveXxReductionInstruction<Reg>
34where
35    Reg: Register,
36    Regs: RegisterFile<Reg>,
37    ExtState: VectorRegistersExt<Reg, CustomError>,
38    [(); ExtState::ELEN as usize]:,
39    [(); ExtState::VLEN as usize]:,
40    [(); ExtState::VLENB as usize]:,
41    Memory: VirtualMemory,
42    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
43    CustomError: fmt::Debug,
44{
45    #[inline(always)]
46    fn execute(
47        self,
48        Rs1Rs2OperandValues {
49            rs1_value: _,
50            rs2_value: _,
51        }: Rs1Rs2OperandValues<<Self::Reg as Register>::Type>,
52        _regs: &mut Regs,
53        ext_state: &mut ExtState,
54        _memory: &mut Memory,
55        program_counter: &mut PC,
56        _system_instruction_handler: &mut InstructionHandler,
57    ) -> Result<
58        ControlFlow<(), (Self::Reg, <Self::Reg as Register>::Type)>,
59        ExecutionError<Reg::Type, CustomError>,
60    > {
61        match self {
62            Self::Vredsum { vd, vs2, vs1, vm } => {
63                if !ext_state.vector_instructions_allowed() {
64                    ::core::hint::cold_path();
65                    return Err(ExecutionError::IllegalInstruction {
66                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
67                    });
68                }
69                let Some(vtype) = ext_state.vtype() else {
70                    ::core::hint::cold_path();
71                    return Err(ExecutionError::IllegalInstruction {
72                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
73                    });
74                };
75                // Spec ยง14: reductions with vstart > 0 are reserved; raise illegal instruction
76                if ext_state.vstart() != 0 {
77                    ::core::hint::cold_path();
78                    return Err(ExecutionError::IllegalInstruction {
79                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
80                    });
81                }
82                let group_regs = vtype.vlmul().register_count();
83                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
84                    program_counter,
85                    vs2,
86                    group_regs,
87                )?;
88                let sew = vtype.vsew();
89                let vl = ext_state.vl();
90                // SAFETY: `vs2` alignment checked; `vstart == 0` checked;
91                // `vs1` and `vd` are single-register scalar operands
92                unsafe {
93                    zvexx_reduction_helpers::execute_reduce_op(
94                        ext_state,
95                        vd,
96                        vs2,
97                        vs1,
98                        vm,
99                        vl,
100                        sew,
101                        |acc, elem, _sew| acc.wrapping_add(elem),
102                    );
103                }
104            }
105            Self::Vredand { vd, vs2, vs1, vm } => {
106                if !ext_state.vector_instructions_allowed() {
107                    ::core::hint::cold_path();
108                    return Err(ExecutionError::IllegalInstruction {
109                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
110                    });
111                }
112                let Some(vtype) = ext_state.vtype() else {
113                    ::core::hint::cold_path();
114                    return Err(ExecutionError::IllegalInstruction {
115                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
116                    });
117                };
118                if ext_state.vstart() != 0 {
119                    ::core::hint::cold_path();
120                    return Err(ExecutionError::IllegalInstruction {
121                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
122                    });
123                }
124                let group_regs = vtype.vlmul().register_count();
125                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
126                    program_counter,
127                    vs2,
128                    group_regs,
129                )?;
130                let sew = vtype.vsew();
131                let vl = ext_state.vl();
132                // SAFETY: see `Vredsum`
133                unsafe {
134                    zvexx_reduction_helpers::execute_reduce_op(
135                        ext_state,
136                        vd,
137                        vs2,
138                        vs1,
139                        vm,
140                        vl,
141                        sew,
142                        |acc, elem, _sew| acc & elem,
143                    );
144                }
145            }
146            Self::Vredor { vd, vs2, vs1, vm } => {
147                if !ext_state.vector_instructions_allowed() {
148                    ::core::hint::cold_path();
149                    return Err(ExecutionError::IllegalInstruction {
150                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
151                    });
152                }
153                let Some(vtype) = ext_state.vtype() else {
154                    ::core::hint::cold_path();
155                    return Err(ExecutionError::IllegalInstruction {
156                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
157                    });
158                };
159                if ext_state.vstart() != 0 {
160                    ::core::hint::cold_path();
161                    return Err(ExecutionError::IllegalInstruction {
162                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
163                    });
164                }
165                let group_regs = vtype.vlmul().register_count();
166                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
167                    program_counter,
168                    vs2,
169                    group_regs,
170                )?;
171                let sew = vtype.vsew();
172                let vl = ext_state.vl();
173                // SAFETY: see `Vredsum`
174                unsafe {
175                    zvexx_reduction_helpers::execute_reduce_op(
176                        ext_state,
177                        vd,
178                        vs2,
179                        vs1,
180                        vm,
181                        vl,
182                        sew,
183                        |acc, elem, _sew| acc | elem,
184                    );
185                }
186            }
187            Self::Vredxor { vd, vs2, vs1, vm } => {
188                if !ext_state.vector_instructions_allowed() {
189                    ::core::hint::cold_path();
190                    return Err(ExecutionError::IllegalInstruction {
191                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
192                    });
193                }
194                let Some(vtype) = ext_state.vtype() else {
195                    ::core::hint::cold_path();
196                    return Err(ExecutionError::IllegalInstruction {
197                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
198                    });
199                };
200                if ext_state.vstart() != 0 {
201                    ::core::hint::cold_path();
202                    return Err(ExecutionError::IllegalInstruction {
203                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
204                    });
205                }
206                let group_regs = vtype.vlmul().register_count();
207                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
208                    program_counter,
209                    vs2,
210                    group_regs,
211                )?;
212                let sew = vtype.vsew();
213                let vl = ext_state.vl();
214                // SAFETY: see `Vredsum`
215                unsafe {
216                    zvexx_reduction_helpers::execute_reduce_op(
217                        ext_state,
218                        vd,
219                        vs2,
220                        vs1,
221                        vm,
222                        vl,
223                        sew,
224                        |acc, elem, _sew| acc ^ elem,
225                    );
226                }
227            }
228            Self::Vredminu { vd, vs2, vs1, vm } => {
229                if !ext_state.vector_instructions_allowed() {
230                    ::core::hint::cold_path();
231                    return Err(ExecutionError::IllegalInstruction {
232                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
233                    });
234                }
235                let Some(vtype) = ext_state.vtype() else {
236                    ::core::hint::cold_path();
237                    return Err(ExecutionError::IllegalInstruction {
238                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
239                    });
240                };
241                if ext_state.vstart() != 0 {
242                    ::core::hint::cold_path();
243                    return Err(ExecutionError::IllegalInstruction {
244                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
245                    });
246                }
247                let group_regs = vtype.vlmul().register_count();
248                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
249                    program_counter,
250                    vs2,
251                    group_regs,
252                )?;
253                let sew = vtype.vsew();
254                let vl = ext_state.vl();
255                // SAFETY: see `Vredsum`
256                unsafe {
257                    zvexx_reduction_helpers::execute_reduce_op(
258                        ext_state,
259                        vd,
260                        vs2,
261                        vs1,
262                        vm,
263                        vl,
264                        sew,
265                        |acc, elem, sew| {
266                            let mask = zvexx_arith_helpers::sew_mask(sew);
267                            if elem & mask < acc & mask { elem } else { acc }
268                        },
269                    );
270                }
271            }
272            Self::Vredmin { vd, vs2, vs1, vm } => {
273                if !ext_state.vector_instructions_allowed() {
274                    ::core::hint::cold_path();
275                    return Err(ExecutionError::IllegalInstruction {
276                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
277                    });
278                }
279                let Some(vtype) = ext_state.vtype() else {
280                    ::core::hint::cold_path();
281                    return Err(ExecutionError::IllegalInstruction {
282                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
283                    });
284                };
285                if ext_state.vstart() != 0 {
286                    ::core::hint::cold_path();
287                    return Err(ExecutionError::IllegalInstruction {
288                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
289                    });
290                }
291                let group_regs = vtype.vlmul().register_count();
292                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
293                    program_counter,
294                    vs2,
295                    group_regs,
296                )?;
297                let sew = vtype.vsew();
298                let vl = ext_state.vl();
299                // SAFETY: see `Vredsum`
300                unsafe {
301                    zvexx_reduction_helpers::execute_reduce_op(
302                        ext_state,
303                        vd,
304                        vs2,
305                        vs1,
306                        vm,
307                        vl,
308                        sew,
309                        |acc, elem, sew| {
310                            if zvexx_arith_helpers::sign_extend(elem, sew)
311                                < zvexx_arith_helpers::sign_extend(acc, sew)
312                            {
313                                elem
314                            } else {
315                                acc
316                            }
317                        },
318                    );
319                }
320            }
321            Self::Vredmaxu { vd, vs2, vs1, vm } => {
322                if !ext_state.vector_instructions_allowed() {
323                    ::core::hint::cold_path();
324                    return Err(ExecutionError::IllegalInstruction {
325                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
326                    });
327                }
328                let Some(vtype) = ext_state.vtype() else {
329                    ::core::hint::cold_path();
330                    return Err(ExecutionError::IllegalInstruction {
331                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
332                    });
333                };
334                if ext_state.vstart() != 0 {
335                    ::core::hint::cold_path();
336                    return Err(ExecutionError::IllegalInstruction {
337                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
338                    });
339                }
340                let group_regs = vtype.vlmul().register_count();
341                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
342                    program_counter,
343                    vs2,
344                    group_regs,
345                )?;
346                let sew = vtype.vsew();
347                let vl = ext_state.vl();
348                // SAFETY: see `Vredsum`
349                unsafe {
350                    zvexx_reduction_helpers::execute_reduce_op(
351                        ext_state,
352                        vd,
353                        vs2,
354                        vs1,
355                        vm,
356                        vl,
357                        sew,
358                        |acc, elem, sew| {
359                            let mask = zvexx_arith_helpers::sew_mask(sew);
360                            if elem & mask > acc & mask { elem } else { acc }
361                        },
362                    );
363                }
364            }
365            Self::Vredmax { vd, vs2, vs1, vm } => {
366                if !ext_state.vector_instructions_allowed() {
367                    ::core::hint::cold_path();
368                    return Err(ExecutionError::IllegalInstruction {
369                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
370                    });
371                }
372                let Some(vtype) = ext_state.vtype() else {
373                    ::core::hint::cold_path();
374                    return Err(ExecutionError::IllegalInstruction {
375                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
376                    });
377                };
378                if ext_state.vstart() != 0 {
379                    ::core::hint::cold_path();
380                    return Err(ExecutionError::IllegalInstruction {
381                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
382                    });
383                }
384                let group_regs = vtype.vlmul().register_count();
385                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
386                    program_counter,
387                    vs2,
388                    group_regs,
389                )?;
390                let sew = vtype.vsew();
391                let vl = ext_state.vl();
392                // SAFETY: see `Vredsum`
393                unsafe {
394                    zvexx_reduction_helpers::execute_reduce_op(
395                        ext_state,
396                        vd,
397                        vs2,
398                        vs1,
399                        vm,
400                        vl,
401                        sew,
402                        |acc, elem, sew| {
403                            if zvexx_arith_helpers::sign_extend(elem, sew)
404                                > zvexx_arith_helpers::sign_extend(acc, sew)
405                            {
406                                elem
407                            } else {
408                                acc
409                            }
410                        },
411                    );
412                }
413            }
414            Self::Vwredsumu { vd, vs2, vs1, vm } => {
415                if !ext_state.vector_instructions_allowed() {
416                    ::core::hint::cold_path();
417                    return Err(ExecutionError::IllegalInstruction {
418                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
419                    });
420                }
421                let Some(vtype) = ext_state.vtype() else {
422                    ::core::hint::cold_path();
423                    return Err(ExecutionError::IllegalInstruction {
424                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
425                    });
426                };
427                if ext_state.vstart() != 0 {
428                    ::core::hint::cold_path();
429                    return Err(ExecutionError::IllegalInstruction {
430                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
431                    });
432                }
433                // Widening: 2*SEW must fit in ELEN
434                if u32::from(vtype.vsew().bits_width()) * 2 > ExtState::ELEN {
435                    ::core::hint::cold_path();
436                    return Err(ExecutionError::IllegalInstruction {
437                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
438                    });
439                }
440                let group_regs = vtype.vlmul().register_count();
441                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
442                    program_counter,
443                    vs2,
444                    group_regs,
445                )?;
446                let sew = vtype.vsew();
447                let vl = ext_state.vl();
448                // SAFETY: `vs2` alignment checked; widening SEW constraint checked above;
449                // `vstart == 0` checked; `vd` and `vs1` are single-register 2*SEW scalar operands
450                unsafe {
451                    zvexx_reduction_helpers::execute_widening_reduce_op::<false, _, _, _, _>(
452                        ext_state,
453                        vd,
454                        vs2,
455                        vs1,
456                        vm,
457                        vl,
458                        sew,
459                        // Zero-extend vs2 elements then accumulate
460                        |acc, elem, _sew| acc.wrapping_add(elem),
461                    );
462                }
463            }
464            Self::Vwredsum { vd, vs2, vs1, vm } => {
465                if !ext_state.vector_instructions_allowed() {
466                    ::core::hint::cold_path();
467                    return Err(ExecutionError::IllegalInstruction {
468                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
469                    });
470                }
471                let Some(vtype) = ext_state.vtype() else {
472                    ::core::hint::cold_path();
473                    return Err(ExecutionError::IllegalInstruction {
474                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
475                    });
476                };
477                if ext_state.vstart() != 0 {
478                    ::core::hint::cold_path();
479                    return Err(ExecutionError::IllegalInstruction {
480                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
481                    });
482                }
483                if u32::from(vtype.vsew().bits_width()) * 2 > ExtState::ELEN {
484                    ::core::hint::cold_path();
485                    return Err(ExecutionError::IllegalInstruction {
486                        address: program_counter.old_pc(zvexx_helpers::INSTRUCTION_SIZE),
487                    });
488                }
489                let group_regs = vtype.vlmul().register_count();
490                zvexx_arith_helpers::check_vreg_group_alignment::<Reg, _, _, _>(
491                    program_counter,
492                    vs2,
493                    group_regs,
494                )?;
495                let sew = vtype.vsew();
496                let vl = ext_state.vl();
497                // SAFETY: see `Vwredsumu`
498                unsafe {
499                    zvexx_reduction_helpers::execute_widening_reduce_op::<true, _, _, _, _>(
500                        ext_state,
501                        vd,
502                        vs2,
503                        vs1,
504                        vm,
505                        vl,
506                        sew,
507                        // Sign-extend vs2 elements then accumulate
508                        |acc, elem, _sew| acc.wrapping_add(elem),
509                    );
510                }
511            }
512        }
513
514        Ok(ControlFlow::Continue(Default::default()))
515    }
516}