diff --git a/Cargo.toml b/Cargo.toml index 46cc43e..7b1197f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ name = "bitcode" authors = [ "Cai Bear", "Finn Bear" ] version = "0.6.9" edition = "2021" -rust-version = "1.70" +rust-version = "1.88" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode" description = "bitcode is a bitwise binary serializer" diff --git a/src/f32.rs b/src/f32.rs index aa56ad4..90ca132 100644 --- a/src/f32.rs +++ b/src/f32.rs @@ -2,7 +2,6 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; use crate::consume::consume_byte_arrays; use crate::fast::{FastSlice, NextUnchecked, PushUnchecked, VecImpl}; use alloc::vec::Vec; -use core::mem::MaybeUninit; use core::num::NonZeroUsize; #[derive(Default)] @@ -20,57 +19,76 @@ impl Encoder for F32Encoder { } } -/// [`bytemuck`] doesn't implement [`MaybeUninit`] casts. Slightly different from -/// [`bytemuck::cast_slice_mut`] in that it will truncate partial elements instead of panicking. -fn chunks_uninit(m: &mut [MaybeUninit]) -> &mut [MaybeUninit] { - use core::mem::{align_of, size_of}; - assert_eq!(align_of::(), align_of::()); - assert_eq!(0, size_of::() % size_of::()); - let divisor = size_of::() / size_of::(); - // Safety: `align_of == align_of` and `size_of()` is a multiple of `size_of()` - unsafe { - core::slice::from_raw_parts_mut(m.as_mut_ptr() as *mut MaybeUninit, m.len() / divisor) +pub const CHUNK_SIZE: usize = 16; + +// CHUNK_SIZE = 16 with #[inline(never)] seems to be the sweet spot for both x86_64 and x86_64 target-cpu=native. +// Larger and it starts loading `chunk` multiple times. Smaller and it doesn't vectorize as well. +// Removing #[inline(never)] makes this autovectorization inconsistent. +// Safety: Same as `encode_tail`. +#[inline(never)] +unsafe fn encode_chunk(chunk: &[f32; CHUNK_SIZE], mantissa: *mut [u8; 3], sign_exp: *mut u8) { + encode_tail(chunk, mantissa, sign_exp) +} + +// Safety: +// `mantissa` must have `tail.len() * 3 + 1 (if tail not empty)` bytes valid for writes. +// `sign_exp` must have `tail.len()` bytes valid for writes (maybe aliasing with mantissa so ptrs are required). +unsafe fn encode_tail(tail: &[f32], mantissa: *mut [u8; 3], sign_exp: *mut u8) { + for (i, &f) in tail.iter().enumerate() { + let little_endian = f.to_le_bytes(); + // Writing overlapping 4 byte mantissas in a separate loops from sign_exp is 70% faster + // than splitting chunks of 4 f32 with bitshifts and ~3.3x faster the scalar solution. + // Safety: `mantissa` has `tail.len() * 3 + 1 (tail is not empty)` bytes valid for writes. + *(mantissa.add(i) as *mut [u8; 4]) = little_endian; + } + + for (i, &f) in tail.iter().enumerate() { + // Safety: `sign_exp` has `tail.len()` bytes valid for writes (maybe aliasing with mantissa so ptrs are used). + *sign_exp.add(i) = f.to_le_bytes()[3]; } } impl Buffer for F32Encoder { fn collect_into(&mut self, out: &mut Vec) { let floats = self.0.as_slice(); + let Some(first_float) = floats.get(0).copied() else { + return; + }; let byte_len = core::mem::size_of_val(floats); out.reserve(byte_len); - let uninit = &mut out.spare_capacity_mut()[..byte_len]; - - let (mantissa, sign_exp) = uninit.split_at_mut(floats.len() * 3); - let mantissa: &mut [MaybeUninit<[u8; 3]>] = chunks_uninit(mantissa); - - // TODO SIMD version with PSHUFB. - const CHUNK_SIZE: usize = 4; - let chunks_len = floats.len() / CHUNK_SIZE; - let chunks_floats = chunks_len * CHUNK_SIZE; - let chunks: &[[u32; CHUNK_SIZE]] = bytemuck::cast_slice(&floats[..chunks_floats]); - let mantissa_chunks: &mut [MaybeUninit<[[u8; 4]; 3]>] = chunks_uninit(mantissa); - let sign_exp_chunks: &mut [MaybeUninit<[u8; 4]>] = chunks_uninit(sign_exp); - - for ci in 0..chunks_len { - let [a, b, c, d] = chunks[ci]; - - let m0 = a & 0xFF_FF_FF | (b << 24); - let m1 = ((b >> 8) & 0xFF_FF) | (c << 16); - let m2 = (c >> 16) & 0xFF | (d << 8); - let mantissa_chunk = &mut mantissa_chunks[ci]; - mantissa_chunk.write([m0.to_le_bytes(), m1.to_le_bytes(), m2.to_le_bytes()]); - - let se = (a >> 24) | ((b >> 24) << 8) | ((c >> 24) << 16) | ((d >> 24) << 24); - let sign_exp_chunk = &mut sign_exp_chunks[ci]; - sign_exp_chunk.write(se.to_le_bytes()); + let mantissa_start = out.spare_capacity_mut().as_mut_ptr() as *mut [u8; 3]; + + // Safety: we've allocated floats.len() * 4 bytes past the end out of out. + // Therefore, the pointer at byte floats.len() * 3 is not past the end of the allocation. + let sign_exp_chunks = unsafe { mantissa_start.add(floats.len()) as *mut [u8; CHUNK_SIZE] }; + let mantissa_chunks = mantissa_start as *mut [[u8; 3]; CHUNK_SIZE]; + + let (chunks, tail) = floats.as_chunks::(); + for (i, chunk) in chunks.iter().enumerate() { + // Safety: + // `mantissa`: We've allocated floats.len() * 4 bytes, so we have `floats.len() * 4` bytes are valid for writes. + // `floats.len() * 3 + 1 (if tail not empty)` is always <= floats.len() * 4. + // `sign_exp`: We've allocated floats.len() * 4 bytes so the pointer starting at floats.len() * 3 has floats.len() valid bytes. + // We keep everying as raw pointers so the aliasing with mantissa's last byte is valid. + unsafe { + let mantissa = mantissa_chunks.add(i) as *mut [u8; 3]; + let sign_exp = sign_exp_chunks.add(i) as *mut u8; + encode_chunk(chunk, mantissa, sign_exp); + } } - - for i in chunks_floats..floats.len() { - let [m @ .., se] = floats[i].to_le_bytes(); - mantissa[i].write(m); - sign_exp[i].write(se); + // Safety: same as above call to encode_chunk. + unsafe { + let mantissa = mantissa_chunks.add(chunks.len()) as *mut [u8; 3]; + let sign_exp = sign_exp_chunks.add(chunks.len()) as *mut u8; + encode_tail(tail, mantissa, sign_exp); } + // Fix up the sign_exp killed by the last 3 byte mantissa writing 4 bytes (technically only required if !chunks.is_empty()). + // Safety: sign_exp_chunks is not past the end of the allocation. + // Additionally floats.len() * 3 < floats.len() * 4 because we've ensured + // floats isn't empty, so this 1 byte u8 pointer is inside the allocation. + unsafe { *(sign_exp_chunks as *mut u8) = first_float.to_le_bytes()[3] }; + // Safety: We just initialized these elements in the loops above. unsafe { out.set_len(out.len() + byte_len) }; self.0.clear(); @@ -124,12 +142,15 @@ mod tests { #[test] fn test() { - for i in 1..16 { + // CHUNK_SIZE * 3 exhibits all sizes of the tail, the first chunk and the second chunk. + for i in 0..CHUNK_SIZE * 3 { let mut rng = ChaCha20Rng::from_seed(Default::default()); let floats: Vec<_> = (0..i).map(|_| f32::from_bits(rng.gen())).collect(); let mut encoder = F32Encoder::default(); - encoder.reserve(NonZeroUsize::new(floats.len()).unwrap()); + if let Some(additional) = NonZeroUsize::new(floats.len()) { + encoder.reserve(additional); + } for &f in &floats { encoder.encode(&f); }