Skip to main content

ab_riscv_interpreter/rv64/zk/zbkx/
rv64_zbkx_helpers.rs

1//! Opaque helpers for Zbkx extension
2
3#[inline(always)]
4#[doc(hidden)]
5pub fn xperm4(rs1: u64, rs2: u64) -> u64 {
6    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
7    cfg_select! {
8        all(not(miri), target_arch = "riscv64", target_feature = "zbkx") => {
9            unsafe { core::arch::riscv64::xperm4(rs1 as usize, rs2 as usize) as u64 }
10        }
11        _ => {
12            use core::simd::num::SimdUint;
13            use core::simd::{simd_swizzle, u64x16};
14
15            const SHIFT: u64x16 =
16                u64x16::from_array([0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]);
17            const MASK: u64x16 = u64x16::splat(0xf);
18
19            // Unpack nibbles of rs1 into bytes via SIMD: broadcast, shift per-lane, mask
20            let lut = (u64x16::splat(rs1) >> SHIFT) & MASK;
21            // Unpack nibbles of rs2 into byte indices via SIMD
22            let idx = (u64x16::splat(rs2) >> SHIFT) & MASK;
23            // For each nibble of rs2, look up directly from lut; all indices 0–15 are in-bounds
24            let nibbles = lut.cast().swizzle_dyn(idx.cast());
25            // Pack nibbles back: interleave even/odd lanes and fold into bytes
26            let lo = simd_swizzle!(nibbles, [0, 2, 4, 6, 8, 10, 12, 14]);
27            let hi = simd_swizzle!(nibbles, [1, 3, 5, 7, 9, 11, 13, 15]);
28            u64::from_le_bytes((lo | (hi << 4)).to_array())
29        }
30    }
31}
32
33#[inline(always)]
34#[doc(hidden)]
35pub fn xperm8(rs1: u64, rs2: u64) -> u64 {
36    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
37    cfg_select! {
38        all(not(miri), target_arch = "riscv64", target_feature = "zbkx") => {
39            unsafe { core::arch::riscv64::xperm8(rs1 as usize, rs2 as usize) as u64 }
40        }
41        _ => {
42            use core::simd::u8x8;
43
44            let lut = u8x8::from_array(rs1.to_le_bytes());
45            let idx = u8x8::from_array(rs2.to_le_bytes());
46
47            let result = lut.swizzle_dyn(idx);
48
49            u64::from_le_bytes(result.to_array())
50        }
51    }
52}