ab_aligned_buffer/
lib.rs

1#![feature(box_vec_non_null, pointer_is_aligned_to, ptr_as_ref_unchecked)]
2#![no_std]
3
4#[cfg(test)]
5mod tests;
6
7extern crate alloc;
8
9use ab_io_type::MAX_ALIGNMENT;
10use alloc::alloc::realloc;
11use alloc::boxed::Box;
12use core::alloc::Layout;
13use core::mem::MaybeUninit;
14use core::ops::{Deref, DerefMut};
15use core::ptr::NonNull;
16use core::slice;
17use core::sync::atomic::{AtomicU32, Ordering};
18use stable_deref_trait::{CloneStableDeref, StableDeref};
19use yoke::CloneableCart;
20
21const _: () = {
22    assert!(
23        align_of::<u128>() == size_of::<u128>(),
24        "Size and alignment are both 16 bytes"
25    );
26    assert!(
27        align_of::<u128>() == MAX_ALIGNMENT as usize,
28        "Alignment of u128 is max alignment"
29    );
30    assert!(size_of::<u128>() >= size_of::<AtomicU32>());
31    assert!(align_of::<u128>() >= align_of::<AtomicU32>());
32};
33
34#[repr(C, align(16))]
35struct ConstInnerBuffer {
36    strong_count: AtomicU32,
37}
38
39const _: () = {
40    assert!(align_of::<ConstInnerBuffer>() == align_of::<u128>());
41    assert!(size_of::<ConstInnerBuffer>() == size_of::<u128>());
42};
43
44static EMPTY_SHARED_ALIGNED_BUFFER: SharedAlignedBuffer = SharedAlignedBuffer {
45    inner: InnerBuffer {
46        buffer: NonNull::from_ref({
47            static BUFFER: MaybeUninit<ConstInnerBuffer> = MaybeUninit::new(ConstInnerBuffer {
48                strong_count: AtomicU32::new(1),
49            });
50
51            &BUFFER
52        })
53        .cast::<MaybeUninit<u128>>(),
54        capacity: 0,
55        len: 0,
56    },
57};
58
59#[derive(Debug)]
60struct InnerBuffer {
61    // The first bytes are allocated for `strong_count`
62    buffer: NonNull<MaybeUninit<u128>>,
63    capacity: u32,
64    len: u32,
65}
66
67// SAFETY: Heap-allocated memory buffer can be used from any thread
68unsafe impl Send for InnerBuffer {}
69// SAFETY: Heap-allocated memory buffer can be used from any thread
70unsafe impl Sync for InnerBuffer {}
71
72impl Default for InnerBuffer {
73    #[inline(always)]
74    fn default() -> Self {
75        EMPTY_SHARED_ALIGNED_BUFFER.inner.clone()
76    }
77}
78
79impl Clone for InnerBuffer {
80    #[inline(always)]
81    fn clone(&self) -> Self {
82        self.strong_count_ref().fetch_add(1, Ordering::AcqRel);
83
84        Self {
85            buffer: self.buffer,
86            capacity: self.capacity,
87            len: self.len,
88        }
89    }
90}
91
92impl Drop for InnerBuffer {
93    #[inline(always)]
94    fn drop(&mut self) {
95        if self.strong_count_ref().fetch_sub(1, Ordering::AcqRel) == 1 {
96            // SAFETY: Created from `Box` in constructor
97            let _ = unsafe {
98                Box::from_non_null(NonNull::slice_from_raw_parts(
99                    self.buffer,
100                    1 + (self.capacity as usize).div_ceil(size_of::<u128>()),
101                ))
102            };
103        }
104    }
105}
106
107impl InnerBuffer {
108    /// Allocates a new buffer + one `u128` worth of memory at the beginning for
109    /// `strong_count` in case it is later converted to [`SharedAlignedBuffer`].
110    ///
111    /// `strong_count` field is automatically initialized as `1`.
112    #[inline(always)]
113    fn allocate(capacity: u32) -> Self {
114        let buffer = Box::into_non_null(Box::<[u128]>::new_uninit_slice(
115            1 + (capacity as usize).div_ceil(size_of::<u128>()),
116        ));
117        // SAFETY: The first bytes are allocated for `strong_count`, which is a correctly aligned
118        // copy type
119        unsafe { buffer.cast::<AtomicU32>().write(AtomicU32::new(1)) };
120        Self {
121            buffer: buffer.cast::<MaybeUninit<u128>>(),
122            capacity,
123            len: 0,
124        }
125    }
126
127    #[inline(always)]
128    fn resize(&mut self, capacity: u32) {
129        // SAFETY: Non-null correctly aligned pointer, correct size
130        let layout = Layout::for_value(unsafe {
131            slice::from_raw_parts(
132                self.buffer.as_ptr(),
133                1 + (self.capacity as usize).div_ceil(size_of::<u128>()),
134            )
135        });
136
137        // `size_of::<u128>()` is added because the first bytes are allocated for `strong_count`
138        let new_size = size_of::<u128>() + (capacity as usize).next_multiple_of(layout.align());
139
140        // SAFETY: Allocated with global allocator, correct layout, non-zero size that is a
141        // multiple of alignment
142        let new_ptr = unsafe {
143            realloc(self.buffer.as_ptr().cast::<u8>(), layout, new_size).cast::<MaybeUninit<u128>>()
144        };
145        let Some(new_ptr) = NonNull::new(new_ptr) else {
146            panic!("Realloc from {} to {new_size} have failed", self.capacity());
147        };
148
149        self.buffer = new_ptr;
150        self.capacity = capacity;
151    }
152
153    #[inline(always)]
154    fn len(&self) -> u32 {
155        self.len
156    }
157
158    /// `len` bytes must be initialized
159    #[inline(always)]
160    unsafe fn set_len(&mut self, len: u32) {
161        self.len = len;
162    }
163
164    #[inline(always)]
165    fn capacity(&self) -> u32 {
166        self.capacity
167    }
168
169    #[inline(always)]
170    fn strong_count_ref(&self) -> &AtomicU32 {
171        // SAFETY: The first bytes are allocated for `strong_count`, which is a correctly aligned
172        // copy type initialized in the constructor
173        unsafe { self.buffer.as_ptr().cast::<AtomicU32>().as_ref_unchecked() }
174    }
175
176    #[inline(always)]
177    fn as_slice(&self) -> &[u8] {
178        let len = self.len() as usize;
179        // SAFETY: Not null and length is a protected invariant of the implementation
180        unsafe { slice::from_raw_parts(self.as_ptr(), len) }
181    }
182
183    #[inline(always)]
184    fn as_mut_slice(&mut self) -> &mut [u8] {
185        let len = self.len() as usize;
186        // SAFETY: Not null and length is a protected invariant of the implementation
187        unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), len) }
188    }
189
190    #[inline(always)]
191    fn as_ptr(&self) -> *const u8 {
192        // SAFETY: Constructor allocates the first element for `strong_count`
193        unsafe { self.buffer.as_ptr().cast_const().add(1).cast::<u8>() }
194    }
195
196    #[inline(always)]
197    fn as_mut_ptr(&mut self) -> *mut u8 {
198        // SAFETY: Constructor allocates the first element for `strong_count`
199        unsafe { self.buffer.as_ptr().add(1).cast::<u8>() }
200    }
201}
202
203/// Owned aligned buffer for executor purposes.
204///
205/// See [`SharedAlignedBuffer`] for a version that can be cheaply cloned, while reusing the original
206/// allocation.
207///
208/// Data is aligned to 16 bytes (128 bits), which is the largest alignment required by primitive
209/// types and by extension any type that implements `TrivialType`/`IoType`.
210#[derive(Debug)]
211pub struct OwnedAlignedBuffer {
212    inner: InnerBuffer,
213}
214
215impl Deref for OwnedAlignedBuffer {
216    type Target = [u8];
217
218    #[inline(always)]
219    fn deref(&self) -> &Self::Target {
220        self.as_slice()
221    }
222}
223
224impl DerefMut for OwnedAlignedBuffer {
225    #[inline(always)]
226    fn deref_mut(&mut self) -> &mut Self::Target {
227        self.as_mut_slice()
228    }
229}
230
231// SAFETY: Heap-allocated data structure, points to the same memory if moved
232unsafe impl StableDeref for OwnedAlignedBuffer {}
233
234impl Clone for OwnedAlignedBuffer {
235    #[inline(always)]
236    fn clone(&self) -> Self {
237        let mut new_instance = Self::with_capacity(self.capacity());
238        new_instance.copy_from_slice(self.as_slice());
239        new_instance
240    }
241}
242
243impl OwnedAlignedBuffer {
244    /// Create a new instance with at least specified capacity.
245    ///
246    /// NOTE: Actual capacity might be larger due to alignment requirements.
247    #[inline(always)]
248    pub fn with_capacity(capacity: u32) -> Self {
249        Self {
250            inner: InnerBuffer::allocate(capacity),
251        }
252    }
253
254    /// Create a new instance from provided bytes.
255    ///
256    /// # Panics
257    /// If `bytes.len()` doesn't fit into `u32`
258    #[inline(always)]
259    pub fn from_bytes(bytes: &[u8]) -> Self {
260        let mut instance = Self::with_capacity(0);
261        instance.copy_from_slice(bytes);
262        instance
263    }
264
265    #[inline(always)]
266    pub fn as_slice(&self) -> &[u8] {
267        self.inner.as_slice()
268    }
269
270    #[inline(always)]
271    pub fn as_mut_slice(&mut self) -> &mut [u8] {
272        self.inner.as_mut_slice()
273    }
274
275    #[inline(always)]
276    pub fn as_ptr(&self) -> *const u8 {
277        self.inner.as_ptr()
278    }
279
280    #[inline(always)]
281    pub fn as_mut_ptr(&mut self) -> *mut u8 {
282        self.inner.as_mut_ptr()
283    }
284
285    #[inline(always)]
286    pub fn into_shared(self) -> SharedAlignedBuffer {
287        SharedAlignedBuffer { inner: self.inner }
288    }
289
290    /// Ensure capacity of the buffer is at least `capacity`.
291    ///
292    /// Will re-allocate if necessary.
293    #[inline(always)]
294    pub fn ensure_capacity(&mut self, capacity: u32) {
295        if capacity > self.capacity() {
296            self.inner.resize(capacity)
297        }
298    }
299
300    /// Will re-allocate if capacity is not enough to store provided bytes.
301    ///
302    /// # Panics
303    /// If `bytes.len()` doesn't fit into `u32`
304    #[inline(always)]
305    pub fn copy_from_slice(&mut self, bytes: &[u8]) {
306        let Ok(len) = u32::try_from(bytes.len()) else {
307            panic!("Too many bytes {}", bytes.len());
308        };
309
310        if len > self.capacity() {
311            self.inner
312                .resize(len.max(self.capacity().saturating_mul(2)));
313        }
314
315        // SAFETY: Sufficient capacity guaranteed above, natural alignment of bytes is 1 for input
316        // and output, non-overlapping allocations guaranteed by type system
317        unsafe {
318            self.as_mut_ptr()
319                .copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
320
321            self.inner.set_len(len);
322        }
323    }
324
325    /// Will re-allocate if capacity is not enough to store provided bytes.
326    ///
327    /// Returns `false` if `self.len() + bytes.len()` doesn't fit into `u32`.
328    #[inline(always)]
329    #[must_use]
330    pub fn append(&mut self, bytes: &[u8]) -> bool {
331        let Ok(len) = u32::try_from(bytes.len()) else {
332            return false;
333        };
334
335        let Some(new_len) = self.len().checked_add(len) else {
336            return false;
337        };
338
339        if new_len > self.capacity() {
340            self.inner
341                .resize(new_len.max(self.capacity().saturating_mul(2)));
342        }
343
344        // SAFETY: Sufficient capacity guaranteed above, natural alignment of bytes is 1 for input
345        // and output, non-overlapping allocations guaranteed by type system
346        unsafe {
347            self.as_mut_ptr()
348                .add(self.len() as usize)
349                .copy_from_nonoverlapping(bytes.as_ptr(), bytes.len());
350
351            self.inner.set_len(new_len);
352        }
353
354        true
355    }
356
357    #[inline(always)]
358    pub fn is_empty(&self) -> bool {
359        self.inner.len() == 0
360    }
361
362    #[inline(always)]
363    pub fn len(&self) -> u32 {
364        self.inner.len()
365    }
366
367    #[inline(always)]
368    pub fn capacity(&self) -> u32 {
369        self.inner.capacity()
370    }
371
372    /// Set the length of the useful data to specified value.
373    ///
374    /// # Safety
375    /// There must be `new_len` bytes initialized in the buffer.
376    ///
377    /// # Panics
378    /// If `bytes.len()` doesn't fit into `u32`
379    #[inline(always)]
380    pub unsafe fn set_len(&mut self, new_len: u32) {
381        debug_assert!(
382            new_len <= self.capacity(),
383            "Too many bytes {} > {}",
384            new_len,
385            self.capacity()
386        );
387        // SAFETY: Guaranteed by method contract
388        unsafe {
389            self.inner.set_len(new_len);
390        }
391    }
392}
393
394/// Shared aligned buffer for executor purposes.
395///
396/// See [`OwnedAlignedBuffer`] for a version that can be mutated.
397///
398/// Data is aligned to 16 bytes (128 bits), which is the largest alignment required by primitive
399/// types and by extension any type that implements `TrivialType`/`IoType`.
400///
401/// NOTE: Counter for number of shared instances is `u32` and will wrap around if exceeded breaking
402/// internal invariants (which is extremely unlikely, but still).
403#[derive(Debug, Default, Clone)]
404pub struct SharedAlignedBuffer {
405    inner: InnerBuffer,
406}
407
408impl Deref for SharedAlignedBuffer {
409    type Target = [u8];
410
411    #[inline(always)]
412    fn deref(&self) -> &Self::Target {
413        self.as_slice()
414    }
415}
416
417// SAFETY: Heap-allocated data structure, points to the same memory if moved
418unsafe impl StableDeref for SharedAlignedBuffer {}
419// SAFETY: Inner buffer is exactly the same and points to the same memory after clone
420unsafe impl CloneStableDeref for SharedAlignedBuffer {}
421// SAFETY: Inner buffer is exactly the same and points to the same memory after clone
422unsafe impl CloneableCart for SharedAlignedBuffer {}
423
424impl SharedAlignedBuffer {
425    /// Static reference to an empty buffer
426    #[inline(always)]
427    pub fn empty_ref() -> &'static Self {
428        &EMPTY_SHARED_ALIGNED_BUFFER
429    }
430
431    /// Create a new instance from provided bytes.
432    ///
433    /// # Panics
434    /// If `bytes.len()` doesn't fit into `u32`
435    #[inline(always)]
436    pub fn from_bytes(bytes: &[u8]) -> Self {
437        OwnedAlignedBuffer::from_bytes(bytes).into_shared()
438    }
439
440    /// Convert into owned buffer.
441    ///
442    /// If this is the last shared instance, then allocation will be reused, otherwise new
443    /// allocation will be created.
444    ///
445    /// Returns `None` if there exit other shared instances.
446    #[inline(always)]
447    pub fn into_owned(self) -> OwnedAlignedBuffer {
448        if self.inner.strong_count_ref().load(Ordering::Acquire) == 1 {
449            OwnedAlignedBuffer { inner: self.inner }
450        } else {
451            OwnedAlignedBuffer::from_bytes(self.as_slice())
452        }
453    }
454
455    #[inline(always)]
456    pub fn as_slice(&self) -> &[u8] {
457        self.inner.as_slice()
458    }
459
460    #[inline(always)]
461    pub fn as_ptr(&self) -> *const u8 {
462        self.inner.as_ptr()
463    }
464
465    #[inline(always)]
466    pub fn is_empty(&self) -> bool {
467        self.inner.len() == 0
468    }
469
470    #[inline(always)]
471    pub fn len(&self) -> u32 {
472        self.inner.len()
473    }
474}