Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions bitcode_derive/src/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,40 @@ enum BitcodeAttr {
BoundType(Type),
CrateName(Path),
Skip,
/// `#[bitcode(with = "LocalType")]`, shorthand for both `encode_with` and `decode_with`.
With(Type),
/// `#[bitcode(encode_with = "LocalType")]`.
EncodeWith(Type),
/// `#[bitcode(decode_with = "LocalType")]`.
DecodeWith(Type),
}

/// Parses a `#[bitcode(name = "Type")]` string literal value into a [`Type`].
fn parse_type_attr(nested: &Meta) -> Result<Type> {
match nested {
Meta::NameValue(name_value) => {
let expr = &name_value.value;
let str_lit = match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(v), ..
}) => v,
_ => return err(&expr, "expected string e.g. \"LocalType\""),
};
let value = TokenStream::from_str(&str_lit.value()).unwrap();
parse2(value).map_err(|e| error(str_lit, &format!("{e}")))
}
_ => err(&nested, "expected name value"),
}
}

impl BitcodeAttr {
fn new(nested: &Meta) -> Result<Self> {
let path = path_ident_string(nested.path(), &nested)?;
match path.as_str() {
"bound_type" => match nested {
Meta::NameValue(name_value) => {
let expr = &name_value.value;
let str_lit = match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(v), ..
}) => v,
_ => return err(&expr, "expected string e.g. \"T\""),
};

let value = TokenStream::from_str(&str_lit.value()).unwrap();
Ok(Self::BoundType(
parse2(value).map_err(|e| error(str_lit, &format!("{e}")))?,
))
}
_ => err(&nested, "expected name value"),
},
"bound_type" => Ok(Self::BoundType(parse_type_attr(nested)?)),
"with" => Ok(Self::With(parse_type_attr(nested)?)),
"encode_with" => Ok(Self::EncodeWith(parse_type_attr(nested)?)),
"decode_with" => Ok(Self::DecodeWith(parse_type_attr(nested)?)),
"crate" => match nested {
Meta::NameValue(name_value) => {
let expr = &name_value.value;
Expand Down Expand Up @@ -97,6 +108,28 @@ impl BitcodeAttr {
err(nested, "can only apply to fields")
}
}
Self::With(ty) => {
if let BitcodeAnyAttrs::Field(field) = attrs {
set_if_not_duplicate(&mut field.encode_with, Some(ty.clone()), nested)?;
set_if_not_duplicate(&mut field.decode_with, Some(ty), nested)
} else {
err(nested, "can only apply to fields")
}
}
Self::EncodeWith(ty) => {
if let BitcodeAnyAttrs::Field(field) = attrs {
set_if_not_duplicate(&mut field.encode_with, Some(ty), nested)
} else {
err(nested, "can only apply to fields")
}
}
Self::DecodeWith(ty) => {
if let BitcodeAnyAttrs::Field(field) = attrs {
set_if_not_duplicate(&mut field.decode_with, Some(ty), nested)
} else {
err(nested, "can only apply to fields")
}
}
}
}
}
Expand Down Expand Up @@ -159,6 +192,10 @@ pub struct BitcodeFieldAttrs<'a> {
parent: BitcodeDeriveOrVariantAttrs<'a>,
pub bound_type: Option<Type>,
pub skip: bool,
/// Encode this field as a different (local) type via `From`/`Into`.
pub encode_with: Option<Type>,
/// Decode this field as a different (local) type via `From`/`Into`.
pub decode_with: Option<Type>,
}
impl<'a> std::ops::Deref for BitcodeFieldAttrs<'a> {
type Target = BitcodeDeriveOrVariantAttrs<'a>;
Expand All @@ -172,6 +209,8 @@ impl<'a> BitcodeFieldAttrs<'a> {
parent,
bound_type: Default::default(),
skip: Default::default(),
encode_with: Default::default(),
decode_with: Default::default(),
};
BitcodeAnyAttrs::Field(&mut ret).parse_inner(attrs)?;
Ok(ret)
Expand Down
18 changes: 17 additions & 1 deletion bitcode_derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ impl crate::shared::Item for Item {
) -> TokenStream {
match self {
Self::Type => {
let mut de_type = replace_lifetimes(field_type, DE_LIFETIME).to_token_stream();
// `#[bitcode(decode_with = "Local")]` stores the decoder for `Local`, not the field.
let base_type = attrs.decode_with.as_ref().unwrap_or(field_type);
let mut de_type = replace_lifetimes(base_type, DE_LIFETIME).to_token_stream();
if attrs.skip {
de_type = quote! { ::core::marker::PhantomData<#de_type> };
}
Expand All @@ -61,10 +63,15 @@ impl crate::shared::Item for Item {
},
// Only used by enum variants.
Self::Decode => {
let de_type = replace_lifetimes(field_type, DE_LIFETIME);
let value = if attrs.skip {
quote! {
Default::default()
}
} else if attrs.decode_with.is_some() {
quote! {
::core::convert::Into::<#de_type>::into(self.#global_field_name.decode())
}
} else {
quote! {
self.#global_field_name.decode()
Expand All @@ -84,6 +91,11 @@ impl crate::shared::Item for Item {
quote! {{
(#target).write(Default::default());
}}
} else if attrs.decode_with.is_some() {
quote! {{
let __local = self.#global_field_name.decode();
(#target).write(::core::convert::Into::into(__local));
}}
} else {
quote! {
self.#global_field_name.decode_in_place(#target);
Expand Down Expand Up @@ -261,6 +273,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Decode {
Some(parse_quote!(Default))
}

fn with_type(&self, field_attrs: &BitcodeFieldAttrs) -> Option<Type> {
field_attrs.decode_with.clone()
}

fn derive_impl(
&self,
attrs: &BitcodeDeriveAttrs,
Expand Down
45 changes: 44 additions & 1 deletion bitcode_derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ impl crate::shared::Item for Item {
) -> TokenStream {
match self {
Self::Type => {
let mut static_type = replace_lifetimes(field_type, "static").to_token_stream();
// `#[bitcode(encode_with = "Local")]` stores the encoder for `Local`, not the field.
let base_type = attrs.encode_with.as_ref().unwrap_or(field_type);
let mut static_type = replace_lifetimes(base_type, "static").to_token_stream();
if attrs.skip {
static_type = quote! { ::core::marker::PhantomData<#static_type> };
}
Expand All @@ -48,6 +50,43 @@ impl crate::shared::Item for Item {
#global_field_name: Default::default(),
},
Self::Encode | Self::EncodeVectored => {
// `#[bitcode(encode_with = "Local")]`: convert `&Field` to `Local`, then encode `Local`.
// The conversion produces an owned `Local`, so we can't yield `&Local` into the
// vectored fast-path; instead the vectored case stamps out a loop calling `encode`.
if let Some(local) = attrs.encode_with.as_ref().filter(|_| !attrs.skip) {
let local_static = replace_lifetimes(local, "static");
let needs_transmute = &local_static != local;
let encode_one = |field_ref: TokenStream| {
if needs_transmute {
// HACK: Since encoders don't have lifetimes we can't reference <T<'a> as Encode>::Encoder since 'a
// does not exist. Instead we replace this with <T<'static> as Encode>::Encoder and transmute it to
// T<'a>. No encoder actually encodes T<'static> any differently from T<'a> so this is sound.
let local_underscore = replace_lifetimes(local, "_");
quote! {{
let __local: #local_underscore = ::core::convert::From::from(#field_ref);
self.#global_field_name.encode(unsafe {
::core::mem::transmute::<&#local_underscore, &#local_static>(&__local)
});
}}
} else {
quote! {{
let __local: #local_static = ::core::convert::From::from(#field_ref);
self.#global_field_name.encode(&__local);
}}
}
};
return if matches!(self, Self::EncodeVectored) {
let encode_one = encode_one(quote! { &__me.#real_field_name });
quote! {
for __me in i.clone() {
#encode_one
}
}
} else {
encode_one(quote! { #field_name })
};
}

let static_type = replace_lifetimes(field_type, "static");
let value = if attrs.skip {
quote! {
Expand Down Expand Up @@ -246,6 +285,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Encode {
None
}

fn with_type(&self, field_attrs: &BitcodeFieldAttrs) -> Option<Type> {
field_attrs.encode_with.clone()
}

fn derive_impl(
&self,
attrs: &BitcodeDeriveAttrs,
Expand Down
11 changes: 10 additions & 1 deletion bitcode_derive/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ pub trait Derive<const ITEM_COUNT: usize> {
/// Bound for skipped fields, e.g. `Default`
fn skip_bound(&self) -> Option<Path>;

/// For `#[bitcode(with)]` fields, the local type whose generic parameters should be bounded
/// instead of the field's own type (which need not implement `Encode`/`Decode`).
fn with_type(&self, field_attrs: &BitcodeFieldAttrs) -> Option<Type>;

/// Generates the derive implementation.
fn derive_impl(
&self,
Expand All @@ -156,7 +160,12 @@ pub trait Derive<const ITEM_COUNT: usize> {
Some(self.bound(&attrs))
};
if let Some(bound) = bound {
bounds.add_bound_type(field.clone(), &field_attrs, bound);
let mut field = field.clone();
if let Some(with_type) = self.with_type(&field_attrs) {
// Bound the local type instead of the field's own (remote) type.
field.ty = with_type;
}
bounds.add_bound_type(field, &field_attrs, bound);
}
Ok(field_attrs)
})
Expand Down
148 changes: 148 additions & 0 deletions src/derive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,151 @@ mod tests {
}
}
}

#[cfg(test)]
mod with_tests {
use crate::{decode, encode, Decode, Encode};
use alloc::string::String;

// "Remote" types that intentionally don't implement Encode/Decode themselves.
#[derive(Debug, PartialEq, Clone)]
struct RemoteStr(String);

#[derive(Debug, PartialEq, Clone, Copy)]
struct Meters(f32);

// Local proxy types used via `#[bitcode(with = ...)]`.
#[derive(Encode, Decode)]
struct LocalStr<'a>(&'a str);

#[derive(Encode, Decode)]
struct LocalMeters(f32);

impl<'a> From<&'a RemoteStr> for LocalStr<'a> {
fn from(v: &'a RemoteStr) -> Self {
LocalStr(v.0.as_str())
}
}
impl From<LocalStr<'_>> for RemoteStr {
fn from(v: LocalStr<'_>) -> Self {
RemoteStr(String::from(v.0))
}
}
impl From<&Meters> for LocalMeters {
fn from(v: &Meters) -> Self {
LocalMeters(v.0)
}
}
impl From<LocalMeters> for Meters {
fn from(v: LocalMeters) -> Self {
Meters(v.0)
}
}

#[derive(Encode, Decode, Debug, PartialEq)]
struct User {
// Borrowing proxy via the `with` shorthand.
#[bitcode(with = "LocalStr<'a>")]
first_name: RemoteStr,
// Same, but with the encode/decode sides spelled separately.
#[bitcode(encode_with = "LocalStr<'a>", decode_with = "LocalStr<'a>")]
last_name: RemoteStr,
// Non-borrowing proxy (no lifetime, so no transmute on encode).
#[bitcode(with = "LocalMeters")]
height: Meters,
}

#[derive(Encode, Decode, Debug, PartialEq)]
enum Shape {
Point,
Line(#[bitcode(with = "LocalMeters")] Meters),
Named {
#[bitcode(with = "LocalStr<'a>")]
name: RemoteStr,
},
}

#[test]
fn test_with_struct() {
let user = User {
first_name: RemoteStr("Ada".into()),
last_name: RemoteStr("Lovelace".into()),
height: Meters(1.7),
};
assert_eq!(decode::<User>(&encode(&user)).unwrap(), user);
}

#[test]
fn test_with_enum() {
for shape in [
Shape::Point,
Shape::Line(Meters(42.0)),
Shape::Named {
name: RemoteStr("triangle".into()),
},
] {
assert_eq!(decode::<Shape>(&encode(&shape)).unwrap(), shape);
}
}

/// Encoding via a `with` proxy must be wire-identical to encoding the proxy type directly.
#[test]
fn test_with_wire_compatible() {
#[derive(Encode)]
struct Direct<'a> {
first_name: LocalStr<'a>,
last_name: LocalStr<'a>,
height: LocalMeters,
}

let user = User {
first_name: RemoteStr("Ada".into()),
last_name: RemoteStr("Lovelace".into()),
height: Meters(1.7),
};
let direct = Direct {
first_name: LocalStr("Ada"),
last_name: LocalStr("Lovelace"),
height: LocalMeters(1.7),
};
assert_eq!(encode(&user), encode(&direct));
}

// A generic "remote" type that intentionally doesn't implement Encode/Decode.
#[derive(Debug, PartialEq)]
struct RemoteCell<A>(A);

#[derive(Encode, Decode)]
struct LocalCell<A>(A);

impl<A: Clone> From<&RemoteCell<A>> for LocalCell<A> {
fn from(v: &RemoteCell<A>) -> Self {
LocalCell(v.0.clone())
}
}
impl<A> From<LocalCell<A>> for RemoteCell<A> {
fn from(v: LocalCell<A>) -> Self {
RemoteCell(v.0)
}
}

#[derive(Encode, Decode, Debug, PartialEq)]
struct GenericContainer<A: Clone> {
// The proxy `LocalCell<A>` uses the struct's generic parameter `A`.
#[bitcode(with = "LocalCell<A>")]
value: RemoteCell<A>,
plain: u8,
}

#[test]
fn test_with_generic() {
let gc = GenericContainer {
value: RemoteCell(String::from("generic")),
plain: 9,
};
assert_eq!(
decode::<GenericContainer<String>>(&encode(&gc)).unwrap(),
gc
);
}
}
Loading