Skip to main content

ab_riscv_interpreter/rv32/zk/zbkx/
rv32_zbkx_helpers.rs

1//! Opaque helpers for Zbkx extension
2
3#[inline(always)]
4#[doc(hidden)]
5pub fn xperm4(rs1: u32, rs2: u32) -> u32 {
6    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
7    cfg_select! {
8        all(not(miri), target_arch = "riscv32", target_feature = "zbkx") => {
9            unsafe { core::arch::riscv32::xperm4(rs1 as usize, rs2 as usize) as u32 }
10        }
11        _ => {
12            use core::simd::num::SimdUint;
13            use core::simd::{simd_swizzle, u32x8};
14
15            const SHIFT: u32x8 = u32x8::from_array([0, 4, 8, 12, 16, 20, 24, 28]);
16            const MASK: u32x8 = u32x8::splat(0xf);
17
18            // Unpack nibbles of rs1 into bytes via SIMD: broadcast, shift per-lane, mask
19            let lut = (u32x8::splat(rs1) >> SHIFT) & MASK;
20            // Unpack nibbles of rs2 into byte indices via SIMD
21            let idx = (u32x8::splat(rs2) >> SHIFT) & MASK;
22            // For each nibble of rs2, look up from lut (out-of-bounds -> 0 via swizzle_dyn)
23            let nibbles = lut.cast().swizzle_dyn(idx.cast());
24            // Pack nibbles back: interleave even/odd lanes and fold into bytes
25            let lo = simd_swizzle!(nibbles, [0, 2, 4, 6]);
26            let hi = simd_swizzle!(nibbles, [1, 3, 5, 7]);
27            u32::from_le_bytes((lo | (hi << 4)).to_array())
28        }
29    }
30}
31
32#[inline(always)]
33#[doc(hidden)]
34pub fn xperm8(rs1: u32, rs2: u32) -> u32 {
35    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
36    cfg_select! {
37        all(not(miri), target_arch = "riscv32", target_feature = "zbkx") => {
38            unsafe { core::arch::riscv32::xperm8(rs1 as usize, rs2 as usize) as u32 }
39        }
40        _ => {
41            use core::simd::u8x4;
42
43            let lut = u8x4::from_array(rs1.to_le_bytes());
44            let idx = u8x4::from_array(rs2.to_le_bytes());
45
46            let result = lut.swizzle_dyn(idx);
47
48            u32::from_le_bytes(result.to_array())
49        }
50    }
51}