ab_aligned_buffer/
lib.rs

1#![feature(box_vec_non_null, pointer_is_aligned_to, ptr_as_ref_unchecked)]
2#![no_std]
3
4#[cfg(test)]
5mod tests;
6
7extern crate alloc;
8
9use ab_io_type::MAX_ALIGNMENT;
10use alloc::alloc::realloc;
11use alloc::boxed::Box;
12use core::alloc::Layout;
13use core::mem::MaybeUninit;
14use core::ops::{Deref, DerefMut};
15use core::ptr::NonNull;
16use core::slice;
17use core::sync::atomic::{AtomicU32, Ordering};
18use stable_deref_trait::{CloneStableDeref, StableDeref};
19use yoke::CloneableCart;
20
21const _: () = {
22    assert!(
23        align_of::<u128>() == size_of::<u128>(),
24        "Size and alignment are both 16 bytes"
25    );
26    assert!(
27        align_of::<u128>() == MAX_ALIGNMENT as usize,
28        "Alignment of u128 is a max alignment"
29    );
30    assert!(size_of::<u128>() >= size_of::<AtomicU32>());
31    assert!(align_of::<u128>() >= align_of::<AtomicU32>());
32};
33
34#[repr(C, align(16))]
35struct ConstInnerBuffer {
36    strong_count: AtomicU32,
37}
38
39const _: () = {
40    assert!(align_of::<ConstInnerBuffer>() == align_of::<u128>());
41    assert!(size_of::<ConstInnerBuffer>() == size_of::<u128>());
42};
43
44static EMPTY_SHARED_ALIGNED_BUFFER: SharedAlignedBuffer = SharedAlignedBuffer {
45    inner: InnerBuffer {
46        buffer: NonNull::from_ref({
47            static BUFFER: MaybeUninit<ConstInnerBuffer> = MaybeUninit::new(ConstInnerBuffer {
48                strong_count: AtomicU32::new(1),
49            });
50
51            &BUFFER
52        })
53        .cast::<MaybeUninit<u128>>(),
54        capacity: 0,
55        len: 0,
56    },
57};
58
59#[derive(Debug)]
60struct InnerBuffer {
61    // The first bytes are allocated for `strong_count`
62    buffer: NonNull<MaybeUninit<u128>>,
63    capacity: u32,
64    len: u32,
65}
66
67// SAFETY: Heap-allocated memory buffer can be used from any thread
68unsafe impl Send for InnerBuffer {}
69// SAFETY: Heap-allocated memory buffer can be used from any thread
70unsafe impl Sync for InnerBuffer {}
71
72impl Default for InnerBuffer {
73    #[inline(always)]
74    fn default() -> Self {
75        EMPTY_SHARED_ALIGNED_BUFFER.inner.clone()
76    }
77}
78
79impl Clone for InnerBuffer {
80    #[inline(always)]
81    fn clone(&self) -> Self {
82        self.strong_count_ref().fetch_add(1, Ordering::AcqRel);
83
84        Self {
85            buffer: self.buffer,
86            capacity: self.capacity,
87            len: self.len,
88        }
89    }
90}
91
92impl Drop for InnerBuffer {
93    #[inline(always)]
94    fn drop(&mut self) {
95        if self.strong_count_ref().fetch_sub(1, Ordering::AcqRel) == 1 {
96            // SAFETY: Created from `Box` in constructor
97            let _ = unsafe {
98                Box::from_non_null(NonNull::slice_from_raw_parts(
99                    self.buffer,
100                    1 + (self.capacity as usize).div_ceil(size_of::<u128>()),
101                ))
102            };
103        }
104    }
105}
106
107impl InnerBuffer {
108    /// Allocates a new buffer + one `u128` worth of memory at the beginning for
109    /// `strong_count` in case it is later converted to [`SharedAlignedBuffer`].
110    ///
111    /// `strong_count` field is automatically initialized as `1`.
112    #[inline(always)]
113    fn allocate(capacity: u32) -> Self {
114        let buffer = Box::into_non_null(Box::<[u128]>::new_uninit_slice(
115            1 + (capacity as usize).div_ceil(size_of::<u128>()),
116        ));
117        // SAFETY: The first bytes are allocated for `strong_count`, which is a correctly aligned
118        // copy type
119        unsafe { buffer.cast::<AtomicU32>().write(AtomicU32::new(1)) };
120        Self {
121            buffer: buffer.cast::<MaybeUninit<u128>>(),
122            capacity,
123            len: 0,
124        }
125    }
126
127    #[inline(always)]
128    fn resize(&mut self, capacity: u32) {
129        // SAFETY: Non-null correctly aligned pointer, correct size
130        let layout = Layout::for_value(unsafe {
131            slice::from_raw_parts(
132                self.buffer.as_ptr(),
133                1 + (self.capacity as usize).div_ceil(size_of::<u128>()),
134            )
135        });
136
137        // `size_of::<u128>()` is added because the first bytes are allocated for `strong_count`
138        let new_size = size_of::<u128>() + (capacity as usize).next_multiple_of(layout.align());
139
140        // SAFETY: Allocated with global allocator, correct layout, non-zero size that is a
141        // multiple of alignment
142        let new_ptr = unsafe {
143            realloc(self.buffer.as_ptr().cast::<u8>(), layout, new_size).cast::<MaybeUninit<u128>>()
144        };
145        let Some(new_ptr) = NonNull::new(new_ptr) else {
146            panic!("Realloc from {} to {new_size} has failed", self.capacity());
147        };
148
149        self.buffer = new_ptr;
150        self.capacity = capacity;
151    }
152
153    #[inline(always)]
154    const fn len(&self) -> u32 {
155        self.len
156    }
157
158    /// `len` bytes must be initialized
159    #[inline(always)]
160    unsafe fn set_len(&mut self, len: u32) {
161        debug_assert!(
162            len <= self.capacity(),
163            "Too many bytes {} > {}",
164            len,
165            self.capacity()
166        );
167        self.len = len;
168    }
169
170    #[inline(always)]
171    const fn capacity(&self) -> u32 {
172        self.capacity
173    }
174
175    #[inline(always)]
176    const fn strong_count_ref(&self) -> &AtomicU32 {
177        // SAFETY: The first bytes are allocated for `strong_count`, which is a correctly aligned
178        // copy type initialized in the constructor
179        unsafe { self.buffer.as_ptr().cast::<AtomicU32>().as_ref_unchecked() }
180    }
181
182    #[inline(always)]
183    const fn as_slice(&self) -> &[u8] {
184        let len = self.len() as usize;
185        // SAFETY: Not null and length is a protected invariant of the implementation
186        unsafe { slice::from_raw_parts(self.as_ptr(), len) }
187    }
188
189    #[inline(always)]
190    const fn as_mut_slice(&mut self) -> &mut [u8] {
191        let len = self.len() as usize;
192        // SAFETY: Not null and length is a protected invariant of the implementation
193        unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), len) }
194    }
195
196    #[inline(always)]
197    const fn as_ptr(&self) -> *const u8 {
198        // SAFETY: Constructor allocates the first element for `strong_count`
199        unsafe { self.buffer.as_ptr().cast_const().add(1).cast::<u8>() }
200    }
201
202    #[inline(always)]
203    const fn as_mut_ptr(&mut self) -> *mut u8 {
204        // SAFETY: Constructor allocates the first element for `strong_count`
205        unsafe { self.buffer.as_ptr().add(1).cast::<u8>() }
206    }
207}
208
209/// Owned aligned buffer for executor purposes.
210///
211/// See [`SharedAlignedBuffer`] for a version that can be cheaply cloned while reusing the original
212/// allocation.
213///
214/// Data is aligned to 16 bytes (128 bits), which is the largest alignment required by primitive
215/// types and by extension any type that implements `TrivialType`/`IoType`.
216#[derive(Debug)]
217pub struct OwnedAlignedBuffer {
218    inner: InnerBuffer,
219}
220
221impl Deref for OwnedAlignedBuffer {
222    type Target = [u8];
223
224    #[inline(always)]
225    fn deref(&self) -> &Self::Target {
226        self.as_slice()
227    }
228}
229
230impl DerefMut for OwnedAlignedBuffer {
231    #[inline(always)]
232    fn deref_mut(&mut self) -> &mut Self::Target {
233        self.as_mut_slice()
234    }
235}
236
237// SAFETY: Heap-allocated data structure, points to the same memory if moved
238unsafe impl StableDeref for OwnedAlignedBuffer {}
239
240impl Clone for OwnedAlignedBuffer {
241    #[inline(always)]
242    fn clone(&self) -> Self {
243        let mut new_instance = Self::with_capacity(self.capacity());
244        new_instance.copy_from_slice(self.as_slice());
245        new_instance
246    }
247}
248
249impl OwnedAlignedBuffer {
250    /// Create a new instance with at least specified capacity.
251    ///
252    /// NOTE: Actual capacity might be larger due to alignment requirements.
253    #[inline(always)]
254    pub fn with_capacity(capacity: u32) -> Self {
255        Self {
256            inner: InnerBuffer::allocate(capacity),
257        }
258    }
259
260    /// Create a new instance from provided bytes.
261    ///
262    /// # Panics
263    /// If `bytes.len()` doesn't fit into `u32`
264    #[inline(always)]
265    pub fn from_bytes(bytes: &[u8]) -> Self {
266        let mut instance = Self::with_capacity(0);
267        instance.copy_from_slice(bytes);
268        instance
269    }
270
271    #[inline(always)]
272    pub const fn as_slice(&self) -> &[u8] {
273        self.inner.as_slice()
274    }
275
276    #[inline(always)]
277    pub const fn as_mut_slice(&mut self) -> &mut [u8] {
278        self.inner.as_mut_slice()
279    }
280
281    #[inline(always)]
282    pub const fn as_ptr(&self) -> *const u8 {
283        self.inner.as_ptr()
284    }
285
286    #[inline(always)]
287    pub const fn as_mut_ptr(&mut self) -> *mut u8 {
288        self.inner.as_mut_ptr()
289    }
290
291    #[inline(always)]
292    pub fn into_shared(self) -> SharedAlignedBuffer {
293        SharedAlignedBuffer { inner: self.inner }
294    }
295
296    /// Ensure capacity of the buffer is at least `capacity`.
297    ///
298    /// Will re-allocate if necessary.
299    #[inline(always)]
300    pub fn ensure_capacity(&mut self, capacity: u32) {
301        if capacity > self.capacity() {
302            self.inner.resize(capacity)
303        }
304    }
305
306    /// Will re-allocate if capacity is not enough to store provided bytes.
307    ///
308    /// # Panics
309    /// If `bytes.len()` doesn't fit into `u32`
310    #[inline(always)]
311    pub fn copy_from_slice(&mut self, bytes: &[u8]) {
312        let Ok(len) = u32::try_from(bytes.len()) else {
313            panic!("Too many bytes {}", bytes.len());
314        };
315
316        if len > self.capacity() {
317            self.inner
318                .resize(len.max(self.capacity().saturating_mul(2)));
319        }
320
321        // SAFETY: Sufficient capacity guaranteed above, natural alignment of bytes is 1 for input
322        // and output, non-overlapping allocations guaranteed by the type system
323        unsafe {
324            self.as_mut_ptr()
325                .copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
326
327            self.inner.set_len(len);
328        }
329    }
330
331    /// Will re-allocate if capacity is not enough to store provided bytes.
332    ///
333    /// Returns `false` if `self.len() + bytes.len()` doesn't fit into `u32`.
334    #[inline(always)]
335    #[must_use]
336    pub fn append(&mut self, bytes: &[u8]) -> bool {
337        let Ok(len) = u32::try_from(bytes.len()) else {
338            return false;
339        };
340
341        let Some(new_len) = self.len().checked_add(len) else {
342            return false;
343        };
344
345        if new_len > self.capacity() {
346            self.inner
347                .resize(new_len.max(self.capacity().saturating_mul(2)));
348        }
349
350        // SAFETY: Sufficient capacity guaranteed above, natural alignment of bytes is 1 for input
351        // and output, non-overlapping allocations guaranteed by the type system
352        unsafe {
353            self.as_mut_ptr()
354                .add(self.len() as usize)
355                .copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
356
357            self.inner.set_len(new_len);
358        }
359
360        true
361    }
362
363    #[inline(always)]
364    pub const fn is_empty(&self) -> bool {
365        self.inner.len() == 0
366    }
367
368    #[inline(always)]
369    pub const fn len(&self) -> u32 {
370        self.inner.len()
371    }
372
373    #[inline(always)]
374    pub const fn capacity(&self) -> u32 {
375        self.inner.capacity()
376    }
377
378    /// Set the length of the useful data to a specified value.
379    ///
380    /// # Safety
381    /// There must be `new_len` bytes initialized in the buffer.
382    ///
383    /// # Panics
384    /// If `bytes.len()` doesn't fit into `u32`
385    #[inline(always)]
386    pub unsafe fn set_len(&mut self, new_len: u32) {
387        // SAFETY: Guaranteed by method contract
388        unsafe {
389            self.inner.set_len(new_len);
390        }
391    }
392}
393
394/// Shared aligned buffer for executor purposes.
395///
396/// See [`OwnedAlignedBuffer`] for a version that can be mutated.
397///
398/// Data is aligned to 16 bytes (128 bits), which is the largest alignment required by primitive
399/// types and by extension any type that implements `TrivialType`/`IoType`.
400///
401/// NOTE: Counter for the number of shared instances is `u32` and will wrap around if exceeded
402/// breaking internal invariants (which is extremely unlikely, but still).
403#[derive(Debug, Default, Clone)]
404pub struct SharedAlignedBuffer {
405    inner: InnerBuffer,
406}
407
408impl Deref for SharedAlignedBuffer {
409    type Target = [u8];
410
411    #[inline(always)]
412    fn deref(&self) -> &Self::Target {
413        self.as_slice()
414    }
415}
416
417// SAFETY: Heap-allocated data structure, points to the same memory if moved
418unsafe impl StableDeref for SharedAlignedBuffer {}
419// SAFETY: Inner buffer is exactly the same and points to the same memory after clone
420unsafe impl CloneStableDeref for SharedAlignedBuffer {}
421// SAFETY: Inner buffer is exactly the same and points to the same memory after clone
422unsafe impl CloneableCart for SharedAlignedBuffer {}
423
424impl SharedAlignedBuffer {
425    /// Static reference to an empty buffer
426    #[inline(always)]
427    pub const fn empty_ref() -> &'static Self {
428        &EMPTY_SHARED_ALIGNED_BUFFER
429    }
430
431    /// Create a new instance from provided bytes.
432    ///
433    /// # Panics
434    /// If `bytes.len()` doesn't fit into `u32`
435    #[inline(always)]
436    pub fn from_bytes(bytes: &[u8]) -> Self {
437        OwnedAlignedBuffer::from_bytes(bytes).into_shared()
438    }
439
440    /// Convert into owned buffer.
441    ///
442    /// If this is the last shared instance, then allocation will be reused, otherwise a new
443    /// allocation will be created.
444    ///
445    /// Returns `None` if there exit other shared instances.
446    #[inline(always)]
447    pub fn into_owned(self) -> OwnedAlignedBuffer {
448        if self.inner.strong_count_ref().load(Ordering::Acquire) == 1 {
449            OwnedAlignedBuffer { inner: self.inner }
450        } else {
451            OwnedAlignedBuffer::from_bytes(self.as_slice())
452        }
453    }
454
455    #[inline(always)]
456    pub const fn as_slice(&self) -> &[u8] {
457        self.inner.as_slice()
458    }
459
460    #[inline(always)]
461    pub const fn as_ptr(&self) -> *const u8 {
462        self.inner.as_ptr()
463    }
464
465    #[inline(always)]
466    pub const fn is_empty(&self) -> bool {
467        self.inner.len() == 0
468    }
469
470    #[inline(always)]
471    pub const fn len(&self) -> u32 {
472        self.inner.len()
473    }
474}