diff --git a/conformance/Cargo.toml b/conformance/Cargo.toml index 9184055e6..b2b933b76 100644 --- a/conformance/Cargo.toml +++ b/conformance/Cargo.toml @@ -8,6 +8,7 @@ authors.workspace = true [dependencies] bytes = "1" env_logger = { version = "0.11", default-features = false } -prost = { path = "../prost" } +prost = { path = "../prost", features = ["serde", "serde-json"] } +prost-types = { path = "../prost-types", features = ["serde", "any-v2"] } protobuf = { path = "../protobuf" } tests = { path = "../tests" } diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index b41904761..a912150bf 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,3 +1,6 @@ # TODO(tokio-rs/prost#2): prost doesn't preserve unknown fields. Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput + +# Unsupported right now +Recommended.Proto2.JsonInput.FieldNameExtension.Validator diff --git a/conformance/src/main.rs b/conformance/src/main.rs index dd3165cb6..3af12352b 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -4,14 +4,29 @@ use bytes::{Buf, BufMut}; use prost::Message; use protobuf::conformance::{ - conformance_request, conformance_response, ConformanceRequest, ConformanceResponse, WireFormat, + conformance_request, conformance_response, ConformanceRequest, ConformanceResponse, + TestCategory, WireFormat, }; use protobuf::test_messages::proto2::TestAllTypesProto2; use protobuf::test_messages::proto3::TestAllTypesProto3; -use tests::{roundtrip, RoundtripResult}; +use tests::{roundtrip, RoundtripInput, RoundtripOutputType, RoundtripResult}; fn main() -> io::Result<()> { env_logger::init(); + + let mut registry = prost_types::any_v2::TypeRegistry::new_with_well_known_types(); + registry.insert_msg_type_for_type_url::( + "type.googleapis.com/protobuf_test_messages.proto2.TestAllTypesProto2", + ); + registry.insert_msg_type_for_type_url::( + "type.googleapis.com/protobuf_test_messages.proto3.TestAllTypesProto3", + ); + + let type_resolver = registry.into_type_resolver(); + prost_types::any_v2::with_type_resolver(Some(type_resolver), entrypoint) +} + +fn entrypoint() -> io::Result<()> { let mut bytes = vec![0; 4]; loop { @@ -49,17 +64,12 @@ fn main() -> io::Result<()> { } fn handle_request(request: ConformanceRequest) -> conformance_response::Result { - match request.requested_output_format() { + let output_ty = match request.requested_output_format() { WireFormat::Unspecified => { return conformance_response::Result::ParseError( "output format unspecified".to_string(), ); } - WireFormat::Json => { - return conformance_response::Result::Skipped( - "JSON output is not supported".to_string(), - ); - } WireFormat::Jspb => { return conformance_response::Result::Skipped( "JSPB output is not supported".to_string(), @@ -70,16 +80,13 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { "TEXT_FORMAT output is not supported".to_string(), ); } - WireFormat::Protobuf => (), + WireFormat::Protobuf => RoundtripOutputType::Protobuf, + WireFormat::Json => RoundtripOutputType::Json, }; - let buf = match request.payload { + let input = match &request.payload { None => return conformance_response::Result::ParseError("no payload".to_string()), - Some(conformance_request::Payload::JsonPayload(_)) => { - return conformance_response::Result::Skipped( - "JSON input is not supported".to_string(), - ); - } + Some(conformance_request::Payload::JspbPayload(_)) => { return conformance_response::Result::Skipped( "JSON input is not supported".to_string(), @@ -90,12 +97,20 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { "JSON input is not supported".to_string(), ); } - Some(conformance_request::Payload::ProtobufPayload(buf)) => buf, + Some(conformance_request::Payload::ProtobufPayload(buf)) => RoundtripInput::Protobuf(buf), + Some(conformance_request::Payload::JsonPayload(buf)) => RoundtripInput::Json(buf), }; - let roundtrip = match request.message_type.as_str() { - "protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip::(&buf), - "protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip::(&buf), + let ignore_unknown_fields = + request.test_category() == TestCategory::JsonIgnoreUnknownParsingTest; + + let roundtrip = match &*request.message_type { + "protobuf_test_messages.proto2.TestAllTypesProto2" => { + roundtrip::(input, output_ty, ignore_unknown_fields) + } + "protobuf_test_messages.proto3.TestAllTypesProto3" => { + roundtrip::(input, output_ty, ignore_unknown_fields) + } _ => { return conformance_response::Result::ParseError(format!( "unknown message type: {}", @@ -105,7 +120,11 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { }; match roundtrip { - RoundtripResult::Ok(buf) => conformance_response::Result::ProtobufPayload(buf), + RoundtripResult::Protobuf(buf) => conformance_response::Result::ProtobufPayload(buf), + RoundtripResult::Json(buf) => conformance_response::Result::JsonPayload(buf), + RoundtripResult::EncodeError(error) => { + conformance_response::Result::SerializeError(error.to_string()) + } RoundtripResult::DecodeError(error) => { conformance_response::Result::ParseError(error.to_string()) } diff --git a/fuzz/afl/proto3/src/main.rs b/fuzz/afl/proto3/src/main.rs index e38c843f6..33101dea7 100644 --- a/fuzz/afl/proto3/src/main.rs +++ b/fuzz/afl/proto3/src/main.rs @@ -1,10 +1,10 @@ use afl::fuzz; use protobuf::test_messages::proto3::TestAllTypesProto3; -use tests::roundtrip; +use tests::roundtrip_proto; fn main() { fuzz!(|data: &[u8]| { - let _ = roundtrip::(data).unwrap_error(); + let _ = roundtrip_proto::(data).unwrap_error(); }); } diff --git a/fuzz/afl/proto3/src/reproduce.rs b/fuzz/afl/proto3/src/reproduce.rs index ab0c0f9c5..90c4d02c8 100644 --- a/fuzz/afl/proto3/src/reproduce.rs +++ b/fuzz/afl/proto3/src/reproduce.rs @@ -1,5 +1,5 @@ use protobuf::test_messages::proto3::TestAllTypesProto3; -use tests::roundtrip; +use tests::roundtrip_proto; fn main() { let args: Vec = std::env::args().collect(); @@ -9,5 +9,5 @@ fn main() { } let data = std::fs::read(&args[1]).expect(&format!("Could not open file {}", args[1])); - let _ = roundtrip::(&data).unwrap_error(); + let _ = roundtrip_proto::(&data).unwrap_error(); } diff --git a/fuzz/fuzzers/proto2.rs b/fuzz/fuzzers/proto2.rs index 5b7cb51ef..9628a930c 100644 --- a/fuzz/fuzzers/proto2.rs +++ b/fuzz/fuzzers/proto2.rs @@ -2,8 +2,8 @@ use libfuzzer_sys::fuzz_target; use protobuf::test_messages::proto2::TestAllTypesProto2; -use tests::roundtrip; +use tests::roundtrip_proto; fuzz_target!(|data: &[u8]| { - let _ = roundtrip::(data).unwrap_error(); + let _ = roundtrip_proto::(data).unwrap_error(); }); diff --git a/fuzz/fuzzers/proto3.rs b/fuzz/fuzzers/proto3.rs index 309701636..dd33763dd 100644 --- a/fuzz/fuzzers/proto3.rs +++ b/fuzz/fuzzers/proto3.rs @@ -2,8 +2,8 @@ use libfuzzer_sys::fuzz_target; use protobuf::test_messages::proto3::TestAllTypesProto3; -use tests::roundtrip; +use tests::roundtrip_proto; fuzz_target!(|data: &[u8]| { - let _ = roundtrip::(data).unwrap_error(); + let _ = roundtrip_proto::(data).unwrap_error(); }); diff --git a/prost-build/Cargo.toml b/prost-build/Cargo.toml index 0d79b07fa..96cd55726 100644 --- a/prost-build/Cargo.toml +++ b/prost-build/Cargo.toml @@ -25,6 +25,7 @@ prost-types = { version = "0.13.3", path = "../prost-types", default-features = tempfile = "3" once_cell = "1.17.1" regex = { version = "1.8.1", default-features = false, features = ["std", "unicode-bool"] } +indexmap = "2.1.0" # feature: format prettyplease = { version = "0.2", optional = true } diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index f8d341445..afd1f6c63 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -1,8 +1,9 @@ use std::ascii; use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::iter; +use indexmap::IndexMap; use itertools::{Either, Itertools}; use log::debug; use multimap::MultiMap; @@ -16,9 +17,12 @@ use prost_types::{ use crate::ast::{Comments, Method, Service}; use crate::extern_paths::ExternPaths; -use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel}; use crate::message_graph::MessageGraph; use crate::Config; +use crate::{ + ident::{strip_enum_prefix, to_snake, to_upper_camel}, + json::{json_attr_for_enum_variant, json_attr_for_field, json_attr_for_one_of_variant}, +}; mod c_escaping; use c_escaping::unescape_c_escape_string; @@ -49,9 +53,9 @@ fn prost_path(config: &Config) -> &str { config.prost_path.as_deref().unwrap_or("::prost") } -struct Field { - descriptor: FieldDescriptorProto, - path_index: i32, +pub struct Field { + pub descriptor: FieldDescriptorProto, + pub path_index: i32, } impl Field { @@ -62,9 +66,13 @@ impl Field { } } - fn rust_name(&self) -> String { + pub fn rust_name(&self) -> String { to_snake(self.descriptor.name()) } + + pub fn rust_variant_name(&self) -> String { + to_upper_camel(self.descriptor.name()) + } } struct OneofField { @@ -238,6 +246,7 @@ impl CodeGenerator<'_> { prost_path(self.config) )); self.append_skip_debug(&fq_message_name); + self.append_serde(); self.push_indent(); self.buf.push_str("pub struct "); self.buf.push_str(&to_upper_camel(&message_name)); @@ -378,6 +387,14 @@ impl CodeGenerator<'_> { } } + fn append_serde(&mut self) { + if self.config.enable_serde { + push_indent(self.buf, self.depth); + self.buf.push_str("#[prost(serde)]"); + self.buf.push('\n'); + } + } + fn append_enum_attributes(&mut self, fq_message_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); for attribute in self.config.enum_attributes.get(fq_message_name) { @@ -462,6 +479,7 @@ impl CodeGenerator<'_> { if boxed { self.buf.push_str(", boxed"); } + self.buf.push_str(", tag=\""); self.buf.push_str(&field.descriptor.number().to_string()); @@ -495,8 +513,15 @@ impl CodeGenerator<'_> { self.buf.push_str(&default.escape_default().to_string()); } } + self.buf.push('"'); - self.buf.push_str("\")]\n"); + if self.config.enable_serde { + if let Some(json_attr) = json_attr_for_field(field) { + self.buf.push_str(&format!(", {json_attr}")); + } + } + + self.buf.push_str(")]\n"); self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str("pub "); @@ -554,12 +579,23 @@ impl CodeGenerator<'_> { let key_tag = self.field_type_tag(key); let value_tag = self.map_value_type_tag(value); + let json_attr = if self.config.enable_serde { + if let Some(json_attr) = json_attr_for_field(field) { + format!(", {json_attr}") + } else { + Default::default() + } + } else { + Default::default() + }; + self.buf.push_str(&format!( - "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", + "#[prost({}=\"{}, {}\", tag=\"{}\"{})]\n", map_type.annotation(), key_tag, value_tag, - field.descriptor.number() + field.descriptor.number(), + json_attr )); self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); @@ -619,12 +655,14 @@ impl CodeGenerator<'_> { self.message_graph .can_field_derive_copy(fq_message_name, &field.descriptor) }); + self.buf.push_str(&format!( "#[derive(Clone, {}PartialEq, {}::Oneof)]\n", if can_oneof_derive_copy { "Copy, " } else { "" }, prost_path(self.config) )); self.append_skip_debug(fq_message_name); + self.append_serde(); self.push_indent(); self.buf.push_str("pub enum "); self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); @@ -633,21 +671,36 @@ impl CodeGenerator<'_> { self.path.push(2); self.depth += 1; for field in &oneof.fields { + let proto_field_name = field.descriptor.name(); + let rust_variant_name = &field.rust_variant_name(); + self.path.push(field.path_index); - self.append_doc(fq_message_name, Some(field.descriptor.name())); + self.append_doc(fq_message_name, Some(proto_field_name)); self.path.pop(); - self.push_indent(); let ty_tag = self.field_type_tag(&field.descriptor); + let ty = self.resolve_type(&field.descriptor, fq_message_name); + + let json_attr = if self.config.enable_serde { + if let Some(json_attr) = json_attr_for_one_of_variant(field) { + format!(", {json_attr}") + } else { + Default::default() + } + } else { + Default::default() + }; + + self.push_indent(); self.buf.push_str(&format!( - "#[prost({}, tag=\"{}\")]\n", + "#[prost({}, tag=\"{}\"{})]\n", ty_tag, - field.descriptor.number() + field.descriptor.number(), + json_attr )); - self.append_field_attributes(&oneof_name, field.descriptor.name()); + self.append_field_attributes(&oneof_name, proto_field_name); self.push_indent(); - let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = self.boxed( &field.descriptor, @@ -657,23 +710,17 @@ impl CodeGenerator<'_> { debug!( " oneof: {:?}, type: {:?}, boxed: {}", - field.descriptor.name(), - ty, - boxed + proto_field_name, ty, boxed ); if boxed { self.buf.push_str(&format!( "{}(::prost::alloc::boxed::Box<{}>),\n", - to_upper_camel(field.descriptor.name()), - ty + rust_variant_name, ty )); } else { - self.buf.push_str(&format!( - "{}({}),\n", - to_upper_camel(field.descriptor.name()), - ty - )); + self.buf + .push_str(&format!("{}({}),\n", rust_variant_name, ty)); } } self.depth -= 1; @@ -741,6 +788,7 @@ impl CodeGenerator<'_> { )); self.push_indent(); self.buf.push_str("#[repr(i32)]\n"); + self.append_serde(); self.push_indent(); self.buf.push_str("pub enum "); self.buf.push_str(&enum_name); @@ -755,6 +803,14 @@ impl CodeGenerator<'_> { self.path.push(variant.path_idx as i32); self.append_doc(&fq_proto_enum_name, Some(variant.proto_name)); + + if self.config.enable_serde { + if let Some(json_attr) = json_attr_for_enum_variant(&enum_name, variant) { + self.push_indent(); + self.buf.push_str(&format!("#[prost({json_attr})]\n")); + } + }; + self.append_field_attributes(&fq_proto_enum_name, variant.proto_name); self.push_indent(); self.buf.push_str(&variant.generated_variant_name); @@ -973,8 +1029,8 @@ impl CodeGenerator<'_> { // protoc should always give fully qualified identifiers. assert_eq!(".", &pb_ident[..1]); - if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) { - return proto_ident; + if let Some(resolved) = self.extern_paths.resolve_ident(pb_ident) { + return resolved.rust_path; } let mut local_path = self @@ -1143,11 +1199,12 @@ fn can_pack(field: &FieldDescriptorProto) -> bool { ) } -struct EnumVariantMapping<'a> { - path_idx: usize, - proto_name: &'a str, - proto_number: i32, - generated_variant_name: String, +pub struct EnumVariantMapping<'a> { + pub path_idx: usize, + pub proto_name: &'a str, + pub proto_aliases: Vec<&'a str>, + pub proto_number: i32, + pub generated_variant_name: String, } fn build_enum_value_mappings<'a>( @@ -1155,35 +1212,43 @@ fn build_enum_value_mappings<'a>( do_strip_enum_prefix: bool, enum_values: &'a [EnumValueDescriptorProto], ) -> Vec> { - let mut numbers = HashSet::new(); let mut generated_names = HashMap::new(); - let mut mappings = Vec::new(); + // Use an insertion-order preserving map here because the enum ordering must be preserved. + let mut mappings = IndexMap::::new(); for (idx, value) in enum_values.iter().enumerate() { + let enum_name = value.name(); + let enum_value = value.number(); + // Skip duplicate enum values. Protobuf allows this when the // 'allow_alias' option is set. - if !numbers.insert(value.number()) { + if let Some(mapping) = mappings.get_mut(&enum_value) { + mapping.proto_aliases.push(enum_name); continue; } - let mut generated_variant_name = to_upper_camel(value.name()); + let mut generated_variant_name = to_upper_camel(enum_name); if do_strip_enum_prefix { generated_variant_name = strip_enum_prefix(generated_enum_name, &generated_variant_name); } - if let Some(old_v) = generated_names.insert(generated_variant_name.to_owned(), value.name()) - { + if let Some(old_v) = generated_names.insert(generated_variant_name.to_owned(), enum_name) { panic!("Generated enum variant names overlap: `{}` variant name to be used both by `{}` and `{}` ProtoBuf enum values", - generated_variant_name, old_v, value.name()); + generated_variant_name, old_v, enum_name); } - mappings.push(EnumVariantMapping { - path_idx: idx, - proto_name: value.name(), - proto_number: value.number(), - generated_variant_name, - }) + mappings.insert( + enum_value, + EnumVariantMapping { + path_idx: idx, + proto_name: enum_name, + proto_aliases: vec![], + proto_number: enum_value, + generated_variant_name, + }, + ); } - mappings + + mappings.into_values().collect() } diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 896726b16..0789fc90f 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -52,6 +52,7 @@ pub struct Config { pub(crate) prost_path: Option, #[cfg(feature = "format")] pub(crate) fmt: bool, + pub(crate) enable_serde: bool, } impl Config { @@ -1124,6 +1125,14 @@ impl Config { *buf = with_generated; } } + + /// Configures the code generator to also emit serde compatible serialization impls. + /// + /// Defaults to `false`. + pub fn enable_serde(&mut self) -> &mut Self { + self.enable_serde = true; + self + } } /// Write a slice as the entire contents of a file. @@ -1174,6 +1183,7 @@ impl default::Default for Config { prost_path: None, #[cfg(feature = "format")] fmt: true, + enable_serde: false, } } } diff --git a/prost-build/src/extern_paths.rs b/prost-build/src/extern_paths.rs index 8f6bee784..e6af0bb74 100644 --- a/prost-build/src/extern_paths.rs +++ b/prost-build/src/extern_paths.rs @@ -17,9 +17,22 @@ fn validate_proto_path(path: &str) -> Result<(), String> { Ok(()) } +#[derive(Debug)] +struct ExternPathEntry { + rust_path: String, + is_well_known: bool, +} + +#[derive(Debug)] +pub struct ResolvedPath { + pub rust_path: String, + #[allow(dead_code)] + pub is_well_known: bool, +} + #[derive(Debug)] pub struct ExternPaths { - extern_paths: HashMap, + extern_paths: HashMap, } impl ExternPaths { @@ -29,33 +42,39 @@ impl ExternPaths { }; for (proto_path, rust_path) in paths { - extern_paths.insert(proto_path.clone(), rust_path.clone())?; + extern_paths.insert(proto_path.clone(), rust_path.clone(), false)?; } if prost_types { - extern_paths.insert(".google.protobuf".to_string(), "::prost_types".to_string())?; - extern_paths.insert(".google.protobuf.BoolValue".to_string(), "bool".to_string())?; - extern_paths.insert( + extern_paths + .insert_well_known(".google.protobuf".to_string(), "::prost_types".to_string())?; + extern_paths + .insert_well_known(".google.protobuf.BoolValue".to_string(), "bool".to_string())?; + extern_paths.insert_well_known( ".google.protobuf.BytesValue".to_string(), "::prost::alloc::vec::Vec".to_string(), )?; - extern_paths.insert( + extern_paths.insert_well_known( ".google.protobuf.DoubleValue".to_string(), "f64".to_string(), )?; - extern_paths.insert(".google.protobuf.Empty".to_string(), "()".to_string())?; - extern_paths.insert(".google.protobuf.FloatValue".to_string(), "f32".to_string())?; - extern_paths.insert(".google.protobuf.Int32Value".to_string(), "i32".to_string())?; - extern_paths.insert(".google.protobuf.Int64Value".to_string(), "i64".to_string())?; - extern_paths.insert( + extern_paths + .insert_well_known(".google.protobuf.Empty".to_string(), "()".to_string())?; + extern_paths + .insert_well_known(".google.protobuf.FloatValue".to_string(), "f32".to_string())?; + extern_paths + .insert_well_known(".google.protobuf.Int32Value".to_string(), "i32".to_string())?; + extern_paths + .insert_well_known(".google.protobuf.Int64Value".to_string(), "i64".to_string())?; + extern_paths.insert_well_known( ".google.protobuf.StringValue".to_string(), "::prost::alloc::string::String".to_string(), )?; - extern_paths.insert( + extern_paths.insert_well_known( ".google.protobuf.UInt32Value".to_string(), "u32".to_string(), )?; - extern_paths.insert( + extern_paths.insert_well_known( ".google.protobuf.UInt64Value".to_string(), "u64".to_string(), )?; @@ -64,7 +83,16 @@ impl ExternPaths { Ok(extern_paths) } - fn insert(&mut self, proto_path: String, rust_path: String) -> Result<(), String> { + fn insert_well_known(&mut self, proto_path: String, rust_path: String) -> Result<(), String> { + self.insert(proto_path, rust_path, true) + } + + fn insert( + &mut self, + proto_path: String, + rust_path: String, + is_well_known: bool, + ) -> Result<(), String> { validate_proto_path(&proto_path)?; match self.extern_paths.entry(proto_path) { hash_map::Entry::Occupied(occupied) => { @@ -73,42 +101,56 @@ impl ExternPaths { occupied.key() )); } - hash_map::Entry::Vacant(vacant) => vacant.insert(rust_path), + hash_map::Entry::Vacant(vacant) => vacant.insert(ExternPathEntry { + rust_path, + is_well_known, + }), }; Ok(()) } - pub fn resolve_ident(&self, pb_ident: &str) -> Option { + pub fn resolve_ident(&self, pb_ident: &str) -> Option { // protoc should always give fully qualified identifiers. assert_eq!(".", &pb_ident[..1]); - if let Some(rust_path) = self.extern_paths.get(pb_ident) { - return Some(rust_path.clone()); + if let Some(ExternPathEntry { + rust_path, + is_well_known, + }) = self.extern_paths.get(pb_ident) + { + return Some(ResolvedPath { + rust_path: rust_path.clone(), + is_well_known: *is_well_known, + }); } // TODO(danburkert): there must be a more efficient way to do this, maybe a trie? for (idx, _) in pb_ident.rmatch_indices('.') { - if let Some(rust_path) = self.extern_paths.get(&pb_ident[..idx]) { + if let Some(entry) = self.extern_paths.get(&pb_ident[..idx]) { let mut segments = pb_ident[idx + 1..].split('.'); let ident_type = segments.next_back().map(to_upper_camel); - return Some( - rust_path - .split("::") - .chain(segments) - .enumerate() - .map(|(idx, segment)| { - if idx == 0 && segment == "crate" { - // If the first segment of the path is 'crate', then do not escape - // it into a raw identifier, since it's being used as the keyword. - segment.to_owned() - } else { - to_snake(segment) - } - }) - .chain(ident_type.into_iter()) - .join("::"), - ); + let rust_path = entry + .rust_path + .split("::") + .chain(segments) + .enumerate() + .map(|(idx, segment)| { + if idx == 0 && segment == "crate" { + // If the first segment of the path is 'crate', then do not escape + // it into a raw identifier, since it's being used as the keyword. + segment.to_owned() + } else { + to_snake(segment) + } + }) + .chain(ident_type.into_iter()) + .join("::"); + + return Some(ResolvedPath { + rust_path, + is_well_known: entry.is_well_known, + }); } } @@ -136,7 +178,10 @@ mod tests { .unwrap(); let case = |proto_ident: &str, resolved_ident: &str| { - assert_eq!(paths.resolve_ident(proto_ident).unwrap(), resolved_ident); + assert_eq!( + paths.resolve_ident(proto_ident).unwrap().rust_path, + resolved_ident + ); }; case(".foo", "::foo1"); @@ -160,7 +205,10 @@ mod tests { let paths = ExternPaths::new(&[], true).unwrap(); let case = |proto_ident: &str, resolved_ident: &str| { - assert_eq!(paths.resolve_ident(proto_ident).unwrap(), resolved_ident); + assert_eq!( + paths.resolve_ident(proto_ident).unwrap().rust_path, + resolved_ident + ); }; case(".google.protobuf.Value", "::prost_types::Value"); diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index bf1e8c517..6a97b426f 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -1,5 +1,6 @@ // This file is @generated by prost-build. #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Container { #[prost(oneof="container::Data", tags="1, 2")] pub data: ::core::option::Option, @@ -7,6 +8,7 @@ pub struct Container { /// Nested message and enum types in `Container`. pub mod container { #[derive(Clone, PartialEq, ::prost::Oneof)] + #[prost(serde)] pub enum Data { #[prost(message, tag="1")] Foo(::prost::alloc::boxed::Box), @@ -15,15 +17,18 @@ pub mod container { } } #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Foo { #[prost(string, tag="1")] pub foo: ::prost::alloc::string::String, } #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Qux { } diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index c130aad2e..bb986b7dd 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -1,5 +1,6 @@ // This file is @generated by prost-build. #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Container { #[prost(oneof = "container::Data", tags = "1, 2")] pub data: ::core::option::Option, @@ -7,6 +8,7 @@ pub struct Container { /// Nested message and enum types in `Container`. pub mod container { #[derive(Clone, PartialEq, ::prost::Oneof)] + #[prost(serde)] pub enum Data { #[prost(message, tag = "1")] Foo(::prost::alloc::boxed::Box), @@ -15,14 +17,17 @@ pub mod container { } } #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Foo { #[prost(string, tag = "1")] pub foo: ::prost::alloc::string::String, } #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Qux {} diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs index f39278358..9cbf7c927 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld.rs @@ -2,6 +2,7 @@ #[derive(derive_builder::Builder)] #[derive(custom_proto::Input)] #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Message { #[prost(string, tag="1")] pub say: ::prost::alloc::string::String, @@ -9,6 +10,7 @@ pub struct Message { #[derive(derive_builder::Builder)] #[derive(custom_proto::Output)] #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Response { #[prost(string, tag="1")] pub say: ::prost::alloc::string::String, @@ -16,9 +18,13 @@ pub struct Response { #[some_enum_attr(u8)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +#[prost(serde)] pub enum ServingStatus { + #[prost(json(proto_name = "UNKNOWN"))] Unknown = 0, + #[prost(json(proto_name = "SERVING"))] Serving = 1, + #[prost(json(proto_name = "NOT_SERVING"))] NotServing = 2, } impl ServingStatus { diff --git a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs index c75338e2b..b922a46ff 100644 --- a/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs +++ b/prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs @@ -2,6 +2,7 @@ #[derive(derive_builder::Builder)] #[derive(custom_proto::Input)] #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Message { #[prost(string, tag = "1")] pub say: ::prost::alloc::string::String, @@ -9,6 +10,7 @@ pub struct Message { #[derive(derive_builder::Builder)] #[derive(custom_proto::Output)] #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct Response { #[prost(string, tag = "1")] pub say: ::prost::alloc::string::String, @@ -16,9 +18,13 @@ pub struct Response { #[some_enum_attr(u8)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +#[prost(serde)] pub enum ServingStatus { + #[prost(json(proto_name = "UNKNOWN"))] Unknown = 0, + #[prost(json(proto_name = "SERVING"))] Serving = 1, + #[prost(json(proto_name = "NOT_SERVING"))] NotServing = 2, } impl ServingStatus { diff --git a/prost-build/src/json.rs b/prost-build/src/json.rs new file mode 100644 index 000000000..2b902d7a1 --- /dev/null +++ b/prost-build/src/json.rs @@ -0,0 +1,116 @@ +use std::iter; + +use heck::{ToShoutySnakeCase, ToSnakeCase}; +use itertools::Itertools; + +use crate::code_generator::{EnumVariantMapping, Field}; + +pub fn json_attr_for_field(field: &Field) -> Option { + let rust_field_name = &field.rust_name(); + let proto_field_name = field.descriptor.name(); + let inferred_json_field_name = proto_field_name.to_proto_camel_case(); + + if let Some(json_name) = field.descriptor.json_name.as_deref() { + if json_name != inferred_json_field_name { + return Some(format!( + "json(proto_name = \"{}\", json_name = \"{}\")", + proto_field_name, json_name + )); + } + } + + let field_name_is_stable_for_json = rust_field_name == proto_field_name; + if field_name_is_stable_for_json { + // We skip emitting the `json` attribute for this case because this is inferred by the + // derive macro. + None + } else { + Some(format!("json(proto_name = \"{proto_field_name}\")")) + } +} + +pub fn json_attr_for_one_of_variant(field: &Field) -> Option { + let rust_variant_name = &field.rust_variant_name(); + let proto_field_name = field.descriptor.name(); + let inferred_json_field_name = proto_field_name.to_proto_camel_case(); + + if let Some(json_name) = field.descriptor.json_name.as_deref() { + if json_name != inferred_json_field_name { + return Some(format!( + "json(proto_name = \"{}\", json_name = \"{}\")", + proto_field_name, json_name + )); + } + } + + let variant_name_is_stable_for_json = rust_variant_name.to_snake_case() == proto_field_name; + if variant_name_is_stable_for_json { + // We skip emitting the `json` attribute for this case because this is inferred by the + // derive macro. + None + } else { + Some(format!("json(proto_name = \"{proto_field_name}\")")) + } +} + +pub fn json_attr_for_enum_variant( + rust_enum_name: &str, + variant: &EnumVariantMapping<'_>, +) -> Option { + let variant_name_is_stable_for_json = { + let rust_enum_variant_name = + format!("{}_{}", rust_enum_name, variant.generated_variant_name).to_shouty_snake_case(); + rust_enum_variant_name == variant.proto_name + }; + + let emit_proto_names = !variant.proto_aliases.is_empty() || !variant_name_is_stable_for_json; + + if emit_proto_names { + let names = iter::once(variant.proto_name) + .chain(variant.proto_aliases.iter().copied()) + .map(|proto_name| format!("proto_name = \"{proto_name}\"")) + .join(", "); + + Some(format!("json({names})")) + } else { + None + } +} + +pub trait ToProtoCamelCase: ToOwned { + fn to_proto_camel_case(&self) -> Self::Owned; +} + +impl ToProtoCamelCase for str { + fn to_proto_camel_case(&self) -> Self::Owned { + // Reference: https://protobuf.com/docs/language-spec#default-json-names + // + // If no json_name pseudo-option is present, the JSON name of the field will be + // the field's name converted to camelCase. To convert to camelCase: + // + // - Discard any trailing underscores (_) + // - When a leading or interior underscore is encountered, discard the underscore and + // capitalize the next non-underscore character encountered. + // - Any other non-underscore and non-capitalized character is retained as is. + // + let mut capitalize_next = false; + let mut out = String::with_capacity(self.len()); + for chr in self.chars() { + if chr == '_' { + capitalize_next = true; + } else if capitalize_next { + out.push(chr.to_ascii_uppercase()); + capitalize_next = false; + } else { + out.push(chr); + } + } + out + } +} + +impl ToProtoCamelCase for String { + fn to_proto_camel_case(&self) -> Self::Owned { + self.as_str().to_proto_camel_case() + } +} diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 14324f9cb..b6b2aaa45 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -149,6 +149,7 @@ pub(crate) use collections::{BytesType, MapType}; mod code_generator; mod extern_paths; mod ident; +mod json; mod message_graph; mod path; @@ -384,6 +385,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); Config::new() + .enable_serde() .service_generator(Box::new(ServiceTraitGenerator)) .out_dir(tempdir.path()) .compile_protos(&["src/fixtures/smoke_test/smoke_test.proto"], &["src"]) @@ -399,6 +401,7 @@ mod tests { let gen = MockServiceGenerator::new(Rc::clone(&state)); Config::new() + .enable_serde() .service_generator(Box::new(gen)) .include_file("_protos.rs") .out_dir(tempdir.path()) @@ -423,6 +426,8 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); let mut config = Config::new(); + config.enable_serde(); + config .out_dir(tempdir.path()) // Add attributes to all messages and enums @@ -472,6 +477,7 @@ mod tests { let previously_empty_proto_path = tempdir.path().join(Path::new("google.protobuf.rs")); Config::new() + .enable_serde() .service_generator(Box::new(gen)) .include_file(include_file) .out_dir(tempdir.path()) @@ -503,6 +509,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); Config::new() + .enable_serde() .out_dir(tempdir.path()) .boxed("Container.data.foo") .boxed("Bar.qux") @@ -533,6 +540,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); Config::new() + .enable_serde() .service_generator(Box::new(gen)) .include_file(include_file) .out_dir(tempdir.path()) @@ -575,6 +583,7 @@ mod tests { let mut buf = Vec::new(); Config::new() + .enable_serde() .default_package_filename("_.default") .write_includes(modules.iter().collect(), &mut buf, None, &file_names) .unwrap(); diff --git a/prost-derive/Cargo.toml b/prost-derive/Cargo.toml index 32eecad38..2229eac45 100644 --- a/prost-derive/Cargo.toml +++ b/prost-derive/Cargo.toml @@ -14,7 +14,8 @@ proc-macro = true [dependencies] anyhow = "1.0.1" -itertools = ">=0.10.1, <=0.13" +heck = "0.4.1" +itertools = ">=0.10.3, <=0.13" proc-macro2 = "1.0.60" quote = "1" syn = { version = "2", features = ["extra-traits"] } diff --git a/prost-derive/src/field/group.rs b/prost-derive/src/field/group.rs index 485ecfc1b..e1e3f09ee 100644 --- a/prost-derive/src/field/group.rs +++ b/prost-derive/src/field/group.rs @@ -3,12 +3,13 @@ use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::Meta; -use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; +use crate::field::{set_bool, set_option, tag_attr, word_attr, Json, Label}; #[derive(Clone)] pub struct Field { pub label: Label, pub tag: u32, + pub json: Option, } impl Field { @@ -17,6 +18,7 @@ impl Field { let mut label = None; let mut tag = None; let mut boxed = false; + let mut json = None; let mut unknown_attrs = Vec::new(); @@ -25,6 +27,8 @@ impl Field { set_bool(&mut group, "duplicate group attributes")?; } else if word_attr("boxed", attr) { set_bool(&mut boxed, "duplicate boxed attributes")?; + } else if let Some(j) = Json::from_attr(attr)? { + set_option(&mut json, j, "duplicate json attributes")?; } else if let Some(t) = tag_attr(attr)? { set_option(&mut tag, t, "duplicate tag attributes")?; } else if let Some(l) = Label::from_attr(attr) { @@ -53,6 +57,7 @@ impl Field { Ok(Some(Field { label: label.unwrap_or(Label::Optional), tag, + json, })) } @@ -132,4 +137,12 @@ impl Field { Label::Repeated => quote!(#ident.clear()), } } + + pub fn to_message_field(&self) -> super::Field { + super::Field::Message(super::message::Field { + label: self.label, + tag: self.tag, + json: self.json.clone(), + }) + } } diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 4855cc5c6..278e82624 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -4,7 +4,7 @@ use quote::quote; use syn::punctuated::Punctuated; use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token}; -use crate::field::{scalar, set_option, tag_attr}; +use crate::field::{scalar, set_option, tag_attr, Json}; #[derive(Clone, Debug)] pub enum MapTy { @@ -42,6 +42,7 @@ fn fake_scalar(ty: scalar::Ty) -> scalar::Field { ty, kind, tag: 0, // Not used here + json: None, } } @@ -51,16 +52,20 @@ pub struct Field { pub key_ty: scalar::Ty, pub value_ty: ValueTy, pub tag: u32, + pub json: Option, } impl Field { pub fn new(attrs: &[Meta], inferred_tag: Option) -> Result, Error> { let mut types = None; let mut tag = None; + let mut json = None; for attr in attrs { if let Some(t) = tag_attr(attr)? { set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(j) = Json::from_attr(attr)? { + set_option(&mut json, j, "duplicate json attributes")?; } else if let Some(map_ty) = attr .path() .get_ident() @@ -114,6 +119,7 @@ impl Field { key_ty, value_ty, tag, + json, }), _ => None, }) diff --git a/prost-derive/src/field/message.rs b/prost-derive/src/field/message.rs index f6ac391e7..5d6d5a175 100644 --- a/prost-derive/src/field/message.rs +++ b/prost-derive/src/field/message.rs @@ -3,12 +3,13 @@ use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::Meta; -use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; +use crate::field::{set_bool, set_option, tag_attr, word_attr, Json, Label}; #[derive(Clone)] pub struct Field { pub label: Label, pub tag: u32, + pub json: Option, } impl Field { @@ -17,6 +18,7 @@ impl Field { let mut label = None; let mut tag = None; let mut boxed = false; + let mut json = None; let mut unknown_attrs = Vec::new(); @@ -25,6 +27,8 @@ impl Field { set_bool(&mut message, "duplicate message attribute")?; } else if word_attr("boxed", attr) { set_bool(&mut boxed, "duplicate boxed attribute")?; + } else if let Some(j) = Json::from_attr(attr)? { + set_option(&mut json, j, "duplicate json attributes")?; } else if let Some(t) = tag_attr(attr)? { set_option(&mut tag, t, "duplicate tag attributes")?; } else if let Some(l) = Label::from_attr(attr) { @@ -53,6 +57,7 @@ impl Field { Ok(Some(Field { label: label.unwrap_or(Label::Optional), tag, + json, })) } diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index 366075e45..a62863d28 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -1,8 +1,8 @@ mod group; -mod map; -mod message; +pub mod map; +pub mod message; mod oneof; -mod scalar; +pub mod scalar; use std::fmt; use std::slice; @@ -10,7 +10,7 @@ use std::slice; use anyhow::{bail, Error}; use proc_macro2::TokenStream; use quote::quote; -use syn::punctuated::Punctuated; +use syn::{punctuated::Punctuated, LitStr}; use syn::{Attribute, Expr, ExprLit, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; #[derive(Clone)] @@ -68,6 +68,7 @@ impl Field { } else if let Some(field) = message::Field::new_oneof(&attrs)? { Field::Message(field) } else if let Some(field) = map::Field::new_oneof(&attrs)? { + // FIXME: oneofs don't support repeated fields (which includes maps). Field::Map(field) } else if let Some(field) = group::Field::new_oneof(&attrs)? { Field::Group(field) @@ -172,6 +173,32 @@ impl Field { _ => None, } } + + pub fn json(&self) -> Option> { + match self { + Self::Scalar(scalar::Field { json, .. }) + | Self::Message(message::Field { json, .. }) + | Self::Group(group::Field { json, .. }) + | Self::Map(map::Field { json, .. }) => Some(json.as_ref()), + Self::Oneof(_) => None, + } + } + + pub fn is_required(&self) -> bool { + matches!( + self, + Self::Scalar(scalar::Field { + kind: scalar::Kind::Required(_), + .. + }) | Field::Message(message::Field { + label: Label::Required, + .. + }) | Field::Group(group::Field { + label: Label::Required, + .. + }) + ) + } } #[derive(Clone, Copy, PartialEq, Eq)] @@ -225,7 +252,7 @@ impl fmt::Display for Label { } /// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. -fn prost_attrs(attrs: Vec) -> Result, Error> { +pub fn prost_attrs(attrs: Vec) -> Result, Error> { let mut result = Vec::new(); for attr in attrs.iter() { if let Meta::List(meta_list) = &attr.meta { @@ -353,3 +380,51 @@ fn tags_attr(attr: &Meta) -> Result>, Error> { _ => bail!("invalid tag attribute: {:?}", attr), } } + +#[derive(Debug, Clone)] +pub struct Json { + pub proto_name: Option, + pub proto_alt_names: Vec, + pub json_name: Option, +} + +impl Json { + pub fn from_attr(attr: &Meta) -> Result, Error> { + let Meta::List(meta_list) = attr else { + return Ok(None); + }; + if !meta_list.path.is_ident("json") { + return Ok(None); + } + + let mut proto_name = None; + let mut proto_alt_names = vec![]; + let mut json_name = None; + + meta_list.parse_nested_meta(|meta| { + if meta.path.is_ident("proto_name") { + let _ = meta.input.parse::()?; + let value = meta.input.parse::()?.value(); + if proto_name.is_none() { + proto_name = Some(value); + } else { + proto_alt_names.push(value); + } + return Ok(()); + } + if meta.path.is_ident("json_name") { + let _ = meta.input.parse::()?; + json_name = Some(meta.input.parse::()?.value()); + return Ok(()); + } + + Err(meta.error("unrecognized attributes")) + })?; + + Ok(Some(Json { + proto_name, + proto_alt_names, + json_name, + })) + } +} diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index c2e870524..0859e9abe 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -5,7 +5,7 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path}; -use crate::field::{bool_attr, set_option, tag_attr, Label}; +use crate::field::{bool_attr, set_option, tag_attr, Json, Label}; /// A scalar protobuf field. #[derive(Clone)] @@ -13,6 +13,7 @@ pub struct Field { pub ty: Ty, pub kind: Kind, pub tag: u32, + pub json: Option, } impl Field { @@ -22,12 +23,15 @@ impl Field { let mut packed = None; let mut default = None; let mut tag = None; + let mut json = None; let mut unknown_attrs = Vec::new(); for attr in attrs { if let Some(t) = Ty::from_attr(attr)? { set_option(&mut ty, t, "duplicate type attributes")?; + } else if let Some(j) = Json::from_attr(attr)? { + set_option(&mut json, j, "duplicate json attributes")?; } else if let Some(p) = bool_attr("packed", attr)? { set_option(&mut packed, p, "duplicate packed attributes")?; } else if let Some(t) = tag_attr(attr)? { @@ -86,7 +90,12 @@ impl Field { (Some(Label::Repeated), _, false) => Kind::Repeated, }; - Ok(Some(Field { ty, kind, tag })) + Ok(Some(Field { + ty, + kind, + tag, + json, + })) } pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 270d25ee2..9cd7833dc 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -10,23 +10,42 @@ use itertools::Itertools; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ - punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, - FieldsUnnamed, Ident, Index, Variant, + punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, ExprLit, ExprUnary, + Fields, FieldsNamed, FieldsUnnamed, Ident, Index, Lit, UnOp, Variant, }; mod field; -use crate::field::Field; +use crate::field::{Field, Json}; + +mod serde; fn try_message(input: TokenStream) -> Result { let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; - syn::custom_keyword!(skip_debug); - let skip_debug = input + let mut skip_debug = false; + let mut emit_serde = false; + + if let Some(attr) = input .attrs - .into_iter() - .any(|a| a.path().is_ident("prost") && a.parse_args::().is_ok()); + .iter() + .find(|attr| attr.path().is_ident("prost")) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("skip_debug") { + skip_debug = true; + return Ok(()); + } + + if meta.path.is_ident("serde") { + emit_serde = true; + return Ok(()); + } + + Err(meta.error("unrecognized attributes")) + })?; + } let variant_data = match input.data { Data::Struct(variant_data) => variant_data, @@ -172,6 +191,12 @@ fn try_message(input: TokenStream) -> Result { } }; + let serde_impl = if emit_serde { + serde::impls_for_struct(&ident, generics, &fields)? + } else { + Default::default() + }; + let expanded = quote! { impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause { #[allow(unused_variables)] @@ -250,6 +275,8 @@ fn try_message(input: TokenStream) -> Result { #expanded #methods + + #serde_impl }; Ok(expanded) @@ -273,15 +300,42 @@ fn try_enumeration(input: TokenStream) -> Result { Data::Union(..) => bail!("Enumeration can not be derived for a union"), }; + let mut emit_serde = false; + + if let Some(attr) = input + .attrs + .iter() + .find(|attr| attr.path().is_ident("prost")) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("serde") { + emit_serde = true; + return Ok(()); + } + + Err(meta.error("unrecognized attributes")) + })?; + } + // Map the variants into 'fields'. - let mut variants: Vec<(Ident, Expr)> = Vec::new(); + let mut variants: Vec<(Ident, Expr, Option)> = Vec::new(); for Variant { + attrs, ident, fields, discriminant, .. } in punctuated_variants { + let mut json = None; + + let attrs = field::prost_attrs(attrs)?; + for attr in &attrs { + if let Some(j) = field::Json::from_attr(attr)? { + field::set_option(&mut json, j, "duplicate json attribute")?; + } + } + match fields { Fields::Unit => (), Fields::Named(_) | Fields::Unnamed(_) => { @@ -290,7 +344,28 @@ fn try_enumeration(input: TokenStream) -> Result { } match discriminant { - Some((_, expr)) => variants.push((ident, expr)), + Some((_, expr)) => { + // Validate the the discriminant. + let inner_expr = match &expr { + Expr::Unary(ExprUnary { + op: UnOp::Neg(_), + expr, + .. + }) => expr, + _ => &expr, + }; + if !matches!( + inner_expr, + Expr::Lit(ExprLit { + lit: Lit::Int(_), + .. + }) + ) { + bail!("Enumeration variants must have an integral discriminant"); + } + + variants.push((ident, expr, json)) + } None => bail!("Enumeration variants must have a discriminant"), } } @@ -301,14 +376,14 @@ fn try_enumeration(input: TokenStream) -> Result { let default = variants[0].0.clone(); - let is_valid = variants.iter().map(|(_, value)| quote!(#value => true)); - let from = variants - .iter() - .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant))); + let is_valid = variants.iter().map(|(_, value, _)| quote!(#value => true)); + let from = variants.iter().map( + |(variant, value, _)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), + ); let try_from = variants .iter() - .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant))); + .map(|(variant, value, _)| quote!(#value => ::core::result::Result::Ok(#ident::#variant))); let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); let from_i32_doc = format!( @@ -316,6 +391,12 @@ fn try_enumeration(input: TokenStream) -> Result { ident ); + let serde_impls = if emit_serde { + serde::impls_for_enum(&ident, generics, &variants)? + } else { + Default::default() + }; + let expanded = quote! { impl #impl_generics #ident #ty_generics #where_clause { #[doc=#is_valid_doc] @@ -358,6 +439,8 @@ fn try_enumeration(input: TokenStream) -> Result { } } } + + #serde_impls }; Ok(expanded) @@ -373,11 +456,28 @@ fn try_oneof(input: TokenStream) -> Result { let ident = input.ident; - syn::custom_keyword!(skip_debug); - let skip_debug = input + let mut skip_debug = false; + let mut emit_serde = false; + + if let Some(attr) = input .attrs - .into_iter() - .any(|a| a.path().is_ident("prost") && a.parse_args::().is_ok()); + .iter() + .find(|attr| attr.path().is_ident("prost")) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("skip_debug") { + skip_debug = true; + return Ok(()); + } + + if meta.path.is_ident("serde") { + emit_serde = true; + return Ok(()); + } + + Err(meta.error("unrecognized attributes")) + })?; + } let variants = match input.data { Data::Enum(DataEnum { variants, .. }) => variants, @@ -459,6 +559,12 @@ fn try_oneof(input: TokenStream) -> Result { quote!(#ident::#variant_ident(ref value) => #encoded_len) }); + let serde_impls = if emit_serde { + serde::impls_for_oneof(&ident, generics, &fields)? + } else { + Default::default() + }; + let expanded = quote! { impl #impl_generics #ident #ty_generics #where_clause { /// Encodes the message to a buffer. @@ -492,6 +598,7 @@ fn try_oneof(input: TokenStream) -> Result { } } + #serde_impls }; let expanded = if skip_debug { expanded diff --git a/prost-derive/src/serde/de.rs b/prost-derive/src/serde/de.rs new file mode 100644 index 000000000..eb09f225b --- /dev/null +++ b/prost-derive/src/serde/de.rs @@ -0,0 +1,522 @@ +use std::iter; + +use anyhow::{anyhow, Error}; +use heck::{ToShoutySnakeCase, ToSnakeCase}; +use itertools::Itertools; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Expr, Generics}; + +use crate::{ + field::{self, scalar, Field, Json}, + serde::utils::ToProtoCamelCase, +}; + +pub fn impl_for_message( + struct_ident: &Ident, + _generics: &Generics, + fields: &[(TokenStream, Field)], +) -> Result { + let full_struct_name = format!("struct {}", struct_ident); + + let mut field_vals = vec![]; + let mut field_assignments = vec![]; + let mut field_variants = vec![]; + let mut field_match_arms = vec![]; + let mut field_match_oneofs = vec![]; + let mut field_matches = vec![]; + let mut field_required_checks = vec![]; + + for (field_idx, (field_ident, field)) in fields.iter().enumerate() { + let field_ident_str = field_ident.to_string(); + + let field_variant_ident = format_ident!("__field{}", field_idx); + + field_vals.push(quote! { #field_variant_ident }); + field_assignments.push(quote! { + #field_ident: _private::Option::unwrap_or_default(#field_variant_ident) + }); + + if let Field::Oneof(oneof) = field { + let ty_path = &oneof.ty; + field_variants.push(quote! { + #field_variant_ident( + <#ty_path as _private::DeserializeOneOf>::FieldKey + ) + }) + } else { + field_variants.push(quote! { #field_variant_ident }) + } + + if let Some(json) = field.json() { + // Only a scalar, message, group or map field may have the json attribute. + + let proto_field_name = match json { + Some(Json { + proto_name: Some(proto_name), + .. + }) => proto_name, + _ => &field_ident_str, + }; + + let json_field_name = match json { + Some(Json { + json_name: Some(json_name), + .. + }) => json_name.to_owned(), + Some(Json { + proto_name: Some(proto_name), + .. + }) => proto_name.to_proto_camel_case(), + Some(_) | None => field_ident_str.to_proto_camel_case(), + }; + + if proto_field_name != &json_field_name { + field_match_arms.push(quote! { + #proto_field_name | #json_field_name + => _private::Ok(__Field::#field_variant_ident) + }); + } else { + field_match_arms.push(quote! { + #proto_field_name + => _private::Ok(__Field::#field_variant_ident) + }); + } + + let deserializer = deserializer_for_field(field)?; + field_matches.push(quote! { + __Field::#field_variant_ident => { + if _private::Option::is_some(&#field_variant_ident) { + return _private::Err( + <__A::Error as _serde::de::Error>::duplicate_field(#field_ident_str) + ); + } + let val =_serde::de::MapAccess::next_value_seed( + &mut __map, + _private::MaybeDesIntoWithConfig::<#deserializer, _>::new(__config) + )?; + let val = _private::MaybeDeserializedValue::unwrap_for_field( + val, + __config, + #field_ident_str + )?; + #field_variant_ident = _private::Some(val); + } + }); + + if field.is_required() { + field_required_checks.push(quote! { + if #field_variant_ident.is_none() { + return _private::Err( + <__A::Error as _serde::de::Error>::missing_field(#field_ident_str) + ); + } + }); + } + } + + if let Field::Oneof(oneof) = field { + let ty_path = &oneof.ty; + + field_match_oneofs.push(quote! { + if let _private::Some(field_key) + = <#ty_path as _private::DeserializeOneOf>::deserialize_field_key(__value) + { + return _private::Ok(__Field::#field_variant_ident(field_key)); + } + }); + + field_matches.push(quote! { + __Field::#field_variant_ident(key) => { + if _private::Option::is_some(&#field_variant_ident) { + let __val = _serde::de::MapAccess::next_value_seed( + &mut __map, + _private::DesIntoWithConfig::<_private::NullDeserializer, ()>::new( + __config + ), + ); + match __val { + _private::Ok(()) => continue, + _private::Err(_) => return _private::Err( + <__A::Error as _serde::de::Error>::duplicate_field(#field_ident_str) + ), + } + } + let __val = _serde::de::MapAccess::next_value_seed( + &mut __map, + _private::OneOfDeserializer(key, __config), + )?; + if _private::Option::is_some(&__val) { + #field_variant_ident = _private::Some(__val); + } + } + }) + } + } + + let map_field = quote! { + enum __Field { + #(#field_variants,)* + __unknown, + } + + struct __FieldVisitor<'a>(&'a _private::DeserializerConfig); + + impl<'a, 'de> _serde::de::Visitor<'de> for __FieldVisitor<'a> { + type Value = __Field; + + fn expecting( + &self, + __formatter: &mut _private::fmt::Formatter + ) -> _serde::__private::fmt::Result { + _private::fmt::Formatter::write_str(__formatter, "field identifier") + } + + fn visit_str<__E>(self, __value: &str) -> _private::Result + where + __E: _serde::de::Error + { + let __config = self.0; + + #(#field_match_oneofs)* + + match __value { + #(#field_match_arms,)* + _ => { + if __config.ignore_unknown_fields { + _private::Ok(__Field::__unknown) + } else { + _private::Err(<__E as _serde::de::Error>::unknown_field(__value, &[])) + } + }, + } + } + } + + impl<'de> _private::CustomDeserialize<'de> for __Field { + fn deserialize<__D>( + __deserializer: __D, + __config: &_private::DeserializerConfig + ) -> _private::Result + where + __D: _serde::Deserializer<'de>, + { + _serde::Deserializer::deserialize_identifier( + __deserializer, + __FieldVisitor(__config), + ) + } + } + }; + + let map_visitor = quote! { + struct __Visitor<'a>(&'a _private::DeserializerConfig); + + impl<'a, 'de> _serde::de::Visitor<'de> for __Visitor<'a> { + type Value = #struct_ident; + + fn expecting(&self, __formatter: &mut _private::fmt::Formatter) -> _private::fmt::Result { + _private::fmt::Formatter::write_str(__formatter, #full_struct_name) + } + + fn visit_map<__A>(self, mut __map: __A) -> _private::Result + where + __A: _serde::de::MapAccess<'de> + { + let __config = self.0; + + #(let mut #field_vals = _private::None;)* + + while let _private::Some(__key) + = _serde::de::MapAccess::next_key_seed( + &mut __map, + _private::DesWithConfig::<__Field>::new(__config) + )? + { + match __key { + #(#field_matches,)* + __Field::__unknown => { + _serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>( + &mut __map + )?; + } + } + } + + #(#field_required_checks)* + + _private::Ok(#struct_ident { + #(#field_assignments),* + }) + } + } + }; + + Ok(quote! { + impl<'de> _private::CustomDeserialize<'de> for #struct_ident { + fn deserialize<__D>( + __deserializer: __D, + __config: &_private::DeserializerConfig + ) -> _private::Result + where + __D: _serde::Deserializer<'de>, + { + #map_field + + #map_visitor + + _serde::Deserializer::deserialize_map( + __deserializer, + __Visitor(__config), + ) + } + } + + impl<'de> _serde::Deserialize<'de> for #struct_ident { + #[inline] + fn deserialize<__D>( + __deserializer: __D, + ) -> _private::Result + where + __D: _serde::Deserializer<'de>, + { + let __config = <_private::DeserializerConfig as _private::Default>::default(); + ::deserialize( + __deserializer, + &__config, + ) + } + } + }) +} + +pub fn impl_for_oneof( + oneof_ident: &Ident, + _generics: &Generics, + fields: &[(Ident, Field)], +) -> Result { + let mut field_keys = vec![]; + let mut match_field_key_str_arms = vec![]; + let mut match_field_key_arms = vec![]; + + let field_key_enum_ident = format_ident!("{}FieldKey", oneof_ident); + + for (field_idx, (field_ident, field)) in fields.iter().enumerate() { + let field_key_ident = format_ident!("__field{}", field_idx); + let field_ident_str = field_ident.to_string(); + + let Some(json) = field.json() else { + return Err(anyhow!("unsupported field in oneof")); + }; + + let proto_field_name = match json { + Some(Json { + proto_name: Some(proto_name), + .. + }) => proto_name.to_owned(), + _ => field_ident_str.to_snake_case(), + }; + + let json_field_name = match json { + Some(Json { + json_name: Some(json_name), + .. + }) => json_name.to_owned(), + Some(Json { + proto_name: Some(proto_name), + .. + }) => proto_name.to_proto_camel_case(), + Some(_) | None => field_ident_str.to_snake_case().to_proto_camel_case(), + }; + + if proto_field_name != json_field_name { + match_field_key_str_arms.push(quote! { + #proto_field_name | #json_field_name + => _private::Some(#field_key_enum_ident::#field_key_ident) + }); + } else { + match_field_key_str_arms.push(quote! { + #proto_field_name + => _private::Some(#field_key_enum_ident::#field_key_ident) + }); + } + + assert!(field.is_required()); + + let deserializer = deserializer_for_field(field)?; + match_field_key_arms.push(quote! { + #field_key_enum_ident::#field_key_ident => { + let __val = < + _private::OptionDeserializer<#deserializer> + as _private::DeserializeInto<_private::Option<_>> + >::deserialize_into(__deserializer, __config)?; + _private::Ok(__val.map(Self::#field_ident)) + } + }); + + field_keys.push(field_key_ident); + } + + Ok(quote! { + pub enum #field_key_enum_ident { + #(#field_keys,)* + } + + impl _private::DeserializeOneOf for #oneof_ident { + type FieldKey = #field_key_enum_ident; + + fn deserialize_field_key(__val: &str) -> _private::Option { + match __val { + #(#match_field_key_str_arms,)* + _ => _private::None, + } + } + + fn deserialize_by_field_key<'de, __D>( + __field_key: Self::FieldKey, + __deserializer: __D, + __config: &_private::DeserializerConfig, + ) -> _private::Result<_private::Option, __D::Error> + where + __D: _serde::de::Deserializer<'de> + { + match __field_key { + #(#match_field_key_arms,)* + } + } + } + }) +} + +pub fn impl_for_enum( + enum_ident: &Ident, + _generics: &Generics, + variants: &[(Ident, Expr, Option)], +) -> Result { + let (str_arms, int_arms): (Vec<_>, Vec<_>) = variants + .iter() + .map(|(variant_ident, descr, json)| { + let json_value = match json { + Some(Json { + proto_name: Some(proto_name), + proto_alt_names, + .. + }) => iter::once(proto_name.to_owned()) + .chain(proto_alt_names.iter().cloned()) + .collect::>(), + _ => vec![format!("{enum_ident}_{variant_ident}").to_shouty_snake_case()], + }; + let str_arm = quote! { + #(#json_value)|* => _private::Ok(_private::Some(Self::#variant_ident)) + }; + let int_arm = quote! { + #descr => _private::Ok(_private::Some(Self::#variant_ident)) + }; + (str_arm, int_arm) + }) + .multiunzip(); + + Ok(quote! { + impl _private::DeserializeEnum for #enum_ident { + fn deserialize_from_i32<__E>(val: i32) + -> _private::Result<_private::Option, __E> + where + __E: _serde::de::Error + { + match val { + #(#int_arms,)* + _ => _private::Ok(_private::None), + } + } + + fn deserialize_from_str<__E>(val: &str) + -> _private::Result<_private::Option, __E> + where + __E: _serde::de::Error + { + match val { + #(#str_arms,)* + _ => _private::Ok(_private::None), + } + } + } + }) +} + +fn deserializer_for_field(field: &Field) -> Result { + // Map group fields to message fields, since they deserialize the same. + let remapped_group_field; + let field = if let Field::Group(group) = field { + remapped_group_field = group.to_message_field(); + &remapped_group_field + } else { + field + }; + Ok(match field { + Field::Scalar(scalar) => { + let de = deserializer_for_ty(&scalar.ty, false); + match scalar.kind { + scalar::Kind::Required(_) => de, + scalar::Kind::Plain(_) => quote! { _private::DefaultDeserializer<#de> }, + scalar::Kind::Optional(_) => quote! { _private::OptionDeserializer<#de> }, + scalar::Kind::Repeated | scalar::Kind::Packed => { + quote! { _private::DefaultDeserializer<_private::VecDeserializer<#de>> } + } + } + } + Field::Message(message) => { + let inner = quote! { _private::MessageDeserializer }; + match message.label { + field::Label::Optional => quote! { + _private::OptionDeserializer<#inner> + }, + field::Label::Repeated => quote! { + _private::DefaultDeserializer<_private::VecDeserializer<#inner>> + }, + field::Label::Required => inner, + } + } + Field::Map(map) => { + let key_deserializer = deserializer_for_ty(&map.key_ty, true); + let val_deserializer = match &map.value_ty { + field::map::ValueTy::Scalar(ty) => deserializer_for_ty(ty, false), + field::map::ValueTy::Message => { + quote! { _private::DefaultDeserializer<_private::MessageDeserializer> } + } + }; + quote! { + _private::DefaultDeserializer< + _private::MapDeserializer<#key_deserializer, #val_deserializer> + > + } + } + Field::Group(_) | Field::Oneof(_) => unreachable!(), + }) +} + +fn deserializer_for_ty(ty: &scalar::Ty, accept_str_eq: bool) -> TokenStream { + use scalar::Ty; + match ty { + Ty::Int32 + | Ty::Int64 + | Ty::Uint32 + | Ty::Uint64 + | Ty::Sint32 + | Ty::Sint64 + | Ty::Fixed32 + | Ty::Fixed64 + | Ty::Sfixed32 + | Ty::Sfixed64 => quote! { _private::IntDeserializer }, + Ty::Float | Ty::Double => quote! { _private::FloatDeserializer }, + Ty::Bool => { + quote! { _private::BoolDeserializer<{ #accept_str_eq }> } + } + Ty::String => { + quote! { _private::ForwardDeserializer } + } + Ty::Enumeration(path) => { + quote! { _private::EnumDeserializer::<#path> } + } + Ty::Bytes(_) => quote! { _private::BytesDeserializer }, + } +} diff --git a/prost-derive/src/serde/mod.rs b/prost-derive/src/serde/mod.rs new file mode 100644 index 000000000..d0f93275f --- /dev/null +++ b/prost-derive/src/serde/mod.rs @@ -0,0 +1,97 @@ +use anyhow::Error; +use proc_macro2::{Ident, TokenStream}; +use quote::quote; +use syn::{Expr, Generics}; + +use crate::field::{Field, Json}; + +mod de; +mod ser; +mod utils; + +pub fn impls_for_enum( + enum_ident: &Ident, + generics: &Generics, + variants: &[(Ident, Expr, Option)], +) -> Result { + let serialize_impl = ser::impl_for_enum(enum_ident, generics, variants)?; + let deserialize_impl = de::impl_for_enum(enum_ident, generics, variants)?; + + let items = quote! { + extern crate prost as _prost; + + use _prost::serde::private::_serde; + use _prost::serde::private as _private; + + #serialize_impl + + #deserialize_impl + }; + + let wrapped = quote! { + #[doc(hidden)] + const _: () = { + #items + }; + }; + + Ok(wrapped) +} + +pub fn impls_for_oneof( + ident: &Ident, + generics: &Generics, + fields: &[(Ident, Field)], +) -> Result { + let serialize_impl = ser::impl_for_oneof(ident, generics, fields)?; + let deserialize_impl = de::impl_for_oneof(ident, generics, fields)?; + + let items = quote! { + extern crate prost as _prost; + + use _prost::serde::private::_serde; + use _prost::serde::private as _private; + + #serialize_impl + + #deserialize_impl + }; + + let wrapped = quote! { + #[doc(hidden)] + const _: () = { + #items + }; + }; + + Ok(wrapped) +} + +pub fn impls_for_struct( + struct_ident: &Ident, + generics: &Generics, + fields: &[(TokenStream, Field)], +) -> Result { + let serialize_impl = ser::impl_for_message(struct_ident, generics, fields)?; + let deserialize_impl = de::impl_for_message(struct_ident, generics, fields)?; + + let items = quote! { + extern crate prost as _prost; + + use _prost::serde::private::_serde; + use _prost::serde::private as _private; + + #serialize_impl + + #deserialize_impl + }; + + let wrapped = quote! { + #[doc(hidden)] + const _: () = { + #items + }; + }; + + Ok(wrapped) +} diff --git a/prost-derive/src/serde/ser.rs b/prost-derive/src/serde/ser.rs new file mode 100644 index 000000000..05f9e6d3b --- /dev/null +++ b/prost-derive/src/serde/ser.rs @@ -0,0 +1,450 @@ +use anyhow::{anyhow, Error}; +use heck::{ToShoutySnakeCase, ToSnakeCase}; +use proc_macro2::{Ident, TokenStream}; +use quote::{quote, TokenStreamExt}; +use syn::{Expr, Generics}; + +use crate::{ + field::{self, Field, Json}, + serde::utils::ToProtoCamelCase, +}; + +pub fn impl_for_message( + struct_ident: &Ident, + generics: &Generics, + fields: &[(TokenStream, Field)], +) -> Result { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let struct_ident_str = struct_ident.to_string(); + + let num_fields = fields.len(); + + let mut ser_stmts = TokenStream::new(); + + for (field_ident, field) in fields { + use field::scalar::Kind; + + if let Some(json) = field.json() { + // Only scalar, message, group and map fields may have the json attribute. + + let json_field_name = match json { + Some(Json { + json_name: Some(name), + .. + }) => name.to_owned(), + Some(Json { + proto_name: Some(name), + .. + }) => name.to_proto_camel_case(), + Some(_) | None => field_ident.to_string().to_proto_camel_case(), + }; + + // Map a group field to an equivalent message field because they share the same + // serialization impl. + let remapped_group_field; + let field = if let Field::Group(group) = field { + remapped_group_field = group.to_message_field(); + &remapped_group_field + } else { + field + }; + + match field { + Field::Scalar(scalar) => { + let wrapper = wrapper_for_ty(&scalar.ty); + + match &scalar.kind { + Kind::Plain(_) => { + ser_stmts.append_all(quote! { + if __config.emit_fields_with_default_value + || !_private::is_default_value(&__self.#field_ident) + { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig( + #wrapper(&__self.#field_ident), + __config, + ) + )?; + } + }); + } + Kind::Required(_) => { + ser_stmts.append_all(quote! { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig( + #wrapper(&__self.#field_ident), + __config, + ) + )?; + }); + } + Kind::Optional(_) => { + ser_stmts.append_all(quote! { + if let _private::Option::Some(val) = &__self.#field_ident { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig( + #wrapper(val), + __config, + ) + )?; + } else { + if __config.emit_nulled_optional_fields { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::Option::<()>::None + )?; + } + } + }); + } + Kind::Repeated | Kind::Packed => { + ser_stmts.append_all(quote! { + if __config.emit_fields_with_default_value + || !_private::is_default_value(&__self.#field_ident) + { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig( + _private::SerMappedVecItems( + &__self.#field_ident, + #wrapper + ), + __config, + ) + )?; + } + }); + } + } + } + Field::Message(message) => { + use field::Label; + + match message.label { + Label::Required => { + ser_stmts.append_all(quote! { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig(&__self.#field_ident, __config) + )?; + }); + } + Label::Optional => { + ser_stmts.append_all(quote! { + if let _private::Option::Some(__val) = &__self.#field_ident { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig(__val, __config) + )?; + } else { + if __config.emit_nulled_optional_fields { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::Option::<()>::None + )?; + } + } + }); + } + Label::Repeated => { + ser_stmts.append_all(quote! { + if __config.emit_fields_with_default_value + || !_private::is_default_value(&__self.#field_ident) + { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig(&__self.#field_ident, __config) + )?; + } + }); + } + } + } + Field::Map(map) => { + use field::map::ValueTy; + + let wrapper = match &map.value_ty { + ValueTy::Scalar(ty) => wrapper_for_ty(ty), + ValueTy::Message => quote! { _private::SerIdentity }, + }; + + ser_stmts.append_all(quote! { + if __config.emit_fields_with_default_value + || !_private::is_default_value(&__self.#field_ident) + { + _serde::ser::SerializeStruct::serialize_field( + &mut __serde_state, + #json_field_name, + &_private::SerWithConfig( + _private::SerMappedMapItems(&__self.#field_ident, #wrapper), + __config, + ) + )?; + } + }); + } + Field::Group(_) => { + // We should've replaced the group field with an equivalant message field. + unreachable!(); + } + Field::Oneof(_) => unreachable!(), + } + } else { + // Must be an oneof field. + let Field::Oneof(oneof) = field else { + unreachable!() + }; + + let oneof_ty = &oneof.ty; + ser_stmts.append_all(quote! { + if let _private::Option::Some(val) = &__self.#field_ident { + <#oneof_ty as _private::SerializeOneOf>::serialize_oneof( + val, + &mut __serde_state, + __config, + )?; + } + }); + } + } + + Ok(quote! { + impl #impl_generics _private::CustomSerialize for #struct_ident #ty_generics + #where_clause + { + fn serialize<__S>( + &self, + __serializer: __S, + __config: &_private::SerializerConfig, + ) -> _private::Result<__S::Ok, __S::Error> + where + __S: _serde::Serializer, + { + let __self = self; + + let mut __serde_state = _serde::Serializer::serialize_struct( + __serializer, + #struct_ident_str, + #num_fields, + )?; + + #ser_stmts + + _serde::ser::SerializeStruct::end(__serde_state) + } + } + + impl #impl_generics _serde::Serialize for #struct_ident #ty_generics + #where_clause + { + #[inline] + fn serialize<__S>( + &self, + __serializer: __S, + ) -> _private::Result<__S::Ok, __S::Error> + where + __S: _serde::Serializer, + { + let __config = <_private::SerializerConfig as _private::Default>::default(); + _private::CustomSerialize::serialize(self, __serializer, &__config) + } + } + }) +} + +pub fn impl_for_oneof( + oneof_ident: &Ident, + generics: &Generics, + fields: &[(Ident, Field)], +) -> Result { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let ser_match_variants = fields + .iter() + .map(|(field_ident, field)| { + let json_field_name = match &field.json().unwrap() { + Some(Json { + json_name: Some(name), + .. + }) => name.to_owned(), + Some(Json { + proto_name: Some(name), + .. + }) => name.to_proto_camel_case(), + Some(_) | None => field_ident + .to_string() + .to_snake_case() + .to_proto_camel_case(), + }; + + // Map a group field to an equivalent message field because they share the same + // serialization impl. + let remapped_group_field; + let field = if let Field::Group(group) = field { + remapped_group_field = group.to_message_field(); + &remapped_group_field + } else { + field + }; + + let arm = match field { + Field::Scalar(scalar) => { + let wrapper = wrapper_for_ty(&scalar.ty); + quote! { + Self::#field_ident(val) => __serializer.serialize_field( + #json_field_name, + &_private::SerWithConfig( + #wrapper(val), + __config, + ), + ) + } + } + Field::Message(_) => { + quote! { + Self::#field_ident(__val) => __serializer.serialize_field( + #json_field_name, + &_private::SerWithConfig( + __val, + __config, + ), + ) + } + } + Field::Group(_) => unreachable!(), + Field::Map(_) => return Err(anyhow!("unsupported map field inside oneof")), + Field::Oneof(_) => return Err(anyhow!("unsupported oneof field inside oneof")), + }; + + Ok(arm) + }) + .collect::, Error>>()?; + + Ok(quote! { + impl #impl_generics _private::SerializeOneOf for #oneof_ident #ty_generics + #where_clause + { + fn serialize_oneof<__S>( + &self, + __serializer: &mut __S, + __config: &_private::SerializerConfig, + ) -> _private::Result<(), __S::Error> + where + __S: _serde::ser::SerializeStruct, + { + match self { + #(#ser_match_variants,)* + } + } + } + }) +} + +pub fn impl_for_enum( + enum_ident: &Ident, + generics: &Generics, + variants: &[(Ident, Expr, Option)], +) -> Result { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let match_arms = variants + .iter() + .map(|(variant_ident, discr, json)| { + let json_value = match json { + Some(Json { + proto_name: Some(proto_name), + .. + }) => proto_name.to_owned(), + _ => format!("{enum_ident}_{variant_ident}").to_shouty_snake_case(), + }; + + quote! { + Self::#variant_ident => { + if __config.emit_enum_values_as_integer { + __serializer.serialize_i32(#discr) + } else { + __serializer.serialize_str(#json_value) + } + } + } + }) + .collect::>(); + + Ok(quote! { + impl #impl_generics _private::CustomSerialize for #enum_ident #ty_generics + #where_clause + { + fn serialize<__S>( + &self, + __serializer: __S, + __config: &_private::SerializerConfig, + ) -> _private::Result<__S::Ok, __S::Error> + where + __S: _serde::Serializer, + { + match self { + #(#match_arms,)* + } + } + } + + impl #impl_generics _serde::Serialize for #enum_ident #ty_generics + #where_clause + { + #[inline] + fn serialize<__S>( + &self, + __serializer: __S, + ) -> _private::Result<__S::Ok, __S::Error> + where + __S: _serde::Serializer, + { + let __config = <_private::SerializerConfig as _private::Default>::default(); + _private::CustomSerialize::serialize(self, __serializer, &__config) + } + } + }) +} + +fn wrapper_for_ty(ty: &field::scalar::Ty) -> TokenStream { + use field::scalar::Ty; + match ty { + Ty::Int32 + | Ty::Uint32 + | Ty::Sint32 + | Ty::Fixed32 + | Ty::Sfixed32 + | Ty::String + | Ty::Bool => { + quote! { _private::SerSerde } + } + Ty::Int64 | Ty::Uint64 | Ty::Sint64 | Ty::Fixed64 | Ty::Sfixed64 => { + quote! { _private::SerAsDisplay } + } + Ty::Bytes(_) => { + quote! { _private::SerBytesAsBase64 } + } + Ty::Float => { + quote! { _private::SerFloat32 } + } + Ty::Double => { + quote! { _private::SerFloat64 } + } + Ty::Enumeration(path) => { + quote! { _private::SerEnum::<#path>::new } + } + } +} diff --git a/prost-derive/src/serde/utils.rs b/prost-derive/src/serde/utils.rs new file mode 100644 index 000000000..36a893d1d --- /dev/null +++ b/prost-derive/src/serde/utils.rs @@ -0,0 +1,37 @@ +pub trait ToProtoCamelCase: ToOwned { + fn to_proto_camel_case(&self) -> Self::Owned; +} + +impl ToProtoCamelCase for str { + fn to_proto_camel_case(&self) -> Self::Owned { + // Reference: https://protobuf.com/docs/language-spec#default-json-names + // + // If no json_name pseudo-option is present, the JSON name of the field will be + // the field's name converted to camelCase. To convert to camelCase: + // + // - Discard any trailing underscores (_) + // - When a leading or interior underscore is encountered, discard the underscore and + // capitalize the next non-underscore character encountered. + // - Any other non-underscore and non-capitalized character is retained as is. + // + let mut capitalize_next = false; + let mut out = String::with_capacity(self.len()); + for chr in self.chars() { + if chr == '_' { + capitalize_next = true; + } else if capitalize_next { + out.push(chr.to_ascii_uppercase()); + capitalize_next = false; + } else { + out.push(chr); + } + } + out + } +} + +impl ToProtoCamelCase for String { + fn to_proto_camel_case(&self) -> Self::Owned { + self.as_str().to_proto_camel_case() + } +} diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index b3ee470d2..2b9b99eee 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -15,10 +15,15 @@ doctest = false [features] default = ["std"] std = ["prost/std"] +serde = ["prost/serde"] +any-v2 = ["std", "serde", "prost/serde-json", "dep:serde", "dep:erased-serde"] [dependencies] prost = { version = "0.13.3", path = "../prost", default-features = false, features = ["prost-derive"] } +erased-serde = { version = "0.4.3", optional = true } +serde = { version = "1.0.189", features = ["derive"], optional = true } + [dev-dependencies] proptest = "1" diff --git a/prost-types/src/any.rs b/prost-types/src/any.rs index af3e0e4da..e4c289f6d 100644 --- a/prost-types/src/any.rs +++ b/prost-types/src/any.rs @@ -1,7 +1,7 @@ use super::*; -impl Any { - /// Serialize the given message type `M` as [`Any`]. +impl protobuf::Any { + /// Serialize the given message type `M` as [`protobuf::Any`]. pub fn from_msg(msg: &M) -> Result where M: Name, @@ -9,10 +9,10 @@ impl Any { let type_url = M::type_url(); let mut value = Vec::new(); Message::encode(msg, &mut value)?; - Ok(Any { type_url, value }) + Ok(Self { type_url, value }) } - /// Decode the given message type `M` from [`Any`], validating that it has + /// Decode the given message type `M` from [`protobuf::Any`], validating that it has /// the expected type URL. pub fn to_msg(&self) -> Result where @@ -38,7 +38,7 @@ impl Any { } } -impl Name for Any { +impl Name for protobuf::Any { const PACKAGE: &'static str = PACKAGE; const NAME: &'static str = "Any"; @@ -54,7 +54,7 @@ mod tests { #[test] fn check_any_serialization() { let message = Timestamp::date(2000, 1, 1).unwrap(); - let any = Any::from_msg(&message).unwrap(); + let any = protobuf::Any::from_msg(&message).unwrap(); assert_eq!( &any.type_url, "type.googleapis.com/google.protobuf.Timestamp" diff --git a/prost-types/src/any_v2.rs b/prost-types/src/any_v2.rs new file mode 100644 index 000000000..7d3d8d8d3 --- /dev/null +++ b/prost-types/src/any_v2.rs @@ -0,0 +1,780 @@ +use core::{any::Any as CoreAny, cell::RefCell, fmt::Debug}; +use std::{ + collections::HashMap, + sync::{Arc, OnceLock, RwLock}, +}; + +use ::serde as _serde; +pub use prost::serde::private::JsonValue; +use prost::{ + bytes::{Buf, BufMut}, + serde::SerdeMessage, + Message, Name, +}; + +use crate::smallbox::{smallbox, SmallBox}; + +mod private { + pub trait Sealed {} +} + +pub trait AnyValue: CoreAny + Message + private::Sealed { + fn as_any(&self) -> &(dyn CoreAny + Send + Sync); + + fn as_mut_any(&mut self) -> &mut (dyn CoreAny + Send + Sync); + + fn as_message(&self) -> &(dyn Message + Send + Sync); + + fn as_mut_message(&mut self) -> &mut (dyn Message + Send + Sync); + + fn clone_value(&self) -> Box; + + fn cmp_any(&self, other: &dyn AnyValue) -> bool; + + fn as_erased_serialize<'a>( + &'a self, + config: &'a prost::serde::SerializerConfig, + ) -> SmallBox; + + fn encode_to_buf(&self, buf: &mut dyn BufMut); +} + +impl AnyValue for T { + fn as_any(&self) -> &(dyn CoreAny + Send + Sync) { + self as _ + } + + fn as_mut_any(&mut self) -> &mut (dyn CoreAny + Send + Sync) { + self as _ + } + + fn as_message(&self) -> &(dyn Message + Send + Sync) { + self as _ + } + + fn as_mut_message(&mut self) -> &mut (dyn Message + Send + Sync) { + self as _ + } + + fn clone_value(&self) -> Box { + Box::new(self.clone()) as _ + } + + fn cmp_any(&self, other: &dyn AnyValue) -> bool { + other.as_any().downcast_ref::() == Some(self) + } + + fn encode_to_buf(&self, mut buf: &mut dyn BufMut) { + self.encode_raw(&mut buf) + } + + fn as_erased_serialize<'a>( + &'a self, + config: &'a prost::serde::SerializerConfig, + ) -> SmallBox { + smallbox!(prost::serde::private::SerWithConfig(self, config)) + } +} + +impl private::Sealed for T {} + +#[derive(Debug)] +enum Inner { + Protobuf(Vec), + Json(JsonValue), + Dyn(Box), +} + +#[derive(Debug)] +pub struct ProstAny { + type_url: String, + inner: Inner, + cached: RwLock>>, +} + +#[allow(clippy::declare_interior_mutable_const)] +const CACHED_INIT: RwLock>> = RwLock::new(None); + +impl Clone for ProstAny { + fn clone(&self) -> Self { + Self { + type_url: self.type_url.clone(), + inner: match &self.inner { + Inner::Protobuf(value) => Inner::Protobuf(value.clone()), + Inner::Json(value) => Inner::Json(value.clone()), + Inner::Dyn(value) => Inner::Dyn(value.clone_value()), + }, + cached: CACHED_INIT, + } + } +} + +impl PartialEq for ProstAny { + fn eq(&self, other: &Self) -> bool { + self.type_url == other.type_url + && match (&self.inner, &other.inner) { + (Inner::Protobuf(value_a), Inner::Protobuf(value_b)) => value_a == value_b, + (Inner::Json(value_a), Inner::Json(value_b)) => value_a == value_b, + (Inner::Dyn(value_a), Inner::Dyn(value_b)) => { + AnyValue::cmp_any(&**value_a, &**value_b) + } + _ => false, + } + } +} + +impl Default for ProstAny { + fn default() -> Self { + Self { + type_url: Default::default(), + inner: Inner::Protobuf(Default::default()), + cached: CACHED_INIT, + } + } +} + +impl Name for ProstAny { + const PACKAGE: &'static str = crate::PACKAGE; + const NAME: &'static str = "Any"; + + fn type_url() -> String { + crate::type_url_for::() + } +} + +impl ProstAny { + pub fn type_url(&self) -> &str { + &self.type_url + } + + pub fn set_type_url(&mut self, type_url: String) -> &mut Self { + self.type_url = type_url; + self + } + + pub fn any_value(&self) -> &dyn AnyValue { + self.opt_any_value() + .expect("any value has not been resolved yet") + } + + pub fn mut_any_value(&mut self) -> &mut dyn AnyValue { + self.opt_mut_any_value() + .expect("any value has not been resolved yet") + } + + pub fn opt_any_value(&self) -> Option<&dyn AnyValue> { + match &self.inner { + Inner::Dyn(value) => Some(&**value), + _ => None, + } + } + + pub fn opt_mut_any_value(&mut self) -> Option<&mut dyn AnyValue> { + match &mut self.inner { + Inner::Dyn(value) => Some(&mut **value), + _ => None, + } + } + + pub fn into_any_value(self) -> Box { + self.try_into_any_value() + .expect("any value has not been resolved yet") + } + + pub fn try_into_any_value(self) -> Result, Self> { + match self.inner { + Inner::Dyn(value) => Ok(value), + _ => Err(self), + } + } + + pub fn from_msg(msg: T) -> Self + where + T: 'static + Message + SerdeMessage + Name + PartialEq + Clone, + { + Self { + type_url: T::type_url(), + inner: Inner::Dyn(Box::new(msg) as _), + cached: CACHED_INIT, + } + } + + pub fn deserialize_any( + &self, + serde_config: Option<&prost::serde::DeserializerConfig>, + ) -> Result, prost::DecodeError> { + if let Inner::Dyn(value) = &self.inner { + return Ok(value.clone_value()); + } + + let type_descriptor = self.find_type_descriptor().ok_or_else(|| { + prost::DecodeError::new(format!("unresolved type url: {}", self.type_url())) + })?; + + let default_serde_config; + let serde_config = match serde_config { + Some(config) => config, + None => { + default_serde_config = Default::default(); + &default_serde_config + } + }; + + match &self.inner { + Inner::Protobuf(value) => (type_descriptor.deserialize_protobuf)(&self.type_url, value), + Inner::Json(value) => { + (type_descriptor.deserialize_json)(&self.type_url, value, serde_config) + } + Inner::Dyn(_) => unreachable!(), + } + } + + pub fn deserialize_any_in_place<'a>( + &'a mut self, + serde_config: Option<&prost::serde::DeserializerConfig>, + ) -> Result<&'a mut dyn AnyValue, prost::DecodeError> { + // This doesn't work due to + // https://rust-lang.github.io/rfcs/2094-nll.html#problem-case-3-conditional-control-flow-across-functions. + // + // if let Inner::Dyn(value) = &mut self.inner { + // return Ok(&mut **value); + // } + // + // So have to do this weird workaround instead: + let has_inner_value = matches!(&self.inner, Inner::Dyn(_)); + if !has_inner_value { + let value = self.deserialize_any(serde_config)?; + + self.inner = Inner::Dyn(value); + self.cached = CACHED_INIT; + } + + let Inner::Dyn(value) = &mut self.inner else { + unreachable!() + }; + + Ok(&mut **value) + } + + fn find_type_descriptor(&self) -> Option { + CURRENT_TYPE_RESOLVER.with(|type_resolver| { + let type_resolver = type_resolver.borrow(); + let type_resolver = type_resolver.as_ref()?; + Some( + type_resolver + .resolve_message_type(self.type_url()) + .ok()? + .clone(), + ) + }) + } + + fn deserialize_and_cache(&self, f: F) -> Result + where + F: FnOnce(&dyn AnyValue) -> R, + { + if let Inner::Dyn(value) = &self.inner { + return Ok(f(&**value)); + } + + if let Some(value) = &*self.cached.read().unwrap() { + return Ok(f(&**value)); + } + + let value = self.deserialize_any(None)?; + let res = f(&*value); + + if let Ok(mut cached) = self.cached.try_write() { + *cached = Some(value); + } + + Ok(res) + } +} + +impl Message for ProstAny { + fn encode_raw(&self, buf: &mut impl BufMut) + where + Self: Sized, + { + if !self.type_url.is_empty() { + prost::encoding::string::encode(1u32, &self.type_url, buf); + } + + match &self.inner { + Inner::Protobuf(value) => { + if !value.is_empty() { + prost::encoding::bytes::encode(2u32, value, buf); + } + } + Inner::Dyn(value) => { + let msg_len = value.as_message().encoded_len(); + if msg_len != 0 { + prost::encoding::encode_key( + 2u32, + prost::encoding::WireType::LengthDelimited, + buf, + ); + prost::encoding::encode_varint(msg_len as u64, buf); + value.encode_to_buf(buf); + } + } + Inner::Json(_) => { + let res = self.deserialize_and_cache(|value| { + prost::encoding::encode_key( + 2u32, + prost::encoding::WireType::LengthDelimited, + buf, + ); + prost::encoding::encode_varint(value.as_message().encoded_len() as u64, buf); + value.encode_to_buf(buf); + }); + if let Err(err) = res { + panic!("unresolved any value: {}", err) + } + } + }; + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: prost::encoding::WireType, + buf: &mut impl Buf, + ctx: prost::encoding::DecodeContext, + ) -> Result<(), prost::DecodeError> + where + Self: Sized, + { + match tag { + 1u32 => { + let value = &mut self.type_url; + ::prost::encoding::string::merge(wire_type, value, buf, ctx).map_err(|mut error| { + error.push("Any", "type_url"); + error + }) + } + 2u32 => { + let value = match &mut self.inner { + Inner::Protobuf(value) => value, + inner => { + *inner = Inner::Protobuf(Default::default()); + let Inner::Protobuf(value) = inner else { + unreachable!() + }; + value + } + }; + ::prost::encoding::bytes::merge(wire_type, value, buf, ctx).map_err(|mut error| { + error.push("Any", "value"); + error + }) + } + _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx), + } + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + + if !self.type_url.is_empty() { + len += prost::encoding::string::encoded_len(1u32, &self.type_url); + } + + match &self.inner { + Inner::Protobuf(value) => { + if !value.is_empty() { + len += prost::encoding::bytes::encoded_len(2u32, value); + } + } + Inner::Dyn(value) => { + let msg = value.as_message(); + if msg.encoded_len() != 0 { + len += prost::encoding::message::encoded_len(2u32, msg); + } + } + Inner::Json(_) => { + let res = self.deserialize_and_cache(|value| { + len += prost::encoding::message::encoded_len(2u32, value.as_message()); + }); + if let Err(err) = res { + panic!("unresolved any value: {}", err) + } + } + } + + len + } + + fn clear(&mut self) { + self.type_url.clear(); + match &mut self.inner { + Inner::Protobuf(value) => value.clear(), + Inner::Dyn(value) => value.as_mut_message().clear(), + Inner::Json(_) => { + panic!("cannot clear unresolved type") + } + } + } +} + +impl prost::serde::private::CustomSerialize for ProstAny { + fn serialize( + &self, + serializer: S, + config: &prost::serde::SerializerConfig, + ) -> Result + where + S: _serde::Serializer, + { + let is_well_known_type = has_known_value_json_mapping(&self.type_url); + + #[derive(Debug, _serde::Serialize)] + struct Flattened<'a, T: ?Sized> { + #[serde(rename = "@type")] + type_url: &'a str, + #[serde(flatten)] + value: &'a T, + } + + #[derive(Debug, _serde::Serialize)] + struct Wrapped<'a, T: ?Sized> { + #[serde(rename = "@type")] + type_url: &'a str, + value: &'a T, + } + + match &self.inner { + Inner::Json(value) => { + if is_well_known_type { + _serde::Serialize::serialize( + &Wrapped { + type_url: &self.type_url, + value, + }, + serializer, + ) + } else { + _serde::Serialize::serialize( + &Flattened { + type_url: &self.type_url, + value, + }, + serializer, + ) + } + } + Inner::Dyn(value) => { + let value = &*value.as_erased_serialize(config); + if is_well_known_type { + erased_serde::serialize( + &Wrapped { + type_url: &self.type_url, + value, + }, + serializer, + ) + } else { + erased_serde::serialize( + &Flattened { + type_url: &self.type_url, + value, + }, + serializer, + ) + } + } + Inner::Protobuf(_) => match self.deserialize_any(None) { + Ok(value) => { + let value = &*value.as_erased_serialize(config); + if is_well_known_type { + erased_serde::serialize( + &Wrapped { + type_url: &self.type_url, + value, + }, + serializer, + ) + } else { + erased_serde::serialize( + &Flattened { + type_url: &self.type_url, + value, + }, + serializer, + ) + } + } + Err(err) => Err(_serde::ser::Error::custom(format!( + "failed to decode any value: {}", + err + ))), + }, + } + } +} + +impl<'de> prost::serde::private::CustomDeserialize<'de> for ProstAny { + fn deserialize( + deserializer: D, + config: &prost::serde::DeserializerConfig, + ) -> Result + where + D: _serde::Deserializer<'de>, + { + use _serde::de::{Error, Unexpected}; + + let val = ::deserialize(deserializer)?; + + let JsonValue::Object(mut obj) = val else { + return Err(D::Error::invalid_type( + Unexpected::Other("non-object value"), + &"object value", + )); + }; + let Some(JsonValue::String(type_url)) = obj.remove("@type") else { + return Err(D::Error::missing_field("@type")); + }; + + let obj = if has_known_value_json_mapping(&type_url) { + let Some(value) = obj.remove("value") else { + return Err(D::Error::missing_field("value")); + }; + + if !config.ignore_unknown_fields && !obj.is_empty() { + let unknown_key = obj + .keys() + .next() + .map(|key| key.as_str()) + .unwrap_or(""); + return Err(D::Error::unknown_field(unknown_key, &["@type", "value"])); + } + + value + } else { + JsonValue::Object(obj) + }; + + let mut res = Self { + type_url, + inner: Inner::Json(obj), + cached: CACHED_INIT, + }; + + if has_type_resolver_set() { + // Gracefully fail here and leave the `Self:;Json` variant in place. + let _ = res.deserialize_any_in_place(Some(config)); + } + + Ok(res) + } +} + +#[derive(Debug, Clone)] +pub struct TypeRegistry { + message_types: HashMap, +} + +impl TypeRegistry { + pub fn new() -> Self { + Self { + message_types: HashMap::new(), + } + } + + pub fn new_with_well_known_types() -> Self { + let mut registry = Self::new(); + registry.insert_default_well_known_types(); + registry + } + + pub fn insert_default_well_known_types(&mut self) { + self.insert_well_known_msg_type::("google.protobuf.Any"); + self.insert_well_known_msg_type::("google.protobuf.Timestamp"); + self.insert_well_known_msg_type::("google.protobuf.Duration"); + self.insert_well_known_msg_type::("google.protobuf.Struct"); + self.insert_well_known_msg_type::("google.protobuf.DoubleValue"); + self.insert_well_known_msg_type::("google.protobuf.FloatValue"); + self.insert_well_known_msg_type::("google.protobuf.Int64Value"); + self.insert_well_known_msg_type::("google.protobuf.UInt64Value"); + self.insert_well_known_msg_type::("google.protobuf.Int32Value"); + self.insert_well_known_msg_type::("google.protobuf.UInt32Value"); + self.insert_well_known_msg_type::("google.protobuf.BoolValue"); + self.insert_well_known_msg_type::("google.protobuf.StringValue"); + self.insert_well_known_msg_type::>("google.protobuf.BytesValue"); + self.insert_well_known_msg_type::>("google.protobuf.BytesValue"); + self.insert_well_known_msg_type::("google.protobuf.FieldMask"); + self.insert_well_known_msg_type::("google.protobuf.ListValue"); + self.insert_well_known_msg_type::("google.protobuf.Value"); + self.insert_well_known_msg_type::<()>("google.protobuf.Empty"); + } + + fn insert_well_known_msg_type(&mut self, type_path: &str) + where + T: 'static + Message + SerdeMessage + Default + PartialEq + Clone, + { + let _ = self.message_types.insert( + format!("type.googleapis.com/{type_path}"), + AnyTypeDescriptor::for_type::(), + ); + } + + pub fn insert_msg_type(&mut self) + where + T: 'static + Message + SerdeMessage + Name + Default + PartialEq + Clone, + { + let _ = self + .message_types + .insert(T::type_url(), AnyTypeDescriptor::for_type::()); + } + + pub fn insert_msg_type_for_type_url(&mut self, type_url: impl Into) + where + T: 'static + Message + SerdeMessage + Default + PartialEq + Clone, + { + let _ = self + .message_types + .insert(type_url.into(), AnyTypeDescriptor::for_type::()); + } + + pub fn remove_by_type_url(&mut self, type_url: &str) -> bool { + self.message_types.remove(type_url).is_some() + } + + pub fn into_type_resolver(self) -> Arc { + Arc::new(self) as _ + } +} + +impl Default for TypeRegistry { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl TypeResolver for TypeRegistry { + fn resolve_message_type<'a>( + &'a self, + type_url: &str, + ) -> Result<&'a AnyTypeDescriptor, TypeResolverError> { + self.message_types.get(type_url).ok_or(TypeResolverError) + } +} + +type DeserializeProtobufFn = fn(&str, &[u8]) -> Result, prost::DecodeError>; + +type DeserializeJsonFn = fn( + &str, + &JsonValue, + &prost::serde::DeserializerConfig, +) -> Result, prost::DecodeError>; + +#[derive(Debug, Clone)] +pub struct AnyTypeDescriptor { + deserialize_protobuf: DeserializeProtobufFn, + deserialize_json: DeserializeJsonFn, +} + +impl AnyTypeDescriptor { + pub fn for_type() -> Self + where + T: 'static + Message + SerdeMessage + Default + PartialEq + Clone, + { + fn deserialize_protobuf< + T: 'static + Message + SerdeMessage + Default + PartialEq + Clone, + >( + _type_url: &str, + data: &[u8], + ) -> Result, prost::DecodeError> { + Ok(Box::new(T::decode(data)?) as _) + } + + fn deserialize_json( + _type_url: &str, + val: &JsonValue, + config: &prost::serde::DeserializerConfig, + ) -> Result, prost::DecodeError> { + let val = config + .deserialize_from_value::(val) + .map_err(|err| prost::DecodeError::new(err.to_string()))?; + Ok(Box::new(val) as _) + } + + Self { + deserialize_protobuf: deserialize_protobuf::, + deserialize_json: deserialize_json::, + } + } +} + +pub fn default_type_resolver() -> Arc { + static DEFAULT_REGISTRY: OnceLock> = OnceLock::new(); + DEFAULT_REGISTRY + .get_or_init(|| Arc::new(TypeRegistry::new_with_well_known_types())) + .clone() +} + +#[derive(Debug)] +pub struct TypeResolverError; + +pub trait TypeResolver { + fn resolve_message_type<'a>( + &'a self, + type_url: &str, + ) -> Result<&'a AnyTypeDescriptor, TypeResolverError>; +} + +thread_local! { + static CURRENT_TYPE_RESOLVER: RefCell>> = RefCell::new(None); +} + +pub fn with_type_resolver(resolver: Option>, f: F) -> R +where + F: FnOnce() -> R, +{ + struct TypeResolverGuard(Option>); + impl Drop for TypeResolverGuard { + fn drop(&mut self) { + CURRENT_TYPE_RESOLVER.with(|current| *current.borrow_mut() = self.0.take()); + } + } + let _guard = TypeResolverGuard(CURRENT_TYPE_RESOLVER.with(|current| current.replace(resolver))); + f() +} + +pub fn with_default_type_resolver R>(f: F) -> R { + with_type_resolver(Some(default_type_resolver()), f) +} + +fn has_type_resolver_set() -> bool { + CURRENT_TYPE_RESOLVER.with(|type_resolver| type_resolver.borrow().is_some()) +} + +fn has_known_value_json_mapping(type_url: &str) -> bool { + let Some(path) = type_url.strip_prefix("type.googleapis.com/") else { + return false; + }; + + const KNOWN_PATHS: &[&str] = &[ + "google.protobuf.Any", + "google.protobuf.Timestamp", + "google.protobuf.Duration", + "google.protobuf.Struct", + "google.protobuf.DoubleValue", + "google.protobuf.FloatValue", + "google.protobuf.Int64Value", + "google.protobuf.UInt64Value", + "google.protobuf.Int32Value", + "google.protobuf.UInt32Value", + "google.protobuf.BoolValue", + "google.protobuf.StringValue", + "google.protobuf.BytesValue", + "google.protobuf.FieldMask", + "google.protobuf.ListValue", + "google.protobuf.Value", + "google.protobuf.Empty", + ]; + + KNOWN_PATHS.contains(&path) +} diff --git a/prost-types/src/datetime.rs b/prost-types/src/datetime.rs index 4c9467753..84c3e49bb 100644 --- a/prost-types/src/datetime.rs +++ b/prost-types/src/datetime.rs @@ -62,6 +62,12 @@ impl DateTime { && self.second < 60 && self.nanos < 1_000_000_000 } + + /// Returns `true` if the `DateTime` is a valid calendar date that also has a valid RFC3339 + /// representation. + pub(crate) fn is_rfc3339_valid(&self) -> bool { + self.is_valid() && self.year > 0 && self.year < 10000 + } } impl fmt::Display for DateTime { @@ -276,19 +282,29 @@ fn parse_nanos(s: &str) -> Option<(u32, &str)> { /// Parses a timezone offset in RFC 3339 format from ASCII string `s`, returning the offset hour, /// offset minute, and remaining input. -fn parse_offset(s: &str) -> Option<(i8, i8, &str)> { +fn parse_offset(s: &str, json_mode: bool) -> Option<(i8, i8, &str)> { debug_assert!(s.is_ascii()); - if s.is_empty() { - // If no timezone specified, assume UTC. - return Some((0, 0, s)); - } + let (s, z) = if json_mode { + if s.is_empty() { + return None; + } + + (s, parse_char(s, b'Z')) + } else { + if s.is_empty() { + // If no timezone specified, assume UTC. + return Some((0, 0, s)); + } - // Snowflake's timestamp format contains a space separator before the offset. - let s = parse_char(s, b' ').unwrap_or(s); + // Snowflake's timestamp format contains a space separator before the offset. + let s = parse_char(s, b' ').unwrap_or(s); - if let Some(s) = parse_char_ignore_case(s, b'Z') { - Some((0, 0, s)) + (s, parse_char_ignore_case(s, b'Z')) + }; + + if let Some(z) = z { + Some((0, 0, z)) } else { let (is_positive, s) = if let Some(s) = parse_char(s, b'+') { (true, s) @@ -487,7 +503,7 @@ pub(crate) fn year_to_seconds(year: i64) -> (i128, bool) { } /// Parses a timestamp in RFC 3339 format from `s`. -pub(crate) fn parse_timestamp(s: &str) -> Option { +pub(crate) fn parse_timestamp(s: &str, json_mode: bool) -> Option { // Check that the string is ASCII, since subsequent parsing steps use byte-level indexing. ensure!(s.is_ascii()); @@ -502,13 +518,22 @@ pub(crate) fn parse_timestamp(s: &str) -> Option { ..DateTime::default() }; + if json_mode { + ensure!(date_time.is_rfc3339_valid()); + } + return Timestamp::try_from(date_time).ok(); } - // Accept either 'T' or ' ' as delimiter between date and time. - let s = parse_char_ignore_case(s, b'T').or_else(|| parse_char(s, b' '))?; + let s = if json_mode { + // Only accept 'T' when parsing in json mode. + parse_char(s, b'T')? + } else { + // Accept either 'T' or ' ' as delimiter between date and time. + parse_char_ignore_case(s, b'T').or_else(|| parse_char(s, b' '))? + }; let (hour, minute, mut second, nanos, s) = parse_time(s)?; - let (offset_hour, offset_minute, s) = parse_offset(s)?; + let (offset_hour, offset_minute, s) = parse_offset(s, json_mode)?; ensure!(s.is_empty()); @@ -532,6 +557,10 @@ pub(crate) fn parse_timestamp(s: &str) -> Option { nanos, }; + if json_mode { + ensure!(date_time.is_rfc3339_valid()); + } + let Timestamp { seconds, nanos } = Timestamp::try_from(date_time).ok()?; let seconds = diff --git a/prost-types/src/duration.rs b/prost-types/src/duration.rs index 3ce993ee5..b36a636cf 100644 --- a/prost-types/src/duration.rs +++ b/prost-types/src/duration.rs @@ -69,6 +69,10 @@ impl Duration { result.normalize(); result } + + pub fn is_valid(&self) -> bool { + self.seconds >= -315_576_000_000 && self.seconds <= 315_576_000_000 + } } impl Name for Duration { @@ -116,7 +120,7 @@ impl TryFrom for time::Duration { impl fmt::Display for Duration { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let d = self.normalized(); - if self.seconds < 0 || self.nanos < 0 { + if self.seconds < 0 && self.nanos <= 0 || self.seconds <= 0 && self.nanos < 0 { write!(f, "-")?; } write!(f, "{}", d.seconds.abs())?; diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index a0507426b..eccbc7a84 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -9,24 +9,27 @@ //! //! ## Any //! -//! The well-known [`Any`] type contains an arbitrary serialized message along with a URL that -//! describes the type of the serialized message. Every message that also implements [`Name`] -//! can be serialized to and deserialized from [`Any`]. +//! The well-known [`Any`](protobuf::Any) type contains an arbitrary serialized message along +//! with a URL that describes the type of the serialized message. +//! Every message that also implements [`Name`] can be serialized to and deserialized +//! from [`Any`](protobuf::Any). //! //! ### Serialization //! -//! A message can be serialized using [`Any::from_msg`]. +//! A message can be serialized using [`Any::from_msg`](protobuf::Any::from_msg). //! //! ```rust +//! # use crate::protobuf::Any; //! let message = Timestamp::date(2000, 1, 1).unwrap(); //! let any = Any::from_msg(&message).unwrap(); //! ``` //! //! ### Deserialization //! -//! A message can be deserialized using [`Any::to_msg`]. +//! A message can be deserialized using [`Any::to_msg`](protobuf::Any::to_msg). //! //! ```rust +//! # use crate::protobuf::Any; //! # let message = Timestamp::date(2000, 1, 1).unwrap(); //! # let any = Any::from_msg(&message).unwrap(); //! # @@ -45,6 +48,13 @@ pub mod compiler; mod datetime; #[rustfmt::skip] mod protobuf; +#[cfg(feature = "any-v2")] +pub mod any_v2; +#[cfg(feature = "serde")] +#[doc(hidden)] +pub mod serde; +#[cfg(feature = "any-v2")] +mod smallbox; use core::convert::TryFrom; use core::fmt; @@ -58,6 +68,11 @@ use prost::{DecodeError, EncodeError, Message, Name}; pub use protobuf::*; +#[cfg(feature = "any-v2")] +pub use any_v2::ProstAny as Any; +#[cfg(feature = "any-v2")] +pub use protobuf::Any as AnyV1; + // The Protobuf `Duration` and `Timestamp` types can't delegate to the standard library equivalents // because the Protobuf versions are signed. To make them easier to work with, `From` conversions // are defined in both directions. diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs new file mode 100644 index 000000000..b4a1471b0 --- /dev/null +++ b/prost-types/src/serde.rs @@ -0,0 +1,527 @@ +use core::fmt; + +use prost::alloc::{borrow::ToOwned, collections::BTreeMap, format, string::String, vec, vec::Vec}; +use prost::serde::{ + de::{CustomDeserialize, DesWithConfig}, + private::{self, DeserializeEnum, _serde}, + ser::{CustomSerialize, SerWithConfig}, + DeserializerConfig, SerializerConfig, +}; + +use crate::{value, Duration, FieldMask, ListValue, NullValue, Struct, Timestamp, Value}; + +impl CustomSerialize for NullValue { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.serialize_none() + } +} + +impl DeserializeEnum for NullValue { + #[inline] + fn deserialize_from_i32(val: i32) -> Result, E> + where + E: _serde::de::Error, + { + Err(E::invalid_value( + _serde::de::Unexpected::Signed(val.into()), + &"a null value", + )) + } + + #[inline] + fn deserialize_from_str(val: &str) -> Result, E> + where + E: _serde::de::Error, + { + if val == "NULL_VALUE" { + Ok(Some(Self::NullValue)) + } else { + Err(E::invalid_value( + _serde::de::Unexpected::Str(val), + &"a null value", + )) + } + } + + #[inline] + fn deserialize_from_null() -> Result + where + E: _serde::de::Error, + { + Ok(Self::NullValue) + } + + #[inline] + fn can_deserialize_null() -> bool { + true + } +} + +impl CustomSerialize for Duration { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + if !self.is_valid() { + return Err(::custom(format!( + "duration is invalid: d={:?}", + self + ))); + } + private::SerAsDisplay(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for Duration { + #[inline] + fn deserialize(deserializer: D, _config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor; + + impl _serde::de::Visitor<'_> for Visitor { + type Value = Duration; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a duration string") + } + + fn visit_str(self, v: &str) -> Result + where + E: _serde::de::Error, + { + match v.parse::() { + Ok(val) if val.is_valid() => Ok(val), + Ok(_) | Err(_) => Err(E::invalid_value( + _serde::de::Unexpected::Str(v), + &"a valid duration string", + )), + } + } + } + + deserializer.deserialize_str(Visitor) + } +} + +impl CustomSerialize for Timestamp { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerAsDisplay(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for Timestamp { + #[inline] + fn deserialize(deserializer: D, _config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor; + + impl _serde::de::Visitor<'_> for Visitor { + type Value = Timestamp; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a timestamp string") + } + + fn visit_str(self, v: &str) -> Result + where + E: _serde::de::Error, + { + Timestamp::from_json_str(v).map_err(|_| { + E::invalid_value(_serde::de::Unexpected::Str(v), &"a valid timestamp string") + }) + } + } + + deserializer.deserialize_str(Visitor) + } +} + +impl CustomSerialize for FieldMask { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + let mut buf = String::with_capacity(self.paths.iter().map(|path| path.len()).sum()); + let mut paths = self.paths.iter().peekable(); + + while let Some(path) = paths.next() { + let mut path_chars = path.chars().peekable(); + + while let Some(chr) = path_chars.next() { + match chr { + 'A'..='Z' => { + return Err(::custom( + "field mask element may not have upper-case letters", + )) + } + '_' => { + let Some(next_chr) = + path_chars.next().filter(|chr| chr.is_ascii_lowercase()) + else { + return Err(::custom( + "underscore in field mask element must be followed by lower-case letter", + )); + }; + buf.push(next_chr.to_ascii_uppercase()); + } + _ => buf.push(chr), + } + } + + if paths.peek().is_some() { + buf.push(','); + } + } + + serializer.serialize_str(&buf) + } +} + +impl<'de> CustomDeserialize<'de> for FieldMask { + #[inline] + fn deserialize(deserializer: D, _config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor; + + impl _serde::de::Visitor<'_> for Visitor { + type Value = FieldMask; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a fieldmask string") + } + + fn visit_str(self, val: &str) -> Result + where + E: _serde::de::Error, + { + fn convert_path(path: &str) -> Result { + let underscores_required = + path.chars().filter(|chr| chr.is_ascii_uppercase()).count(); + + let mut buf = String::with_capacity(path.len() + underscores_required); + + for chr in path.chars() { + match chr { + 'A'..='Z' => { + buf.push('_'); + buf.push(chr.to_ascii_lowercase()); + } + '_' => return Err("field mask element may not contain underscores"), + 'a'..='z' | '0'..='9' => buf.push(chr), + _ => { + return Err( + "field mask element may not contain non ascii alphabetic letters or digits", + ) + } + } + } + + Ok(buf) + } + + let paths = val + .split(',') + .map(|path| path.trim()) + .filter(|path| !path.is_empty()) + .map(convert_path) + .collect::, _>>() + .map_err(|err| { + E::invalid_value( + _serde::de::Unexpected::Str(val), + &&*format!("a valid fieldmask string ({err})"), + ) + })?; + + Ok(FieldMask { paths }) + } + } + + deserializer.deserialize_str(Visitor) + } +} + +impl CustomSerialize for Struct { + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + use _serde::ser::SerializeMap; + + let mut map = serializer.serialize_map(Some(self.fields.len()))?; + for (key, value) in &self.fields { + map.serialize_entry(key, &SerWithConfig(value, config))?; + } + map.end() + } +} + +impl<'de> CustomDeserialize<'de> for Struct { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor<'c>(&'c DeserializerConfig); + + impl<'de> _serde::de::Visitor<'de> for Visitor<'_> { + type Value = Struct; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Struct") + } + + #[inline] + fn visit_map(self, map: A) -> Result + where + A: _serde::de::MapAccess<'de>, + { + deserialize_struct(map, self.0) + } + } + + deserializer.deserialize_map(Visitor(config)) + } +} + +impl CustomSerialize for crate::protobuf::Any { + fn serialize(&self, _serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + panic!("serializing the old prost::Any is not supported") + } +} + +impl<'de> CustomDeserialize<'de> for crate::protobuf::Any { + #[inline] + fn deserialize(_deserializer: D, _config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + panic!("deserializing the old prost::Any is not supported") + } +} + +impl CustomSerialize for Value { + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + match self.kind.as_ref() { + Some(value::Kind::NullValue(_)) | None => serializer.serialize_none(), + Some(value::Kind::NumberValue(val)) => { + if val.is_nan() || val.is_infinite() { + return Err(_serde::ser::Error::custom(format!( + "serializing a value::Kind::NumberValue, which is {val}, is not possible" + ))); + } + serializer.serialize_f64(*val) + } + Some(value::Kind::StringValue(val)) => serializer.serialize_str(val), + Some(value::Kind::BoolValue(val)) => serializer.serialize_bool(*val), + Some(value::Kind::StructValue(val)) => { + CustomSerialize::serialize(val, serializer, config) + } + Some(value::Kind::ListValue(val)) => { + CustomSerialize::serialize(val, serializer, config) + } + } + } +} + +impl<'de> CustomDeserialize<'de> for Value { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor<'c>(&'c DeserializerConfig); + + impl<'de> _serde::de::Visitor<'de> for Visitor<'_> { + type Value = Value; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Value") + } + + #[inline] + fn visit_none(self) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::NullValue(0)), + }) + } + + #[inline] + fn visit_unit(self) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::NullValue(0)), + }) + } + + #[inline] + fn visit_i64(self, v: i64) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::NumberValue(v as f64)), + }) + } + + #[inline] + fn visit_u64(self, v: u64) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::NumberValue(v as f64)), + }) + } + + #[inline] + fn visit_f64(self, v: f64) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::NumberValue(v)), + }) + } + + #[inline] + fn visit_str(self, v: &str) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::StringValue(v.to_owned())), + }) + } + + #[inline] + fn visit_bool(self, v: bool) -> Result + where + E: _serde::de::Error, + { + Ok(Value { + kind: Some(value::Kind::BoolValue(v)), + }) + } + + #[inline] + fn visit_map(self, map: A) -> Result + where + A: _serde::de::MapAccess<'de>, + { + let value = deserialize_struct(map, self.0)?; + Ok(Value { + kind: Some(value::Kind::StructValue(value)), + }) + } + + #[inline] + fn visit_seq(self, seq: A) -> Result + where + A: _serde::de::SeqAccess<'de>, + { + let value = deserialize_list_value(seq, self.0)?; + Ok(Value { + kind: Some(value::Kind::ListValue(value)), + }) + } + } + + deserializer.deserialize_any(Visitor(config)) + } + + #[inline] + fn can_deserialize_null() -> bool { + true + } +} + +impl CustomSerialize for ListValue { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.collect_seq(self.values.iter().map(|value| SerWithConfig(value, config))) + } +} + +impl<'de> CustomDeserialize<'de> for ListValue { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor<'c>(&'c DeserializerConfig); + + impl<'de> _serde::de::Visitor<'de> for Visitor<'_> { + type Value = ListValue; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a ListValue") + } + + #[inline] + fn visit_seq(self, seq: A) -> Result + where + A: _serde::de::SeqAccess<'de>, + { + deserialize_list_value(seq, self.0) + } + } + + deserializer.deserialize_seq(Visitor(config)) + } +} + +fn deserialize_list_value<'de, A>( + mut seq: A, + config: &DeserializerConfig, +) -> Result +where + A: _serde::de::SeqAccess<'de>, +{ + let mut values = vec![]; + while let Some(value) = seq.next_element_seed(DesWithConfig::::new(config))? { + values.push(value); + } + Ok(ListValue { values }) +} + +fn deserialize_struct<'de, A>(mut map: A, config: &DeserializerConfig) -> Result +where + A: _serde::de::MapAccess<'de>, +{ + let mut fields = BTreeMap::new(); + while let Some(key) = map.next_key::()? { + let value = map.next_value_seed(DesWithConfig::::new(config))?; + fields.insert(key, value); + } + Ok(Struct { fields }) +} diff --git a/prost-types/src/smallbox.rs b/prost-types/src/smallbox.rs new file mode 100644 index 000000000..9b39726b7 --- /dev/null +++ b/prost-types/src/smallbox.rs @@ -0,0 +1,139 @@ +use core::{ + alloc::Layout, + marker::PhantomData, + mem::{self, ManuallyDrop, MaybeUninit}, + ops::{Deref, DerefMut}, + ptr, +}; + +const SMALLBOX_CAP: usize = 3; + +type Storage = [usize; SMALLBOX_CAP]; + +pub struct SmallBox { + storage: MaybeUninit, + vptr: *const (), + _marker: PhantomData<*mut T>, +} + +struct Validate(PhantomData); + +impl Validate { + const IS_VALID: bool = { + assert!(mem::size_of::() <= mem::size_of::()); + assert!(mem::align_of::() == mem::align_of::()); + true + }; +} + +const fn has_same_layout() -> bool { + let lhs = Layout::new::(); + let rhs = Layout::new::(); + lhs.align() == rhs.align() && lhs.size() == rhs.size() +} + +macro_rules! smallbox { + ($val:expr) => {{ + let val = $val; + let ptr = &val as *const _; + #[allow(unsafe_code)] + unsafe { + $crate::smallbox::SmallBox::from_parts(val, ptr) + } + }}; +} + +pub(crate) use smallbox; + +impl SmallBox { + const IS_NON_DST_PTR: bool = has_same_layout::<*const T, *const ()>(); + + const IS_FAT_PTR: bool = has_same_layout::<*const T, [*const (); 2]>(); + + const IS_VALID: bool = { + assert!(Self::IS_NON_DST_PTR || Self::IS_FAT_PTR); + true + }; + + pub unsafe fn from_parts(val: S, obj: *const T) -> Self + where + S: Unpin, /* + Unsize */ + { + assert!(Validate::::IS_VALID); + assert!(Self::IS_VALID); + + let mut storage: MaybeUninit = MaybeUninit::uninit(); + ptr::write(storage.as_mut_ptr().cast::(), val); + + let mut vptr = ptr::null(); + if Self::IS_FAT_PTR { + vptr = *(&obj as *const *const T as *const *const ()).add(1); + } + + Self { + storage, + vptr, + _marker: PhantomData, + } + } + + pub fn as_mut_ptr(&mut self) -> *mut T { + let mut ptr: MaybeUninit<*mut T> = MaybeUninit::uninit(); + + unsafe { + let base = ptr.as_mut_ptr().cast::<*mut ()>(); + *base = self.storage.as_mut_ptr().cast::<()>(); + + if Self::IS_FAT_PTR { + *base.add(1).cast::<*const ()>() = self.vptr; + } + } + + unsafe { ptr.assume_init() } + } + + pub fn as_ptr(&self) -> *const T { + let mut ptr: MaybeUninit<*const T> = MaybeUninit::uninit(); + + unsafe { + let base = ptr.as_mut_ptr().cast::<*const ()>(); + *base = self.storage.as_ptr().cast::<()>(); + + if Self::IS_FAT_PTR { + *base.add(1).cast::<*const ()>() = self.vptr; + } + } + + unsafe { ptr.assume_init() } + } + + pub fn into_inner(this: Self) -> T + where + T: Sized, + { + let mut val = ManuallyDrop::new(this); + unsafe { ptr::read(val.as_mut_ptr()) } + } +} + +impl Drop for SmallBox { + fn drop(&mut self) { + unsafe { + ptr::drop_in_place(self.as_mut_ptr()); + } + } +} + +impl Deref for SmallBox { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.as_ptr() } + } +} + +impl DerefMut for SmallBox { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.as_mut_ptr() } + } +} diff --git a/prost-types/src/timestamp.rs b/prost-types/src/timestamp.rs index 1d7e609f4..ce10e55f0 100644 --- a/prost-types/src/timestamp.rs +++ b/prost-types/src/timestamp.rs @@ -112,6 +112,11 @@ impl Timestamp { Timestamp::try_from(date_time) } + + #[cfg(feature = "serde")] + pub(crate) fn from_json_str(s: &str) -> Result { + datetime::parse_timestamp(s, true).ok_or(TimestampError::ParseFailure) + } } impl Name for Timestamp { @@ -233,7 +238,7 @@ impl FromStr for Timestamp { type Err = TimestampError; fn from_str(s: &str) -> Result { - datetime::parse_timestamp(s).ok_or(TimestampError::ParseFailure) + datetime::parse_timestamp(s, false).ok_or(TimestampError::ParseFailure) } } diff --git a/prost/Cargo.toml b/prost/Cargo.toml index efd6d8cdd..288a584dd 100644 --- a/prost/Cargo.toml +++ b/prost/Cargo.toml @@ -22,10 +22,17 @@ prost-derive = ["derive"] # deprecated, please use derive feature instead no-recursion-limit = [] std = [] +serde = ["dep:serde", "dep:base64"] +serde-json = ["std", "serde", "dep:serde_json"] + [dependencies] bytes = { version = "1", default-features = false } prost-derive = { version = "0.13.3", path = "../prost-derive", optional = true } +base64 = { version = "0.21.4", default-features = false, features = ["alloc"], optional = true } +serde = { version = "1.0.189", default-features = false, features = ["alloc"], optional = true } +serde_json = { version = "1.0.107", features = ["float_roundtrip"], optional = true } + [dev-dependencies] criterion = { version = "0.5", default-features = false } proptest = "1" diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index e12455574..6abefac4d 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -844,7 +844,7 @@ pub mod message { #[inline] pub fn encoded_len(tag: u32, msg: &M) -> usize where - M: Message, + M: ?Sized + Message, { let len = msg.encoded_len(); key_len(tag) + encoded_len_varint(len as u64) + len diff --git a/prost/src/lib.rs b/prost/src/lib.rs index efdfbc5c1..1e0485834 100644 --- a/prost/src/lib.rs +++ b/prost/src/lib.rs @@ -14,6 +14,9 @@ mod message; mod name; mod types; +#[cfg(feature = "serde")] +pub mod serde; + #[doc(hidden)] pub mod encoding; diff --git a/prost/src/serde/de.rs b/prost/src/serde/de.rs new file mode 100644 index 000000000..97b5a75a9 --- /dev/null +++ b/prost/src/serde/de.rs @@ -0,0 +1,234 @@ +use alloc::{boxed::Box, format}; +use core::marker::PhantomData; +use serde::{de::DeserializeSeed, Deserializer}; + +use super::DeserializerConfig; + +mod bytes; +mod default; +mod r#enum; +mod forward; +mod map; +mod message; +mod oneof; +mod option; +mod scalar; +mod vec; + +/// This is an extended and cut-down version of serde's [serde::Deserialize]. +/// +/// The main changes are: +/// - the addition of an additional argument `config` ([DeserializerConfig]). Deserializers can +/// use that to change their deserialization behavior. +/// - the `can_deserialize_null` method. +/// +pub trait CustomDeserialize<'de>: Sized { + /// Deserialize `Self` from the given `deserializer` and `config`. + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: serde::Deserializer<'de>; + + /// By default this impl doesn't support deserializing from `null` values. + #[inline] + fn can_deserialize_null() -> bool { + false + } +} + +impl<'de, T> CustomDeserialize<'de> for Box +where + T: CustomDeserialize<'de>, +{ + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: serde::Deserializer<'de>, + { + let val = ::deserialize(deserializer, config)?; + Ok(Box::new(val)) + } +} + +// FIXME: Make `T` contravariant, not covariant, by changing the `T` in `PhantomData` to +// `fn() -> T`. +pub struct DesWithConfig<'c, T>(pub &'c DeserializerConfig, PhantomData T>); + +impl<'c, T> DesWithConfig<'c, T> { + #[inline] + pub fn new(config: &'c DeserializerConfig) -> Self { + Self(config, PhantomData) + } +} + +impl<'de, T> serde::de::DeserializeSeed<'de> for DesWithConfig<'_, T> +where + T: CustomDeserialize<'de>, +{ + type Value = T; + + #[inline] + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + >::deserialize(deserializer, self.0) + } +} + +#[derive(Debug)] +pub enum MaybeDeserializedValue { + Val(T), + UnknownEnumValue, +} + +impl MaybeDeserializedValue { + #[inline] + pub fn map(self, f: impl FnOnce(T) -> R) -> MaybeDeserializedValue { + match self { + Self::Val(val) => MaybeDeserializedValue::Val(f(val)), + Self::UnknownEnumValue => MaybeDeserializedValue::UnknownEnumValue, + } + } + + #[inline] + pub fn unwrap_for_field( + self, + config: &DeserializerConfig, + field_name: &'static str, + ) -> Result + where + E: serde::de::Error, + T: Default, + { + match self { + Self::Val(val) => Ok(val), + Self::UnknownEnumValue if config.ignore_unknown_enum_string_values => Ok(T::default()), + Self::UnknownEnumValue => Err(E::custom(format!( + "found an unknown enum value at field `{field_name}`" + ))), + } + } + + #[inline] + pub fn unwrap_for_omittable( + self, + config: &DeserializerConfig, + location: &'static str, + ) -> Result, E> + where + E: serde::de::Error, + { + match self { + Self::Val(val) => Ok(Some(val)), + Self::UnknownEnumValue if config.ignore_unknown_enum_string_values => Ok(None), + Self::UnknownEnumValue => Err(E::custom(format!( + "found an unknown enum value `{location}`" + ))), + } + } +} + +impl From for MaybeDeserializedValue { + #[inline] + fn from(val: T) -> Self { + Self::Val(val) + } +} + +pub trait DeserializeInto { + fn deserialize_into<'de, D: Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result; + + fn maybe_deserialize_into<'de, D: Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + Self::deserialize_into(deserializer, config).map(MaybeDeserializedValue::Val) + } + + #[inline] + fn can_deserialize_null() -> bool { + false + } +} + +pub struct DesIntoWithConfig<'c, W, T>(pub &'c DeserializerConfig, PhantomData<(W, T)>); + +impl<'c, W, T> DesIntoWithConfig<'c, W, T> { + #[inline] + pub fn new(config: &'c DeserializerConfig) -> Self { + Self(config, PhantomData) + } +} + +impl<'de, W, T> DeserializeSeed<'de> for DesIntoWithConfig<'_, W, T> +where + W: DeserializeInto, +{ + type Value = T; + + #[inline] + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + >::deserialize_into(deserializer, self.0) + } +} + +pub struct MaybeDesIntoWithConfig<'c, W, T>(pub &'c DeserializerConfig, PhantomData<(W, T)>); + +impl<'c, W, T> MaybeDesIntoWithConfig<'c, W, T> { + #[inline] + pub fn new(config: &'c DeserializerConfig) -> Self { + Self(config, PhantomData) + } +} + +impl<'de, W, T> DeserializeSeed<'de> for MaybeDesIntoWithConfig<'_, W, T> +where + W: DeserializeInto, +{ + type Value = MaybeDeserializedValue; + + #[inline] + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + >::maybe_deserialize_into(deserializer, self.0) + } +} + +// Re-export all deserializers. +// FIXME: Remove the `self::` when we've bumped the MSRV to 1.72. +pub use self::bytes::BytesDeserializer; +pub use default::DefaultDeserializer; +pub use forward::ForwardDeserializer; +pub use map::MapDeserializer; +pub use message::MessageDeserializer; +pub use oneof::{DeserializeOneOf, OneOfDeserializer}; +pub use option::{NullDeserializer, OptionDeserializer}; +pub use r#enum::{DeserializeEnum, EnumDeserializer}; +pub use scalar::{BoolDeserializer, FloatDeserializer, IntDeserializer}; +pub use vec::VecDeserializer; + +mod size_hint { + use core::{cmp, mem}; + + #[inline] + pub fn cautious(hint: Option) -> usize { + const MAX_PREALLOC_BYTES: usize = 1024 * 1024; + + if mem::size_of::() == 0 { + 0 + } else { + cmp::min( + hint.unwrap_or(0), + MAX_PREALLOC_BYTES / mem::size_of::(), + ) + } + } +} diff --git a/prost/src/serde/de/bytes.rs b/prost/src/serde/de/bytes.rs new file mode 100644 index 000000000..de9e75b00 --- /dev/null +++ b/prost/src/serde/de/bytes.rs @@ -0,0 +1,63 @@ +use alloc::vec::Vec; +use core::{fmt, marker::PhantomData}; + +use super::{DeserializeInto, DeserializerConfig}; + +pub struct BytesDeserializer; + +impl DeserializeInto for BytesDeserializer +where + T: From>, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor(PhantomData); + + impl serde::de::Visitor<'_> for Visitor + where + T: From>, + { + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a base64 encoded string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + use base64::Engine; + + let err = match base64::prelude::BASE64_STANDARD.decode(v) { + Ok(val) => return Ok(T::from(val)), + Err(err) => err, + }; + if let base64::DecodeError::InvalidByte(_, b'-' | b'_') = err { + static ENGINE: base64::engine::GeneralPurpose = + base64::engine::GeneralPurpose::new( + &base64::alphabet::URL_SAFE, + base64::engine::GeneralPurposeConfig::new() + .with_decode_allow_trailing_bits(true) + .with_decode_padding_mode( + base64::engine::DecodePaddingMode::RequireNone, + ), + ); + if let Ok(val) = ENGINE.decode(v) { + return Ok(T::from(val)); + } + } + + Err(E::invalid_value( + serde::de::Unexpected::Str(v), + &"a valid base64 encoded string", + )) + } + } + + deserializer.deserialize_any(Visitor(PhantomData)) + } +} diff --git a/prost/src/serde/de/default.rs b/prost/src/serde/de/default.rs new file mode 100644 index 000000000..b7b4a8047 --- /dev/null +++ b/prost/src/serde/de/default.rs @@ -0,0 +1,29 @@ +use core::marker::PhantomData; + +use super::{DeserializeInto, DeserializerConfig, MaybeDeserializedValue, OptionDeserializer}; + +pub struct DefaultDeserializer(PhantomData); + +impl DeserializeInto for DefaultDeserializer +where + W: DeserializeInto, + T: Default, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result { + let val: Option = OptionDeserializer::::deserialize_into(deserializer, config)?; + Ok(val.unwrap_or_default()) + } + + fn maybe_deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + let val: MaybeDeserializedValue> = + OptionDeserializer::::maybe_deserialize_into(deserializer, config)?; + Ok(val.map(|val| val.unwrap_or_default())) + } +} diff --git a/prost/src/serde/de/enum.rs b/prost/src/serde/de/enum.rs new file mode 100644 index 000000000..45eb26c1f --- /dev/null +++ b/prost/src/serde/de/enum.rs @@ -0,0 +1,174 @@ +use alloc::borrow::{Cow, ToOwned}; +use core::{fmt, marker::PhantomData}; + +use super::{DeserializeInto, DeserializerConfig, MaybeDeserializedValue}; + +pub trait DeserializeEnum: Sized + Into { + fn deserialize_from_i32(val: i32) -> Result, E> + where + E: serde::de::Error; + + fn deserialize_from_str(val: &str) -> Result, E> + where + E: serde::de::Error; + + #[inline] + fn deserialize_from_null() -> Result + where + E: serde::de::Error, + { + Err(E::invalid_value( + serde::de::Unexpected::Option, + &"a valid enum value", + )) + } + + #[inline] + fn can_deserialize_null() -> bool { + false + } +} + +pub struct EnumDeserializer(PhantomData); + +impl DeserializeInto for EnumDeserializer +where + T: DeserializeEnum, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result { + match deserializer.deserialize_any(EnumVisitor::(config, PhantomData))? { + Ok(val) => Ok(val.into()), + Err(UnknownEnumValue::Int(val)) => Ok(val), + Err(UnknownEnumValue::Str(val)) => Err(::invalid_value( + serde::de::Unexpected::Str(&val), + &"a valid enum value", + )), + } + } + + #[inline] + fn maybe_deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + match deserializer.deserialize_any(EnumVisitor::(config, PhantomData))? { + Ok(val) => Ok(MaybeDeserializedValue::Val(val.into())), + Err(UnknownEnumValue::Int(val)) => Ok(MaybeDeserializedValue::Val(val)), + Err(UnknownEnumValue::Str(_)) => Ok(MaybeDeserializedValue::UnknownEnumValue), + } + } + + #[inline] + fn can_deserialize_null() -> bool { + T::can_deserialize_null() + } +} + +#[derive(Debug)] +enum UnknownEnumValue<'de> { + Int(i32), + Str(Cow<'de, str>), +} + +struct EnumVisitor<'c, E>(&'c DeserializerConfig, PhantomData); + +impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'_, T> +where + T: DeserializeEnum, +{ + type Value = Result>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an enum") + } + + #[inline] + fn visit_i32(self, v: i32) -> Result + where + E: serde::de::Error, + { + let val = T::deserialize_from_i32(v)?; + match val { + Some(val) => Ok(Ok(val)), + None if self.0.deny_unknown_enum_values => Err(E::invalid_value( + serde::de::Unexpected::Signed(v.into()), + &"a valid enum value", + )), + None => Ok(Err(UnknownEnumValue::Int(v))), + } + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + self.visit_i32(v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Signed(v), &"a valid enum value") + })?) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + self.visit_i32(v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Unsigned(v), &"a valid enum value") + })?) + } + + #[inline] + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + let val = T::deserialize_from_str(v)?; + match val { + Some(val) => Ok(Ok(val)), + None if self.0.ignore_unknown_enum_string_values => { + Ok(Err(UnknownEnumValue::Str(Cow::Owned(v.to_owned())))) + } + None => Err(E::invalid_value( + serde::de::Unexpected::Str(v), + &"a valid enum value", + )), + } + } + + #[inline] + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: serde::de::Error, + { + let val = T::deserialize_from_str(v)?; + match val { + Some(val) => Ok(Ok(val)), + None if self.0.ignore_unknown_enum_string_values => { + Ok(Err(UnknownEnumValue::Str(Cow::Borrowed(v)))) + } + None => Err(E::invalid_value( + serde::de::Unexpected::Str(v), + &"a valid enum value", + )), + } + } + + #[inline] + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + T::deserialize_from_null().map(Ok) + } + + #[inline] + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + T::deserialize_from_null().map(Ok) + } +} diff --git a/prost/src/serde/de/forward.rs b/prost/src/serde/de/forward.rs new file mode 100644 index 000000000..80d6c8029 --- /dev/null +++ b/prost/src/serde/de/forward.rs @@ -0,0 +1,16 @@ +use super::{DeserializeInto, DeserializerConfig}; + +pub struct ForwardDeserializer; + +impl DeserializeInto for ForwardDeserializer +where + T: for<'de> serde::Deserialize<'de>, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + ::deserialize(deserializer) + } +} diff --git a/prost/src/serde/de/map.rs b/prost/src/serde/de/map.rs new file mode 100644 index 000000000..099494407 --- /dev/null +++ b/prost/src/serde/de/map.rs @@ -0,0 +1,101 @@ +use core::{fmt, marker::PhantomData}; + +use super::{DesIntoWithConfig, DeserializeInto, DeserializerConfig, MaybeDesIntoWithConfig}; + +pub struct MapDeserializer(PhantomData<(KD, VD)>); + +#[cfg(feature = "std")] +impl DeserializeInto> for MapDeserializer +where + K: Eq + core::hash::Hash, + KD: DeserializeInto, + VD: DeserializeInto, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + struct Visitor<'c, K, V, KD, VD>(&'c DeserializerConfig, PhantomData<(K, V, KD, VD)>); + + impl<'de, K, V, KD, VD> serde::de::Visitor<'de> for Visitor<'_, K, V, KD, VD> + where + K: Eq + core::hash::Hash, + KD: DeserializeInto, + VD: DeserializeInto, + { + type Value = std::collections::HashMap; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let capacity = super::size_hint::cautious::<(K, V)>(map.size_hint()); + let mut inner = std::collections::HashMap::with_capacity(capacity); + + while let Some(key) = map.next_key_seed(DesIntoWithConfig::::new(self.0))? { + let val = map.next_value_seed(MaybeDesIntoWithConfig::::new(self.0))?; + let Some(val) = val.unwrap_for_omittable::(self.0, "in map")? else { + continue; + }; + inner.insert(key, val); + } + + Ok(inner) + } + } + + deserializer.deserialize_map(Visitor::(config, PhantomData)) + } +} + +impl DeserializeInto> for MapDeserializer +where + K: Ord, + KD: DeserializeInto, + VD: DeserializeInto, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + struct Visitor<'c, K, V, KD, VD>(&'c DeserializerConfig, PhantomData<(K, V, KD, VD)>); + + impl<'de, K, V, KD, VD> serde::de::Visitor<'de> for Visitor<'_, K, V, KD, VD> + where + K: Ord, + KD: DeserializeInto, + VD: DeserializeInto, + { + type Value = alloc::collections::BTreeMap; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut inner = alloc::collections::BTreeMap::new(); + + while let Some(key) = map.next_key_seed(DesIntoWithConfig::::new(self.0))? { + let val = map.next_value_seed(MaybeDesIntoWithConfig::::new(self.0))?; + let Some(val) = val.unwrap_for_omittable::(self.0, "in map")? else { + continue; + }; + inner.insert(key, val); + } + + Ok(inner) + } + } + + deserializer.deserialize_map(Visitor::(config, PhantomData)) + } +} diff --git a/prost/src/serde/de/message.rs b/prost/src/serde/de/message.rs new file mode 100644 index 000000000..28b2e4b6c --- /dev/null +++ b/prost/src/serde/de/message.rs @@ -0,0 +1,21 @@ +use super::{CustomDeserialize, DeserializeInto, DeserializerConfig}; + +pub struct MessageDeserializer; + +impl DeserializeInto for MessageDeserializer +where + T: for<'de> CustomDeserialize<'de>, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result { + CustomDeserialize::deserialize(deserializer, config) + } + + #[inline] + fn can_deserialize_null() -> bool { + T::can_deserialize_null() + } +} diff --git a/prost/src/serde/de/oneof.rs b/prost/src/serde/de/oneof.rs new file mode 100644 index 000000000..ca839b7e3 --- /dev/null +++ b/prost/src/serde/de/oneof.rs @@ -0,0 +1,34 @@ +use super::DeserializerConfig; + +pub trait DeserializeOneOf: Sized { + type FieldKey; + + fn deserialize_field_key(val: &str) -> Option; + + fn deserialize_by_field_key<'de, D>( + field_key: Self::FieldKey, + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> + where + D: serde::Deserializer<'de>; +} + +pub struct OneOfDeserializer<'c, T>(pub T::FieldKey, pub &'c DeserializerConfig) +where + T: DeserializeOneOf; + +impl<'de, T> serde::de::DeserializeSeed<'de> for OneOfDeserializer<'_, T> +where + T: DeserializeOneOf, +{ + type Value = Option; + + #[inline] + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + T::deserialize_by_field_key(self.0, deserializer, self.1) + } +} diff --git a/prost/src/serde/de/option.rs b/prost/src/serde/de/option.rs new file mode 100644 index 000000000..23f17962e --- /dev/null +++ b/prost/src/serde/de/option.rs @@ -0,0 +1,155 @@ +use core::{fmt, marker::PhantomData}; + +use super::{DeserializeInto, DeserializerConfig, MaybeDeserializedValue}; + +pub struct OptionDeserializer(PhantomData); + +impl DeserializeInto> for OptionDeserializer +where + I: DeserializeInto, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + struct Visitor<'c, T, I>(&'c DeserializerConfig, PhantomData<(T, I)>); + + impl<'de, T, I> serde::de::Visitor<'de> for Visitor<'_, T, I> + where + I: DeserializeInto, + { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "an option") + } + + #[inline] + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + #[inline] + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + #[inline] + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + I::deserialize_into(deserializer, self.0).map(Some) + } + } + + if I::can_deserialize_null() { + Ok(Some(I::deserialize_into(deserializer, config)?)) + } else { + deserializer.deserialize_option(Visitor::(config, PhantomData)) + } + } + + fn maybe_deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result>, D::Error> { + struct Visitor<'c, T, I>(&'c DeserializerConfig, PhantomData<(T, I)>); + + impl<'de, T, I> serde::de::Visitor<'de> for Visitor<'_, T, I> + where + I: DeserializeInto, + { + type Value = MaybeDeserializedValue>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "an option") + } + + #[inline] + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(MaybeDeserializedValue::Val(None)) + } + + #[inline] + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(MaybeDeserializedValue::Val(None)) + } + + #[inline] + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(I::maybe_deserialize_into(deserializer, self.0)?.map(Some)) + } + } + + if I::can_deserialize_null() { + Ok(I::maybe_deserialize_into(deserializer, config)?.map(Some)) + } else { + deserializer.deserialize_option(Visitor::(config, PhantomData)) + } + } + + #[inline] + fn can_deserialize_null() -> bool { + true + } +} + +pub struct NullDeserializer; + +impl DeserializeInto<()> for NullDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result<(), D::Error> { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a null value") + } + + #[inline] + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(()) + } + + #[inline] + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(()) + } + } + + deserializer.deserialize_option(Visitor) + } + + #[inline] + fn can_deserialize_null() -> bool { + true + } +} diff --git a/prost/src/serde/de/scalar.rs b/prost/src/serde/de/scalar.rs new file mode 100644 index 000000000..5c851db53 --- /dev/null +++ b/prost/src/serde/de/scalar.rs @@ -0,0 +1,571 @@ +use core::fmt; + +use super::{DeserializeInto, DeserializerConfig}; + +pub struct BoolDeserializer; + +impl DeserializeInto for BoolDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = bool; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a boolean value") + } + + #[inline] + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + #[inline] + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + if PARSE_STR { + match v { + "true" => return Ok(true), + "false" => return Ok(false), + _ => (), + } + } + Err(E::invalid_type( + serde::de::Unexpected::Str(v), + &"a valid boolean value", + )) + } + } + + deserializer.deserialize_any(Visitor::) + } +} + +pub struct IntDeserializer; + +impl DeserializeInto for IntDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = i32; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a numeric value (i32)") + } + + #[inline] + fn visit_i8(self, v: i8) -> Result + where + E: serde::de::Error, + { + Ok(v as i32) + } + + #[inline] + fn visit_i16(self, v: i16) -> Result + where + E: serde::de::Error, + { + Ok(v as i32) + } + + #[inline] + fn visit_i32(self, v: i32) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Signed(v), &"a valid integer (i32)") + }) + } + + #[inline] + fn visit_u8(self, v: u8) -> Result + where + E: serde::de::Error, + { + Ok(v as i32) + } + + #[inline] + fn visit_u16(self, v: u16) -> Result + where + E: serde::de::Error, + { + Ok(v as i32) + } + + #[inline] + fn visit_u32(self, v: u32) -> Result + where + E: serde::de::Error, + { + self.visit_u64(v as u64) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Unsigned(v), &"a valid integer (i32)") + }) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + let conv = v as i32; + if conv as f64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Float(v), + &"a valid integer (i32)", + )) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse::().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Str(v), &"a valid integer (i32)") + }) + } + } + + deserializer.deserialize_any(Visitor) + } +} + +impl DeserializeInto for IntDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = i64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a numeric value (i64)") + } + + #[inline] + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + #[inline] + fn visit_u8(self, v: u8) -> Result + where + E: serde::de::Error, + { + Ok(v as i64) + } + + #[inline] + fn visit_u16(self, v: u16) -> Result + where + E: serde::de::Error, + { + Ok(v as i64) + } + + #[inline] + fn visit_u32(self, v: u32) -> Result + where + E: serde::de::Error, + { + Ok(v as i64) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Unsigned(v), &"a valid integer (i64)") + }) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + let conv = v as i64; + if conv as f64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Float(v), + &"a valid integer (i64)", + )) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse::().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Str(v), &"a valid integer (i64)") + }) + } + } + + deserializer.deserialize_any(Visitor) + } +} + +impl DeserializeInto for IntDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = u32; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a numeric value (u32)") + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Signed(v), &"a valid integer (u32)") + }) + } + + #[inline] + fn visit_u8(self, v: u8) -> Result + where + E: serde::de::Error, + { + Ok(v as u32) + } + + #[inline] + fn visit_u16(self, v: u16) -> Result + where + E: serde::de::Error, + { + Ok(v as u32) + } + + #[inline] + fn visit_u32(self, v: u32) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Unsigned(v), &"a valid integer (u32)") + }) + } + + #[inline] + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + let conv = v as u32; + if conv as f64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Float(v), + &"a valid integer (u32)", + )) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse::().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Str(v), &"a valid integer (u32)") + }) + } + } + + deserializer.deserialize_any(Visitor) + } +} + +impl DeserializeInto for IntDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = u64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a numeric value (u64)") + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + v.try_into().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Signed(v), &"a valid integer (u64)") + }) + } + + #[inline] + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + #[inline] + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + let conv = v as u64; + if conv as f64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Float(v), + &"a valid integer (u64)", + )) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse::().map_err(|_| { + E::invalid_value(serde::de::Unexpected::Str(v), &"a valid integer (u32)") + }) + } + } + + deserializer.deserialize_any(Visitor) + } +} + +pub struct FloatDeserializer; + +impl DeserializeInto for FloatDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = f32; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a float (f32)") + } + + #[inline] + fn visit_f32(self, v: f32) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + #[inline] + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + let conv = v as f32; + if conv.is_finite() { + Ok(v as f32) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Float(v), + &"a floating point number (f32)", + )) + } + } + + #[inline] + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + let conv = v as f32; + if conv as i64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Signed(v), + &"a floating point number (f32)", + )) + } + } + + #[inline] + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + let conv = v as f32; + if conv as u64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Unsigned(v), + &"a floating point number (f32)", + )) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + "NaN" => Ok(f32::NAN), + "Infinity" => Ok(f32::INFINITY), + "-Infinity" => Ok(f32::NEG_INFINITY), + v => match v.parse::() { + Ok(v) if !v.is_infinite() => Ok(v), + _ => Err(E::invalid_value( + serde::de::Unexpected::Str(v), + &"a floating point number (f32)", + )), + }, + } + } + } + + deserializer.deserialize_any(Visitor) + } +} + +impl DeserializeInto for FloatDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result { + struct Visitor; + + impl serde::de::Visitor<'_> for Visitor { + type Value = f64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a float (f64)") + } + + #[inline] + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + Ok(v) + } + + #[inline] + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + let conv = v as f64; + if conv as i64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Signed(v), + &"a floating point number (f64)", + )) + } + } + + #[inline] + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + let conv = v as f64; + if conv as u64 == v { + Ok(conv) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Unsigned(v), + &"a floating point number (f64)", + )) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + "NaN" => Ok(f64::NAN), + "Infinity" => Ok(f64::INFINITY), + "-Infinity" => Ok(f64::NEG_INFINITY), + v => match v.parse::() { + Ok(v) if !v.is_infinite() => Ok(v), + _ => Err(E::invalid_value( + serde::de::Unexpected::Str(v), + &"a floating point number (f64)", + )), + }, + } + } + } + + deserializer.deserialize_any(Visitor) + } +} diff --git a/prost/src/serde/de/vec.rs b/prost/src/serde/de/vec.rs new file mode 100644 index 000000000..e426ef291 --- /dev/null +++ b/prost/src/serde/de/vec.rs @@ -0,0 +1,51 @@ +use alloc::vec::Vec; +use core::{fmt, marker::PhantomData}; + +use super::{DeserializeInto, DeserializerConfig, MaybeDesIntoWithConfig}; + +pub struct VecDeserializer(PhantomData); + +impl DeserializeInto> for VecDeserializer +where + W: DeserializeInto, +{ + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + config: &DeserializerConfig, + ) -> Result, D::Error> { + struct Visitor<'c, W, T>(&'c DeserializerConfig, PhantomData<(W, T)>); + + impl<'de, W, T> serde::de::Visitor<'de> for Visitor<'_, W, T> + where + W: DeserializeInto, + { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let capacity = super::size_hint::cautious::(seq.size_hint()); + let mut values = Vec::::with_capacity(capacity); + + while let Some(val) = + seq.next_element_seed(MaybeDesIntoWithConfig::::new(self.0))? + { + let Some(val) = val.unwrap_for_omittable(self.0, "in repeated field")? else { + continue; + }; + values.push(val); + } + + Ok(values) + } + } + + deserializer.deserialize_seq(Visitor::(config, PhantomData)) + } +} diff --git a/prost/src/serde/mod.rs b/prost/src/serde/mod.rs new file mode 100644 index 000000000..703caf0e5 --- /dev/null +++ b/prost/src/serde/mod.rs @@ -0,0 +1,299 @@ +use core::marker::PhantomData; + +use serde::{de::DeserializeSeed, Serialize}; + +use private::{CustomDeserialize, CustomSerialize}; + +#[doc(hidden)] +pub mod private; + +#[doc(hidden)] +pub mod ser; + +#[doc(hidden)] +pub mod de; + +#[doc(hidden)] +pub mod types; + +pub trait SerdeMessage: CustomSerialize + for<'de> CustomDeserialize<'de> {} + +impl SerdeMessage for T where T: CustomSerialize + for<'de> CustomDeserialize<'de> {} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub struct SerializerConfig { + pub emit_fields_with_default_value: bool, + pub emit_nulled_optional_fields: bool, + pub emit_enum_values_as_integer: bool, + pub use_proto_name: bool, +} + +impl SerializerConfig { + #[inline] + pub fn with<'a, T>(&'a self, val: &'a T) -> WithSerializerConfig<'a, T> { + WithSerializerConfig { + inner: val, + config: self, + } + } +} + +#[derive(Debug)] +pub struct WithSerializerConfig<'a, T> { + inner: &'a T, + config: &'a SerializerConfig, +} + +impl WithSerializerConfig<'_, T> +where + T: private::CustomSerialize, +{ + #[inline] + pub fn config(&self) -> &SerializerConfig { + self.config + } + + #[cfg(feature = "serde-json")] + #[inline] + pub fn to_string(self) -> Result { + serde_json::to_string(&self) + } + + #[cfg(feature = "serde-json")] + #[inline] + pub fn to_string_pretty(self) -> Result { + serde_json::to_string_pretty(&self) + } + + #[cfg(feature = "serde-json")] + #[inline] + pub fn to_vec(self) -> Result, serde_json::Error> { + serde_json::to_vec(&self) + } + + #[cfg(feature = "serde-json")] + #[inline] + pub fn to_vec_pretty(self) -> Result, serde_json::Error> { + serde_json::to_vec_pretty(&self) + } + + #[cfg(feature = "serde-json")] + #[inline] + pub fn to_writer(self, writer: W) -> Result<(), serde_json::Error> { + serde_json::to_writer_pretty(writer, &self) + } + + #[cfg(feature = "serde-json")] + #[inline] + pub fn to_writer_pretty(self, writer: W) -> Result<(), serde_json::Error> { + serde_json::to_writer_pretty(writer, &self) + } +} + +impl Serialize for WithSerializerConfig<'_, T> +where + T: private::CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + private::CustomSerialize::serialize(self.inner, serializer, self.config) + } +} + +pub trait WithSerializerConfigExt: Sized { + #[inline] + fn with_config<'a>(&'a self, config: &'a SerializerConfig) -> WithSerializerConfig<'a, Self> { + WithSerializerConfig { + inner: self, + config, + } + } +} + +impl WithSerializerConfigExt for T where T: private::CustomSerialize {} + +#[derive(Debug, Clone)] +pub struct SerializerConfigBuilder { + config: SerializerConfig, +} + +impl SerializerConfigBuilder { + #[inline] + pub fn new() -> Self { + Self { + config: Default::default(), + } + } + + #[inline] + pub fn emit_fields_with_default_value(mut self, emit: bool) -> Self { + self.config.emit_fields_with_default_value = emit; + self + } + + #[inline] + pub fn emit_nulled_optional_fields(mut self, emit: bool) -> Self { + self.config.emit_nulled_optional_fields = emit; + self + } + + #[inline] + pub fn emit_enum_values_as_integer(mut self, emit: bool) -> Self { + self.config.emit_enum_values_as_integer = emit; + self + } + + #[inline] + pub fn use_proto_name(mut self, enabled: bool) -> Self { + self.config.use_proto_name = enabled; + self + } + + #[inline] + pub fn build(self) -> SerializerConfig { + self.config + } +} + +impl Default for SerializerConfigBuilder { + #[inline] + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub struct DeserializerConfig { + pub ignore_unknown_fields: bool, + pub ignore_unknown_enum_string_values: bool, + pub deny_unknown_enum_values: bool, +} + +impl DeserializerConfig { + #[cfg(feature = "serde-json")] + pub fn deserialize_from_str(&self, val: &str) -> Result + where + T: for<'de> private::CustomDeserialize<'de>, + { + let mut deserializer = serde_json::Deserializer::from_str(val); + let val = ::deserialize(&mut deserializer, self)?; + deserializer.end()?; + Ok(val) + } + + #[cfg(feature = "serde-json")] + pub fn deserialize_from_slice(&self, val: &[u8]) -> Result + where + T: for<'de> private::CustomDeserialize<'de>, + { + let mut deserializer = serde_json::Deserializer::from_slice(val); + let val = ::deserialize(&mut deserializer, self)?; + deserializer.end()?; + Ok(val) + } + + #[cfg(feature = "serde-json")] + pub fn deserialize_from_reader(&self, val: R) -> Result + where + R: std::io::Read, + T: for<'de> private::CustomDeserialize<'de>, + { + let mut deserializer = serde_json::Deserializer::from_reader(val); + let val = ::deserialize(&mut deserializer, self)?; + deserializer.end()?; + Ok(val) + } + + #[cfg(feature = "serde-json")] + pub fn deserialize_from_value( + &self, + value: &serde_json::Value, + ) -> Result + where + T: for<'de> private::CustomDeserialize<'de>, + { + ::deserialize(value, self) + } + + #[inline] + pub fn with(self) -> WithDeserializerConfig + where + T: for<'de> CustomDeserialize<'de>, + { + WithDeserializerConfig { + config: self, + _for: PhantomData, + } + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub struct WithDeserializerConfig { + config: DeserializerConfig, + _for: PhantomData, +} + +impl<'de, T> DeserializeSeed<'de> for WithDeserializerConfig +where + T: CustomDeserialize<'de>, +{ + type Value = T; + + #[inline] + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + ::deserialize(deserializer, &self.config) + } +} + +#[derive(Debug, Clone)] +pub struct DeserializerConfigBuilder { + config: DeserializerConfig, +} + +impl DeserializerConfigBuilder { + #[inline] + pub fn new() -> Self { + Self { + config: Default::default(), + } + } + + #[inline] + pub fn ignore_unknown_fields(mut self, ignore: bool) -> Self { + self.config.ignore_unknown_fields = ignore; + self + } + + #[inline] + pub fn deny_unknown_enum_values(mut self, deny: bool) -> Self { + self.config.deny_unknown_enum_values = deny; + self + } + + #[inline] + pub fn ignore_unknown_enum_string_values(mut self, ignore: bool) -> Self { + self.config.ignore_unknown_enum_string_values = ignore; + self + } + + #[inline] + pub fn build(self) -> DeserializerConfig { + self.config + } +} + +impl Default for DeserializerConfigBuilder { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/prost/src/serde/private.rs b/prost/src/serde/private.rs new file mode 100644 index 000000000..8b1a2c8ab --- /dev/null +++ b/prost/src/serde/private.rs @@ -0,0 +1,38 @@ +pub use core::convert::TryFrom; +pub use core::default::Default; +pub use core::fmt; +pub use core::marker::PhantomData; +pub use core::option::Option; +pub use core::result::Result; + +pub use Option::{None, Some}; +pub use Result::{Err, Ok}; + +pub use ::serde as _serde; + +#[cfg(feature = "serde-json")] +pub use serde_json::Value as JsonValue; + +pub use super::{DeserializerConfig, SerializerConfig}; + +#[inline] +pub fn is_default_value(val: &T) -> bool { + *val == T::default() +} + +// Serialization utilities. + +pub use super::ser::{ + CustomSerialize, SerAsDisplay, SerBytesAsBase64, SerEnum, SerFloat32, SerFloat64, SerIdentity, + SerMappedMapItems, SerMappedVecItems, SerSerde, SerWithConfig, SerializeOneOf, +}; + +// Deserialization utilities. + +pub use super::de::{ + BoolDeserializer, BytesDeserializer, CustomDeserialize, DefaultDeserializer, DesIntoWithConfig, + DesWithConfig, DeserializeEnum, DeserializeInto, DeserializeOneOf, EnumDeserializer, + FloatDeserializer, ForwardDeserializer, IntDeserializer, MapDeserializer, + MaybeDesIntoWithConfig, MaybeDeserializedValue, MessageDeserializer, NullDeserializer, + OneOfDeserializer, OptionDeserializer, VecDeserializer, +}; diff --git a/prost/src/serde/ser.rs b/prost/src/serde/ser.rs new file mode 100644 index 000000000..240841e81 --- /dev/null +++ b/prost/src/serde/ser.rs @@ -0,0 +1,257 @@ +use alloc::{boxed::Box, vec::Vec}; + +use core::{fmt::Display, marker::PhantomData, ops::Deref}; +use serde::{ser::SerializeStruct, Serialize, Serializer}; + +use super::SerializerConfig; + +pub trait CustomSerialize { + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: serde::Serializer; +} + +impl CustomSerialize for &T +where + T: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: serde::Serializer, + { + CustomSerialize::serialize(*self, serializer, config) + } +} + +impl CustomSerialize for [T] +where + T: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: serde::Serializer, + { + serializer.collect_seq(self.iter().map(|item| SerWithConfig(item, config))) + } +} + +impl CustomSerialize for Vec +where + T: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: serde::Serializer, + { + CustomSerialize::serialize(self.as_slice(), serializer, config) + } +} + +impl CustomSerialize for Box +where + T: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: serde::Serializer, + { + CustomSerialize::serialize(&**self, serializer, config) + } +} + +pub struct SerWithConfig<'c, T>(pub T, pub &'c SerializerConfig); + +impl serde::Serialize for SerWithConfig<'_, T> +where + T: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + CustomSerialize::serialize(&self.0, serializer, self.1) + } +} + +pub struct SerIdentity<'a, T>(pub &'a T); + +impl CustomSerialize for SerIdentity<'_, T> +where + T: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: serde::Serializer, + { + CustomSerialize::serialize(self.0, serializer, config) + } +} + +pub struct SerSerde<'a, T>(pub &'a T); + +impl CustomSerialize for SerSerde<'_, T> +where + T: Serialize, +{ + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: serde::Serializer, + { + T::serialize(self.0, serializer) + } +} + +pub struct SerAsDisplay<'a, T>(pub &'a T); + +impl CustomSerialize for SerAsDisplay<'_, T> +where + T: Display, +{ + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: Serializer, + { + serializer.collect_str(self.0) + } +} + +pub struct SerBytesAsBase64<'a, T>(pub &'a T); + +impl CustomSerialize for SerBytesAsBase64<'_, T> +where + T: Deref, +{ + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: Serializer, + { + serializer.collect_str(&base64::display::Base64Display::new( + self.0, + &base64::prelude::BASE64_STANDARD, + )) + } +} + +pub struct SerFloat32<'a>(pub &'a f32); + +impl CustomSerialize for SerFloat32<'_> { + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: Serializer, + { + if self.0.is_nan() { + serializer.serialize_str("NaN") + } else if self.0.is_infinite() { + if self.0.is_sign_positive() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_str("-Infinity") + } + } else { + serializer.serialize_f32(*self.0) + } + } +} + +pub struct SerFloat64<'a>(pub &'a f64); + +impl CustomSerialize for SerFloat64<'_> { + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: Serializer, + { + if self.0.is_nan() { + serializer.serialize_str("NaN") + } else if self.0.is_infinite() { + if self.0.is_sign_positive() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_str("-Infinity") + } + } else { + serializer.serialize_f64(*self.0) + } + } +} + +pub struct SerMappedVecItems<'a, I, M>(pub &'a Vec, pub fn(&'a I) -> M); + +impl CustomSerialize for SerMappedVecItems<'_, I, M> +where + M: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: Serializer, + { + serializer.collect_seq(self.0.iter().map(|x| SerWithConfig(self.1(x), config))) + } +} + +pub struct SerEnum(pub i32, PhantomData); + +impl SerEnum { + #[inline] + pub fn new(val: &i32) -> Self { + Self(*val, PhantomData) + } +} + +impl CustomSerialize for SerEnum +where + E: TryFrom + CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: Serializer, + { + if let Ok(enum_val) = E::try_from(self.0) { + CustomSerialize::serialize(&enum_val, serializer, config) + } else { + serializer.serialize_i32(self.0) + } + } +} + +pub struct SerMappedMapItems<'a, C, V, M>(pub &'a C, pub fn(&'a V) -> M); + +impl<'a, C, K, V, M> CustomSerialize for SerMappedMapItems<'a, C, V, M> +where + &'a C: IntoIterator, + K: Display + 'a, + M: CustomSerialize, +{ + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: Serializer, + { + serializer.collect_map(self.0.into_iter().map(|(key, val)| { + ( + SerWithConfig(SerAsDisplay(key), config), + SerWithConfig(self.1(val), config), + ) + })) + } +} + +pub trait SerializeOneOf { + fn serialize_oneof( + &self, + serializer: &mut S, + config: &SerializerConfig, + ) -> Result<(), S::Error> + where + S: SerializeStruct; +} diff --git a/prost/src/serde/types.rs b/prost/src/serde/types.rs new file mode 100644 index 000000000..e087b6025 --- /dev/null +++ b/prost/src/serde/types.rs @@ -0,0 +1,264 @@ +use alloc::{string::String, vec::Vec}; +use core::fmt; + +use super::{ + de::CustomDeserialize, + private::{self, DeserializeInto, _serde}, + ser::CustomSerialize, + DeserializerConfig, SerializerConfig, +}; + +impl CustomSerialize for () { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + use _serde::ser::SerializeMap; + serializer.serialize_map(None)?.end() + } +} + +impl<'de> CustomDeserialize<'de> for () { + #[inline] + fn deserialize(deserializer: D, _config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + struct Visitor; + + impl<'de> _serde::de::Visitor<'de> for Visitor { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an empty message") + } + + #[inline] + fn visit_map(self, mut map: A) -> Result + where + A: _serde::de::MapAccess<'de>, + { + if map.next_key::<_serde::de::IgnoredAny>()?.is_some() { + return Err(::invalid_length( + 1, + &"an empty map", + )); + } + Ok(()) + } + } + + deserializer.deserialize_map(Visitor) + } +} + +impl CustomSerialize for bool { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.serialize_bool(*self) + } +} + +impl<'de> CustomDeserialize<'de> for bool { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::ForwardDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for i32 { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.serialize_i32(*self) + } +} + +impl<'de> CustomDeserialize<'de> for i32 { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::IntDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for u32 { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.serialize_u32(*self) + } +} + +impl<'de> CustomDeserialize<'de> for u32 { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::IntDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for i64 { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerAsDisplay(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for i64 { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::IntDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for u64 { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerAsDisplay(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for u64 { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::IntDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for str { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.serialize_str(self) + } +} + +impl CustomSerialize for String { + #[inline] + fn serialize(&self, serializer: S, _config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + serializer.serialize_str(self) + } +} + +impl<'de> CustomDeserialize<'de> for String { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::ForwardDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for Vec { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerBytesAsBase64(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for Vec { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::BytesDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for bytes::Bytes { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerBytesAsBase64(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for bytes::Bytes { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::BytesDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for f32 { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerFloat32(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for f32 { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::FloatDeserializer::deserialize_into(deserializer, config) + } +} + +impl CustomSerialize for f64 { + #[inline] + fn serialize(&self, serializer: S, config: &SerializerConfig) -> Result + where + S: _serde::Serializer, + { + private::SerFloat64(self).serialize(serializer, config) + } +} + +impl<'de> CustomDeserialize<'de> for f64 { + #[inline] + fn deserialize(deserializer: D, config: &DeserializerConfig) -> Result + where + D: _serde::Deserializer<'de>, + { + private::FloatDeserializer::deserialize_into(deserializer, config) + } +} diff --git a/protobuf/Cargo.toml b/protobuf/Cargo.toml index fdf297aba..fdaf6ea05 100644 --- a/protobuf/Cargo.toml +++ b/protobuf/Cargo.toml @@ -6,8 +6,8 @@ edition.workspace = true authors.workspace = true [dependencies] -prost = { path = "../prost" } -prost-types = { path = "../prost-types" } +prost = { path = "../prost", features = ["serde"] } +prost-types = { path = "../prost-types", features = ["serde", "any-v2"] } [build-dependencies] anyhow = "1.0.1" diff --git a/protobuf/build.rs b/protobuf/build.rs index 0f329fe3b..bc837c237 100644 --- a/protobuf/build.rs +++ b/protobuf/build.rs @@ -39,6 +39,7 @@ fn main() -> Result<()> { let conformance_proto_dir = src_dir.join("conformance"); prost_build::Config::new() .protoc_executable(&protoc_executable) + .enable_serde() .compile_protos( &[conformance_proto_dir.join("conformance.proto")], &[conformance_proto_dir], @@ -54,6 +55,7 @@ fn main() -> Result<()> { prost_build::Config::new() .protoc_executable(&protoc_executable) .btree_map(["."]) + .enable_serde() .compile_protos( &[ proto_dir.join("google/protobuf/test_messages_proto2.proto"), diff --git a/tests-2015/Cargo.toml b/tests-2015/Cargo.toml index 8a622dd27..70602b2ed 100644 --- a/tests-2015/Cargo.toml +++ b/tests-2015/Cargo.toml @@ -12,15 +12,16 @@ doctest = false path = "../tests/src/lib.rs" [features] -default = ["edition-2015", "std"] +default = ["edition-2015", "std", "json"] edition-2015 = [] std = [] +json = ["prost/serde-json"] [dependencies] anyhow = "1.0.1" cfg-if = "1" -prost = { path = "../prost" } -prost-types = { path = "../prost-types" } +prost = { path = "../prost", features = ["serde"] } +prost-types = { path = "../prost-types", features = ["serde"] } protobuf = { path = "../protobuf" } [dev-dependencies] diff --git a/tests-2018/Cargo.toml b/tests-2018/Cargo.toml index e9deeef7c..abaf75a90 100644 --- a/tests-2018/Cargo.toml +++ b/tests-2018/Cargo.toml @@ -15,8 +15,9 @@ doctest = false path = "../tests/src/lib.rs" [features] -default = ["std"] +default = ["std", "json"] std = [] +json = ["prost/serde-json"] [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2015"))'] } @@ -24,8 +25,8 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2 [dependencies] anyhow = "1.0.1" cfg-if = "1" -prost = { path = "../prost" } -prost-types = { path = "../prost-types" } +prost = { path = "../prost", features = ["serde"] } +prost-types = { path = "../prost-types", features = ["serde"] } protobuf = { path = "../protobuf" } [dev-dependencies] diff --git a/tests-no-std/Cargo.toml b/tests-no-std/Cargo.toml index 1fbe1d468..243d112c6 100644 --- a/tests-no-std/Cargo.toml +++ b/tests-no-std/Cargo.toml @@ -12,7 +12,7 @@ doctest = false path = "../tests/src/lib.rs" [lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2015", "std"))'] } +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2015", "std", "json"))'] } # Compile the `tests` crate *without* the std feature, which is implicitly # omitted from the default crate features. It would be easier to do something @@ -23,8 +23,8 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2 [dependencies] anyhow = { version = "1.0.45", default-features = false } cfg-if = "1" -prost = { path = "../prost", default-features = false, features = ["derive"] } -prost-types = { path = "../prost-types", default-features = false } +prost = { path = "../prost", default-features = false, features = ["derive", "serde"] } +prost-types = { path = "../prost-types", default-features = false, features = ["serde"] } [dev-dependencies] prost-build = { path = "../prost-build" } diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 066162eaf..3af2e15f5 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -6,8 +6,9 @@ edition.workspace = true authors.workspace = true [features] -default = ["std"] +default = ["std", "json"] std = [] +json = ["prost/serde-json"] [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2015"))'] } @@ -15,8 +16,8 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("edition-2 [dependencies] anyhow = "1.0.1" cfg-if = "1" -prost = { path = "../prost" } -prost-types = { path = "../prost-types" } +prost = { path = "../prost", features = ["serde"] } +prost-types = { path = "../prost-types", features = ["serde"] } [dev-dependencies] prost-build = { path = "../prost-build", features = ["cleanup-markdown"] } diff --git a/tests/build.rs b/tests/build.rs index b78374d61..59fede7a6 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -25,6 +25,7 @@ fn main() { // compare based on the Rust PartialEq implementations is difficult, due to presence of NaN // values. let mut config = prost_build::Config::new(); + config.enable_serde(); config.btree_map(["."]); // Tests for custom attributes config.type_attribute("Foo.Bar_Baz.Foo_barBaz", "#[derive(Eq, PartialOrd, Ord)]"); @@ -128,11 +129,13 @@ fn main() { .unwrap(); prost_build::Config::new() + .enable_serde() .protoc_arg("--experimental_allow_proto3_optional") .compile_protos(&[src.join("proto3_presence.proto")], includes) .unwrap(); prost_build::Config::new() + .enable_serde() .disable_comments(["."]) .compile_protos(&[src.join("disable_comments.proto")], includes) .unwrap(); @@ -148,6 +151,7 @@ fn main() { std::fs::create_dir_all(&out_path).unwrap(); prost_build::Config::new() + .enable_serde() .bytes(["."]) .out_dir(out_path) .include_file("wellknown_include.rs") @@ -162,6 +166,7 @@ fn main() { .unwrap(); prost_build::Config::new() + .enable_serde() .enable_type_names() .type_name_domain([".type_names.Foo"], "tests") .compile_protos(&[src.join("type_names.proto")], includes) @@ -185,6 +190,7 @@ fn main() { fs::create_dir_all(&no_root_packages).expect("failed to create prefix directory"); let mut no_root_packages_config = prost_build::Config::new(); no_root_packages_config + .enable_serde() .out_dir(&no_root_packages) .default_package_filename("__.default") .include_file("__.include.rs") @@ -200,6 +206,7 @@ fn main() { fs::create_dir_all(&no_root_packages_with_default).expect("failed to create prefix directory"); let mut no_root_packages_config = prost_build::Config::new(); no_root_packages_config + .enable_serde() .out_dir(&no_root_packages_with_default) .compile_protos( &[src.join("no_root_packages/widget_factory.proto")], diff --git a/tests/single-include/Cargo.toml b/tests/single-include/Cargo.toml index 54408a0f1..665dbeb59 100644 --- a/tests/single-include/Cargo.toml +++ b/tests/single-include/Cargo.toml @@ -7,7 +7,7 @@ publish = false license = "MIT" [dependencies] -prost = { path = "../../prost" } +prost = { path = "../../prost", features = ["serde"] } [build-dependencies] prost-build = { path = "../../prost-build" } diff --git a/tests/single-include/build.rs b/tests/single-include/build.rs index 5432d8b5c..ad6b335ac 100644 --- a/tests/single-include/build.rs +++ b/tests/single-include/build.rs @@ -2,11 +2,13 @@ use prost_build::Config; fn main() { Config::new() + .enable_serde() .include_file("lib.rs") .compile_protos(&["protos/search.proto"], &["protos"]) .unwrap(); Config::new() + .enable_serde() .out_dir("src/outdir") .include_file("mod.rs") .compile_protos(&["protos/outdir.proto"], &["protos"]) diff --git a/tests/single-include/src/outdir/outdir.rs b/tests/single-include/src/outdir/outdir.rs index 233028a04..6bb94ad6e 100644 --- a/tests/single-include/src/outdir/outdir.rs +++ b/tests/single-include/src/outdir/outdir.rs @@ -1,5 +1,6 @@ // This file is @generated by prost-build. #[derive(Clone, PartialEq, ::prost::Message)] +#[prost(serde)] pub struct OutdirRequest { #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, diff --git a/tests/src/decode_error.rs b/tests/src/decode_error.rs index ff74240e7..804a31705 100644 --- a/tests/src/decode_error.rs +++ b/tests/src/decode_error.rs @@ -136,7 +136,7 @@ fn test_decode_error_invalid_string() { #[test] fn test_decode_error_any() { - use prost_types::{Any, Timestamp}; + use prost_types::{AnyV1 as Any, Timestamp}; let msg = Any { type_url: "non-existing-url".to_string(), diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 4c2f60d6e..9b19506f2 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -128,106 +128,240 @@ pub mod default_string_escape { } #[cfg(not(feature = "std"))] -use alloc::vec::Vec; +use alloc::{string::String, vec::Vec}; use anyhow::anyhow; -use prost::bytes::Buf; +use prost::{bytes::Buf, serde::SerdeMessage, Message}; -use prost::Message; +#[derive(Debug, Clone, Copy)] +pub enum RoundtripInput<'a> { + Protobuf(&'a [u8]), + Json(&'a str), +} + +#[derive(Debug, Clone, Copy)] +pub enum RoundtripOutputType { + Protobuf, + Json, +} +#[derive(Debug)] pub enum RoundtripResult { - /// The roundtrip succeeded. - Ok(Vec), + /// The roundtrip to protobuf succeeded. + Protobuf(Vec), + /// The roundtrip to json succeeded. + Json(String), + /// The data could not be encoded. This could indicate a bug in prost, + /// or it could indicate that the data was invalid (eg. violating message invariants). + EncodeError(anyhow::Error), /// The data could not be decoded. This could indicate a bug in prost, /// or it could indicate that the input was bogus. - DecodeError(prost::DecodeError), + DecodeError(anyhow::Error), /// Re-encoding or validating the data failed. This indicates a bug in `prost`. Error(anyhow::Error), } impl RoundtripResult { - /// Unwrap the roundtrip result. - pub fn unwrap(self) -> Vec { + pub fn unwrap(self) -> RoundtripOutput { match self { - RoundtripResult::Ok(buf) => buf, - RoundtripResult::DecodeError(error) => { + Self::Json(json) => RoundtripOutput::Json(json), + Self::Protobuf(data) => RoundtripOutput::Protobuf(data), + Self::EncodeError(error) => { + panic!("failed to encode the roundtrip data: {}", error) + } + Self::DecodeError(error) => { panic!("failed to decode the roundtrip data: {}", error) } - RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), + Self::Error(error) => panic!("failed roundtrip: {}", error), } } /// Unwrap the roundtrip result. Panics if the result was a validation or re-encoding error. - pub fn unwrap_error(self) -> Result, prost::DecodeError> { + pub fn unwrap_error(self) -> RoundtripResult { match self { - RoundtripResult::Ok(buf) => Ok(buf), - RoundtripResult::DecodeError(error) => Err(error), - RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), + Self::Error(error) => panic!("failed roundtrip: {}", error), + result => result, } } } +#[derive(Debug)] +pub enum RoundtripOutput { + /// The roundtrip to protobuf succeeded. + Protobuf(Vec), + /// The roundtrip to json succeeded. + Json(String), +} + /// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, /// otherwise the comparison may fail due to inconsistent `HashMap` entry encoding ordering. -pub fn roundtrip(data: &[u8]) -> RoundtripResult +pub fn roundtrip( + input: RoundtripInput<'_>, + output_ty: RoundtripOutputType, + #[cfg_attr(not(feature = "json"), allow(unused_variables))] ignore_unknown_fields: bool, +) -> RoundtripResult where - M: Message + Default, + M: Message + SerdeMessage + Default, { + #[cfg(feature = "json")] + let serializer_config = prost::serde::SerializerConfig::default(); + #[cfg(feature = "json")] + let deserializer_config = prost::serde::DeserializerConfigBuilder::default() + .ignore_unknown_fields(ignore_unknown_fields) + .ignore_unknown_enum_string_values(ignore_unknown_fields) + .build(); + // Try to decode a message from the data. If decoding fails, continue. - let all_types = match M::decode(data) { - Ok(all_types) => all_types, - Err(error) => return RoundtripResult::DecodeError(error), + let all_types = match input { + RoundtripInput::Protobuf(data) => match M::decode(data) { + Ok(all_types) => all_types, + Err(err) => return RoundtripResult::DecodeError(anyhow::Error::msg(err)), + }, + #[cfg(feature = "json")] + RoundtripInput::Json(data) => match deserializer_config.deserialize_from_str::(data) { + Ok(all_types) => all_types, + Err(err) => return RoundtripResult::DecodeError(err.into()), + }, + #[cfg(not(feature = "json"))] + RoundtripInput::Json(_) => unreachable!("enable the `json` feature for json rountrips"), }; - let encoded_len = all_types.encoded_len(); - - // TODO: Reenable this once sign-extension in negative int32s is figured out. - // assert!(encoded_len <= data.len(), "encoded_len: {}, len: {}, all_types: {:?}", - // encoded_len, data.len(), all_types); + let mid_protobuf; + #[cfg(feature = "json")] + let mid_json; + let mid_input = match output_ty { + RoundtripOutputType::Protobuf => { + mid_protobuf = all_types.encode_to_vec(); + + let encoded_len = all_types.encoded_len(); + + // TODO: Reenable this once sign-extension in negative int32s is figured out. + // if let RoundtripInput::Protobuf(data) = input { + // assert!( + // encoded_len <= data.len(), + // "encoded_len: {}, len: {}, all_types: {:?}", + // encoded_len, + // data.len(), + // all_types + // ); + // } + + if encoded_len != mid_protobuf.len() { + return RoundtripResult::Error(anyhow!( + "expected encoded len ({}) did not match actual encoded len ({})", + encoded_len, + mid_protobuf.len() + )); + } - let mut buf1 = Vec::new(); - if let Err(error) = all_types.encode(&mut buf1) { - return RoundtripResult::Error(anyhow!(error)); - } - let buf1 = buf1; - if encoded_len != buf1.len() { - return RoundtripResult::Error(anyhow!( - "expected encoded len ({}) did not match actual encoded len ({})", - encoded_len, - buf1.len() - )); - } + RoundtripInput::Protobuf(&mid_protobuf) + } + #[cfg(feature = "json")] + RoundtripOutputType::Json => { + mid_json = match serializer_config.with(&all_types).to_string() { + Ok(val) => val, + Err(err) => { + return if err.is_data() { + RoundtripResult::EncodeError(err.into()) + } else { + RoundtripResult::Error(err.into()) + } + } + }; + RoundtripInput::Json(&mid_json) + } + #[cfg(not(feature = "json"))] + RoundtripOutputType::Json => unreachable!("enable the `json` feature for json rountrips"), + }; - let roundtrip = match M::decode(buf1.as_slice()) { - Ok(roundtrip) => roundtrip, - Err(error) => return RoundtripResult::Error(anyhow!(error)), + let final_all_types = match mid_input { + RoundtripInput::Protobuf(data) => match M::decode(data) { + Ok(all_types) => all_types, + Err(err) => return RoundtripResult::DecodeError(anyhow::Error::msg(err)), + }, + #[cfg(feature = "json")] + RoundtripInput::Json(data) => match deserializer_config.deserialize_from_str::(data) { + Ok(all_types) => all_types, + Err(err) => { + return if err.is_data() { + RoundtripResult::EncodeError(err.into()) + } else { + RoundtripResult::Error(err.into()) + } + } + }, + #[cfg(not(feature = "json"))] + RoundtripInput::Json(_) => unreachable!("enable the `json` feature for json rountrips"), }; - let mut buf2 = Vec::new(); - if let Err(error) = roundtrip.encode(&mut buf2) { - return RoundtripResult::Error(anyhow!(error)); - } - let buf2 = buf2; - let buf3 = roundtrip.encode_to_vec(); - - /* - // Useful for debugging: - eprintln!(" data: {:?}", data.iter().map(|x| format!("0x{:x}", x)).collect::>()); - eprintln!(" buf1: {:?}", buf1.iter().map(|x| format!("0x{:x}", x)).collect::>()); - eprintln!("a: {:?}\nb: {:?}", all_types, roundtrip); - */ - - if buf1 != buf2 { - return RoundtripResult::Error(anyhow!("roundtripped encoded buffers do not match")); - } + match output_ty { + RoundtripOutputType::Protobuf => { + let encoded_len = final_all_types.encoded_len(); + + let encoded_1 = final_all_types.encode_to_vec(); + if encoded_1.len() != encoded_len { + return RoundtripResult::Error(anyhow!( + "expected encoded len ({}) did not match actual encoded len ({})", + encoded_len, + encoded_1.len() + )); + } + + let mut encoded_2 = alloc::vec![]; + if let Err(error) = final_all_types.encode(&mut encoded_2) { + return RoundtripResult::Error(anyhow::Error::msg(error)); + } + if encoded_2.len() != encoded_len { + return RoundtripResult::Error(anyhow!( + "expected encoded len ({}) did not match actual encoded len ({})", + encoded_len, + encoded_2.len() + )); + } - if buf1 != buf3 { - return RoundtripResult::Error(anyhow!( - "roundtripped encoded buffers do not match with `encode_to_vec`" - )); + if let RoundtripInput::Protobuf(mid_input) = mid_input { + if encoded_1 != mid_input { + return RoundtripResult::Error(anyhow!( + "roundtripped encoded buffers (1) do not match" + )); + } + if encoded_2 != mid_input { + return RoundtripResult::Error(anyhow!( + "roundtripped encoded buffers (2) do not match" + )); + } + } + + RoundtripResult::Protobuf(encoded_1) + } + #[cfg(feature = "json")] + RoundtripOutputType::Json => { + let json = match serializer_config.with(&final_all_types).to_string() { + Ok(val) => val, + Err(err) => { + return if err.is_data() { + RoundtripResult::EncodeError(err.into()) + } else { + RoundtripResult::Error(err.into()) + } + } + }; + RoundtripResult::Json(json) + } + #[cfg(not(feature = "json"))] + RoundtripOutputType::Json => unreachable!("enable the `json` feature for json rountrips"), } +} - RoundtripResult::Ok(buf1) +pub fn roundtrip_proto(input: &[u8]) -> RoundtripResult +where + M: Message + SerdeMessage + Default, +{ + roundtrip::( + RoundtripInput::Protobuf(input), + RoundtripOutputType::Protobuf, + false, + ) } /// Generic roundtrip serialization check for messages. @@ -302,7 +436,7 @@ mod tests { ]; for msg in msgs { - roundtrip::(msg).unwrap(); + roundtrip_proto::(msg).unwrap(); } } @@ -378,7 +512,7 @@ mod tests { let mut buf = Vec::new(); msg.encode(&mut buf).expect("encode"); - roundtrip::(&buf).unwrap(); + roundtrip_proto::(&buf).unwrap(); } #[test] diff --git a/tests/src/message_encoding.rs b/tests/src/message_encoding.rs index 002c5ca48..3fea233f0 100644 --- a/tests/src/message_encoding.rs +++ b/tests/src/message_encoding.rs @@ -316,7 +316,7 @@ fn check_default_values() { assert_eq!(&default.bytes_buf.as_ref(), b"foo\0bar"); assert_eq!(default.enumeration, BasicEnumeration::ONE as i32); assert_eq!(default.optional_enumeration, None); - assert_eq!(&default.repeated_enumeration, &[]); + assert_eq!(default.repeated_enumeration, &[] as &[i32]); assert_eq!(0, default.encoded_len()); }