diff --git a/bitcode_derive/src/attribute.rs b/bitcode_derive/src/attribute.rs index fb10acd..5f85bc6 100644 --- a/bitcode_derive/src/attribute.rs +++ b/bitcode_derive/src/attribute.rs @@ -10,29 +10,40 @@ enum BitcodeAttr { BoundType(Type), CrateName(Path), Skip, + /// `#[bitcode(with = "LocalType")]`, shorthand for both `encode_with` and `decode_with`. + With(Type), + /// `#[bitcode(encode_with = "LocalType")]`. + EncodeWith(Type), + /// `#[bitcode(decode_with = "LocalType")]`. + DecodeWith(Type), +} + +/// Parses a `#[bitcode(name = "Type")]` string literal value into a [`Type`]. +fn parse_type_attr(nested: &Meta) -> Result { + match nested { + Meta::NameValue(name_value) => { + let expr = &name_value.value; + let str_lit = match expr { + Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }) => v, + _ => return err(&expr, "expected string e.g. \"LocalType\""), + }; + let value = TokenStream::from_str(&str_lit.value()).unwrap(); + parse2(value).map_err(|e| error(str_lit, &format!("{e}"))) + } + _ => err(&nested, "expected name value"), + } } impl BitcodeAttr { fn new(nested: &Meta) -> Result { let path = path_ident_string(nested.path(), &nested)?; match path.as_str() { - "bound_type" => match nested { - Meta::NameValue(name_value) => { - let expr = &name_value.value; - let str_lit = match expr { - Expr::Lit(ExprLit { - lit: Lit::Str(v), .. - }) => v, - _ => return err(&expr, "expected string e.g. \"T\""), - }; - - let value = TokenStream::from_str(&str_lit.value()).unwrap(); - Ok(Self::BoundType( - parse2(value).map_err(|e| error(str_lit, &format!("{e}")))?, - )) - } - _ => err(&nested, "expected name value"), - }, + "bound_type" => Ok(Self::BoundType(parse_type_attr(nested)?)), + "with" => Ok(Self::With(parse_type_attr(nested)?)), + "encode_with" => Ok(Self::EncodeWith(parse_type_attr(nested)?)), + "decode_with" => Ok(Self::DecodeWith(parse_type_attr(nested)?)), "crate" => match nested { Meta::NameValue(name_value) => { let expr = &name_value.value; @@ -97,6 +108,28 @@ impl BitcodeAttr { err(nested, "can only apply to fields") } } + Self::With(ty) => { + if let BitcodeAnyAttrs::Field(field) = attrs { + set_if_not_duplicate(&mut field.encode_with, Some(ty.clone()), nested)?; + set_if_not_duplicate(&mut field.decode_with, Some(ty), nested) + } else { + err(nested, "can only apply to fields") + } + } + Self::EncodeWith(ty) => { + if let BitcodeAnyAttrs::Field(field) = attrs { + set_if_not_duplicate(&mut field.encode_with, Some(ty), nested) + } else { + err(nested, "can only apply to fields") + } + } + Self::DecodeWith(ty) => { + if let BitcodeAnyAttrs::Field(field) = attrs { + set_if_not_duplicate(&mut field.decode_with, Some(ty), nested) + } else { + err(nested, "can only apply to fields") + } + } } } } @@ -159,6 +192,10 @@ pub struct BitcodeFieldAttrs<'a> { parent: BitcodeDeriveOrVariantAttrs<'a>, pub bound_type: Option, pub skip: bool, + /// Encode this field as a different (local) type via `From`/`Into`. + pub encode_with: Option, + /// Decode this field as a different (local) type via `From`/`Into`. + pub decode_with: Option, } impl<'a> std::ops::Deref for BitcodeFieldAttrs<'a> { type Target = BitcodeDeriveOrVariantAttrs<'a>; @@ -172,6 +209,8 @@ impl<'a> BitcodeFieldAttrs<'a> { parent, bound_type: Default::default(), skip: Default::default(), + encode_with: Default::default(), + decode_with: Default::default(), }; BitcodeAnyAttrs::Field(&mut ret).parse_inner(attrs)?; Ok(ret) diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index fedf2c8..03f874e 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -43,7 +43,9 @@ impl crate::shared::Item for Item { ) -> TokenStream { match self { Self::Type => { - let mut de_type = replace_lifetimes(field_type, DE_LIFETIME).to_token_stream(); + // `#[bitcode(decode_with = "Local")]` stores the decoder for `Local`, not the field. + let base_type = attrs.decode_with.as_ref().unwrap_or(field_type); + let mut de_type = replace_lifetimes(base_type, DE_LIFETIME).to_token_stream(); if attrs.skip { de_type = quote! { ::core::marker::PhantomData<#de_type> }; } @@ -61,10 +63,15 @@ impl crate::shared::Item for Item { }, // Only used by enum variants. Self::Decode => { + let de_type = replace_lifetimes(field_type, DE_LIFETIME); let value = if attrs.skip { quote! { Default::default() } + } else if attrs.decode_with.is_some() { + quote! { + ::core::convert::Into::<#de_type>::into(self.#global_field_name.decode()) + } } else { quote! { self.#global_field_name.decode() @@ -84,6 +91,11 @@ impl crate::shared::Item for Item { quote! {{ (#target).write(Default::default()); }} + } else if attrs.decode_with.is_some() { + quote! {{ + let __local = self.#global_field_name.decode(); + (#target).write(::core::convert::Into::into(__local)); + }} } else { quote! { self.#global_field_name.decode_in_place(#target); @@ -261,6 +273,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Decode { Some(parse_quote!(Default)) } + fn with_type(&self, field_attrs: &BitcodeFieldAttrs) -> Option { + field_attrs.decode_with.clone() + } + fn derive_impl( &self, attrs: &BitcodeDeriveAttrs, diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 26e0c58..228350a 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -35,7 +35,9 @@ impl crate::shared::Item for Item { ) -> TokenStream { match self { Self::Type => { - let mut static_type = replace_lifetimes(field_type, "static").to_token_stream(); + // `#[bitcode(encode_with = "Local")]` stores the encoder for `Local`, not the field. + let base_type = attrs.encode_with.as_ref().unwrap_or(field_type); + let mut static_type = replace_lifetimes(base_type, "static").to_token_stream(); if attrs.skip { static_type = quote! { ::core::marker::PhantomData<#static_type> }; } @@ -48,6 +50,43 @@ impl crate::shared::Item for Item { #global_field_name: Default::default(), }, Self::Encode | Self::EncodeVectored => { + // `#[bitcode(encode_with = "Local")]`: convert `&Field` to `Local`, then encode `Local`. + // The conversion produces an owned `Local`, so we can't yield `&Local` into the + // vectored fast-path; instead the vectored case stamps out a loop calling `encode`. + if let Some(local) = attrs.encode_with.as_ref().filter(|_| !attrs.skip) { + let local_static = replace_lifetimes(local, "static"); + let needs_transmute = &local_static != local; + let encode_one = |field_ref: TokenStream| { + if needs_transmute { + // HACK: Since encoders don't have lifetimes we can't reference as Encode>::Encoder since 'a + // does not exist. Instead we replace this with as Encode>::Encoder and transmute it to + // T<'a>. No encoder actually encodes T<'static> any differently from T<'a> so this is sound. + let local_underscore = replace_lifetimes(local, "_"); + quote! {{ + let __local: #local_underscore = ::core::convert::From::from(#field_ref); + self.#global_field_name.encode(unsafe { + ::core::mem::transmute::<&#local_underscore, &#local_static>(&__local) + }); + }} + } else { + quote! {{ + let __local: #local_static = ::core::convert::From::from(#field_ref); + self.#global_field_name.encode(&__local); + }} + } + }; + return if matches!(self, Self::EncodeVectored) { + let encode_one = encode_one(quote! { &__me.#real_field_name }); + quote! { + for __me in i.clone() { + #encode_one + } + } + } else { + encode_one(quote! { #field_name }) + }; + } + let static_type = replace_lifetimes(field_type, "static"); let value = if attrs.skip { quote! { @@ -246,6 +285,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Encode { None } + fn with_type(&self, field_attrs: &BitcodeFieldAttrs) -> Option { + field_attrs.encode_with.clone() + } + fn derive_impl( &self, attrs: &BitcodeDeriveAttrs, diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 5444e94..de32f9c 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -130,6 +130,10 @@ pub trait Derive { /// Bound for skipped fields, e.g. `Default` fn skip_bound(&self) -> Option; + /// For `#[bitcode(with)]` fields, the local type whose generic parameters should be bounded + /// instead of the field's own type (which need not implement `Encode`/`Decode`). + fn with_type(&self, field_attrs: &BitcodeFieldAttrs) -> Option; + /// Generates the derive implementation. fn derive_impl( &self, @@ -156,7 +160,12 @@ pub trait Derive { Some(self.bound(&attrs)) }; if let Some(bound) = bound { - bounds.add_bound_type(field.clone(), &field_attrs, bound); + let mut field = field.clone(); + if let Some(with_type) = self.with_type(&field_attrs) { + // Bound the local type instead of the field's own (remote) type. + field.ty = with_type; + } + bounds.add_bound_type(field, &field_attrs, bound); } Ok(field_attrs) }) diff --git a/src/derive/mod.rs b/src/derive/mod.rs index e0eed37..766eb80 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -426,3 +426,151 @@ mod tests { } } } + +#[cfg(test)] +mod with_tests { + use crate::{decode, encode, Decode, Encode}; + use alloc::string::String; + + // "Remote" types that intentionally don't implement Encode/Decode themselves. + #[derive(Debug, PartialEq, Clone)] + struct RemoteStr(String); + + #[derive(Debug, PartialEq, Clone, Copy)] + struct Meters(f32); + + // Local proxy types used via `#[bitcode(with = ...)]`. + #[derive(Encode, Decode)] + struct LocalStr<'a>(&'a str); + + #[derive(Encode, Decode)] + struct LocalMeters(f32); + + impl<'a> From<&'a RemoteStr> for LocalStr<'a> { + fn from(v: &'a RemoteStr) -> Self { + LocalStr(v.0.as_str()) + } + } + impl From> for RemoteStr { + fn from(v: LocalStr<'_>) -> Self { + RemoteStr(String::from(v.0)) + } + } + impl From<&Meters> for LocalMeters { + fn from(v: &Meters) -> Self { + LocalMeters(v.0) + } + } + impl From for Meters { + fn from(v: LocalMeters) -> Self { + Meters(v.0) + } + } + + #[derive(Encode, Decode, Debug, PartialEq)] + struct User { + // Borrowing proxy via the `with` shorthand. + #[bitcode(with = "LocalStr<'a>")] + first_name: RemoteStr, + // Same, but with the encode/decode sides spelled separately. + #[bitcode(encode_with = "LocalStr<'a>", decode_with = "LocalStr<'a>")] + last_name: RemoteStr, + // Non-borrowing proxy (no lifetime, so no transmute on encode). + #[bitcode(with = "LocalMeters")] + height: Meters, + } + + #[derive(Encode, Decode, Debug, PartialEq)] + enum Shape { + Point, + Line(#[bitcode(with = "LocalMeters")] Meters), + Named { + #[bitcode(with = "LocalStr<'a>")] + name: RemoteStr, + }, + } + + #[test] + fn test_with_struct() { + let user = User { + first_name: RemoteStr("Ada".into()), + last_name: RemoteStr("Lovelace".into()), + height: Meters(1.7), + }; + assert_eq!(decode::(&encode(&user)).unwrap(), user); + } + + #[test] + fn test_with_enum() { + for shape in [ + Shape::Point, + Shape::Line(Meters(42.0)), + Shape::Named { + name: RemoteStr("triangle".into()), + }, + ] { + assert_eq!(decode::(&encode(&shape)).unwrap(), shape); + } + } + + /// Encoding via a `with` proxy must be wire-identical to encoding the proxy type directly. + #[test] + fn test_with_wire_compatible() { + #[derive(Encode)] + struct Direct<'a> { + first_name: LocalStr<'a>, + last_name: LocalStr<'a>, + height: LocalMeters, + } + + let user = User { + first_name: RemoteStr("Ada".into()), + last_name: RemoteStr("Lovelace".into()), + height: Meters(1.7), + }; + let direct = Direct { + first_name: LocalStr("Ada"), + last_name: LocalStr("Lovelace"), + height: LocalMeters(1.7), + }; + assert_eq!(encode(&user), encode(&direct)); + } + + // A generic "remote" type that intentionally doesn't implement Encode/Decode. + #[derive(Debug, PartialEq)] + struct RemoteCell(A); + + #[derive(Encode, Decode)] + struct LocalCell(A); + + impl From<&RemoteCell> for LocalCell { + fn from(v: &RemoteCell) -> Self { + LocalCell(v.0.clone()) + } + } + impl From> for RemoteCell { + fn from(v: LocalCell) -> Self { + RemoteCell(v.0) + } + } + + #[derive(Encode, Decode, Debug, PartialEq)] + struct GenericContainer { + // The proxy `LocalCell` uses the struct's generic parameter `A`. + #[bitcode(with = "LocalCell")] + value: RemoteCell, + plain: u8, + } + + #[test] + fn test_with_generic() { + let gc = GenericContainer { + value: RemoteCell(String::from("generic")), + plain: 9, + }; + assert_eq!( + decode::>(&encode(&gc)).unwrap(), + gc + ); + } +}