Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "bitcode"
authors = [ "Cai Bear", "Finn Bear" ]
version = "0.6.9"
edition = "2021"
rust-version = "1.70"
rust-version = "1.88"
license = "MIT OR Apache-2.0"
repository = "https://github.com/SoftbearStudios/bitcode"
description = "bitcode is a bitwise binary serializer"
Expand Down
107 changes: 64 additions & 43 deletions src/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View};
use crate::consume::consume_byte_arrays;
use crate::fast::{FastSlice, NextUnchecked, PushUnchecked, VecImpl};
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::num::NonZeroUsize;

#[derive(Default)]
Expand All @@ -20,57 +19,76 @@ impl Encoder<f32> for F32Encoder {
}
}

/// [`bytemuck`] doesn't implement [`MaybeUninit`] casts. Slightly different from
/// [`bytemuck::cast_slice_mut`] in that it will truncate partial elements instead of panicking.
fn chunks_uninit<A, B>(m: &mut [MaybeUninit<A>]) -> &mut [MaybeUninit<B>] {
use core::mem::{align_of, size_of};
assert_eq!(align_of::<B>(), align_of::<A>());
assert_eq!(0, size_of::<B>() % size_of::<A>());
let divisor = size_of::<B>() / size_of::<A>();
// Safety: `align_of<B> == align_of<A>` and `size_of<B>()` is a multiple of `size_of<A>()`
unsafe {
core::slice::from_raw_parts_mut(m.as_mut_ptr() as *mut MaybeUninit<B>, m.len() / divisor)
pub const CHUNK_SIZE: usize = 16;

// CHUNK_SIZE = 16 with #[inline(never)] seems to be the sweet spot for both x86_64 and x86_64 target-cpu=native.
// Larger and it starts loading `chunk` multiple times. Smaller and it doesn't vectorize as well.
// Removing #[inline(never)] makes this autovectorization inconsistent.
// Safety: Same as `encode_tail`.
#[inline(never)]
unsafe fn encode_chunk(chunk: &[f32; CHUNK_SIZE], mantissa: *mut [u8; 3], sign_exp: *mut u8) {
encode_tail(chunk, mantissa, sign_exp)
}

// Safety:
// `mantissa` must have `tail.len() * 3 + 1 (if tail not empty)` bytes valid for writes.
// `sign_exp` must have `tail.len()` bytes valid for writes (maybe aliasing with mantissa so ptrs are required).
unsafe fn encode_tail(tail: &[f32], mantissa: *mut [u8; 3], sign_exp: *mut u8) {
for (i, &f) in tail.iter().enumerate() {
let little_endian = f.to_le_bytes();
// Writing overlapping 4 byte mantissas in a separate loops from sign_exp is 70% faster
// than splitting chunks of 4 f32 with bitshifts and ~3.3x faster the scalar solution.
// Safety: `mantissa` has `tail.len() * 3 + 1 (tail is not empty)` bytes valid for writes.
*(mantissa.add(i) as *mut [u8; 4]) = little_endian;
}

for (i, &f) in tail.iter().enumerate() {
// Safety: `sign_exp` has `tail.len()` bytes valid for writes (maybe aliasing with mantissa so ptrs are used).
*sign_exp.add(i) = f.to_le_bytes()[3];
}
}

impl Buffer for F32Encoder {
fn collect_into(&mut self, out: &mut Vec<u8>) {
let floats = self.0.as_slice();
let Some(first_float) = floats.get(0).copied() else {
return;
};
let byte_len = core::mem::size_of_val(floats);
out.reserve(byte_len);
let uninit = &mut out.spare_capacity_mut()[..byte_len];

let (mantissa, sign_exp) = uninit.split_at_mut(floats.len() * 3);
let mantissa: &mut [MaybeUninit<[u8; 3]>] = chunks_uninit(mantissa);

// TODO SIMD version with PSHUFB.
const CHUNK_SIZE: usize = 4;
let chunks_len = floats.len() / CHUNK_SIZE;
let chunks_floats = chunks_len * CHUNK_SIZE;
let chunks: &[[u32; CHUNK_SIZE]] = bytemuck::cast_slice(&floats[..chunks_floats]);
let mantissa_chunks: &mut [MaybeUninit<[[u8; 4]; 3]>] = chunks_uninit(mantissa);
let sign_exp_chunks: &mut [MaybeUninit<[u8; 4]>] = chunks_uninit(sign_exp);

for ci in 0..chunks_len {
let [a, b, c, d] = chunks[ci];

let m0 = a & 0xFF_FF_FF | (b << 24);
let m1 = ((b >> 8) & 0xFF_FF) | (c << 16);
let m2 = (c >> 16) & 0xFF | (d << 8);
let mantissa_chunk = &mut mantissa_chunks[ci];
mantissa_chunk.write([m0.to_le_bytes(), m1.to_le_bytes(), m2.to_le_bytes()]);

let se = (a >> 24) | ((b >> 24) << 8) | ((c >> 24) << 16) | ((d >> 24) << 24);
let sign_exp_chunk = &mut sign_exp_chunks[ci];
sign_exp_chunk.write(se.to_le_bytes());
let mantissa_start = out.spare_capacity_mut().as_mut_ptr() as *mut [u8; 3];

// Safety: we've allocated floats.len() * 4 bytes past the end out of out.
// Therefore, the pointer at byte floats.len() * 3 is not past the end of the allocation.
let sign_exp_chunks = unsafe { mantissa_start.add(floats.len()) as *mut [u8; CHUNK_SIZE] };
let mantissa_chunks = mantissa_start as *mut [[u8; 3]; CHUNK_SIZE];

let (chunks, tail) = floats.as_chunks::<CHUNK_SIZE>();
for (i, chunk) in chunks.iter().enumerate() {
// Safety:
// `mantissa`: We've allocated floats.len() * 4 bytes, so we have `floats.len() * 4` bytes are valid for writes.
// `floats.len() * 3 + 1 (if tail not empty)` is always <= floats.len() * 4.
// `sign_exp`: We've allocated floats.len() * 4 bytes so the pointer starting at floats.len() * 3 has floats.len() valid bytes.
// We keep everying as raw pointers so the aliasing with mantissa's last byte is valid.
unsafe {
let mantissa = mantissa_chunks.add(i) as *mut [u8; 3];
let sign_exp = sign_exp_chunks.add(i) as *mut u8;
encode_chunk(chunk, mantissa, sign_exp);
}
}

for i in chunks_floats..floats.len() {
let [m @ .., se] = floats[i].to_le_bytes();
mantissa[i].write(m);
sign_exp[i].write(se);
// Safety: same as above call to encode_chunk.
unsafe {
let mantissa = mantissa_chunks.add(chunks.len()) as *mut [u8; 3];
let sign_exp = sign_exp_chunks.add(chunks.len()) as *mut u8;
encode_tail(tail, mantissa, sign_exp);
}

// Fix up the sign_exp killed by the last 3 byte mantissa writing 4 bytes (technically only required if !chunks.is_empty()).
// Safety: sign_exp_chunks is not past the end of the allocation.
// Additionally floats.len() * 3 < floats.len() * 4 because we've ensured
// floats isn't empty, so this 1 byte u8 pointer is inside the allocation.
unsafe { *(sign_exp_chunks as *mut u8) = first_float.to_le_bytes()[3] };

// Safety: We just initialized these elements in the loops above.
unsafe { out.set_len(out.len() + byte_len) };
self.0.clear();
Expand Down Expand Up @@ -124,12 +142,15 @@ mod tests {

#[test]
fn test() {
for i in 1..16 {
// CHUNK_SIZE * 3 exhibits all sizes of the tail, the first chunk and the second chunk.
for i in 0..CHUNK_SIZE * 3 {
let mut rng = ChaCha20Rng::from_seed(Default::default());
let floats: Vec<_> = (0..i).map(|_| f32::from_bits(rng.gen())).collect();

let mut encoder = F32Encoder::default();
encoder.reserve(NonZeroUsize::new(floats.len()).unwrap());
if let Some(additional) = NonZeroUsize::new(floats.len()) {
encoder.reserve(additional);
}
for &f in &floats {
encoder.encode(&f);
}
Expand Down
Loading