Skip to content

Commit

Permalink
feat: derive Eq and Hash trait for messages where possible
Browse files Browse the repository at this point in the history
Integer and bytes types can be compared using trait Eq. Some generated Rust structs can also have this property by deriving the Eq trait.

Automatically derive Eq and Hash for:
- messages that only have fields with integer or bytes types
- messages where all field types also implement Eq and Hash
- the Rust enum for one-of fields, where all fields implement Eq and Hash

Generated code for Protobuf enums already derives Eq and Hash.

BREAKING CHANGE: `prost-build` will automatically derive `trait Eq` and `trait Hash` for types where all field support those as well. If you manually `impl Eq` and/or `impl Hash` for generated types, then you need to remove the manual implementation. If you use `type_attribute` to `derive(Eq)` and/or `derive(Hash)`, then you need to remove those.
  • Loading branch information
caspermeijn committed Oct 23, 2024
1 parent 86f87a2 commit 3562a58
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 45 deletions.
14 changes: 12 additions & 2 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,17 @@ impl<'a> CodeGenerator<'a> {
self.append_message_attributes(&fq_message_name);
self.push_indent();
self.buf.push_str(&format!(
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
"#[derive(Clone, {}PartialEq, {}{}::Message)]\n",
if self.message_graph.can_message_derive_copy(&fq_message_name) {
"Copy, "
} else {
""
},
if self.message_graph.can_message_derive_eq(&fq_message_name) {
"Eq, Hash, "
} else {
""
},
prost_path(self.config)
));
self.append_skip_debug(&fq_message_name);
Expand Down Expand Up @@ -619,9 +624,14 @@ impl<'a> CodeGenerator<'a> {
self.message_graph
.can_field_derive_copy(fq_message_name, &field.descriptor)
});
let can_oneof_derive_eq = oneof.fields.iter().all(|field| {
self.message_graph
.can_field_derive_eq(fq_message_name, &field.descriptor)
});
self.buf.push_str(&format!(
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
"#[derive(Clone, {}PartialEq, {}{}::Oneof)]\n",
if can_oneof_derive_copy { "Copy, " } else { "" },
if can_oneof_derive_eq { "Eq, Hash, " } else { "" },
prost_path(self.config)
));
self.append_skip_debug(fq_message_name);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
// This file is @generated by prost-build.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Container {
#[prost(oneof = "container::Data", tags = "1, 2")]
pub data: ::core::option::Option<container::Data>,
}
/// Nested message and enum types in `Container`.
pub mod container {
#[derive(Clone, PartialEq, ::prost::Oneof)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
pub enum Data {
#[prost(message, tag = "1")]
Foo(::prost::alloc::boxed::Box<super::Foo>),
#[prost(message, tag = "2")]
Bar(super::Bar),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Foo {
#[prost(string, tag = "1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Qux {}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// This file is @generated by prost-build.
#[derive(derive_builder::Builder)]
#[derive(custom_proto::Input)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Message {
#[prost(string, tag = "1")]
pub say: ::prost::alloc::string::String,
}
#[derive(derive_builder::Builder)]
#[derive(custom_proto::Output)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Response {
#[prost(string, tag = "1")]
pub say: ::prost::alloc::string::String,
Expand Down
43 changes: 43 additions & 0 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,47 @@ impl MessageGraph {
)
}
}

/// Returns `true` if this message can automatically derive Eq trait.
pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);

let msg = self.messages.get(fq_message_name).unwrap();
msg.field
.iter()
.all(|field| self.can_field_derive_eq(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Eq trait.
pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.r#type() == Type::Message {
if field.label() == Label::Repeated {
false
} else if self.is_nested(field.type_name(), fq_message_name) {
false
} else {
self.can_message_derive_eq(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
| Type::String
| Type::Bytes
)
}
}
}
2 changes: 1 addition & 1 deletion prost-types/src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is @generated by prost-build.
/// The version number of protocol compiler.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Version {
#[prost(int32, optional, tag = "1")]
pub major: ::core::option::Option<i32>,
Expand Down
8 changes: 0 additions & 8 deletions prost-types/src/duration.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
use super::*;

#[cfg(feature = "std")]
impl std::hash::Hash for Duration {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.seconds.hash(state);
self.nanos.hash(state);
}
}

impl Duration {
/// Normalizes the duration to a canonical format.
///
Expand Down
24 changes: 12 additions & 12 deletions prost-types/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub mod descriptor_proto {
/// Range of reserved tag numbers. Reserved tag numbers may not be used by
/// fields or extension ranges in the same message. Reserved ranges may
/// not overlap.
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct ReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -350,7 +350,7 @@ pub mod enum_descriptor_proto {
/// Note that this is distinct from DescriptorProto.ReservedRange in that it
/// is inclusive such that it can appropriately represent the entire int32
/// domain.
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct EnumReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -961,7 +961,7 @@ pub mod uninterpreted_option {
/// extension (denoted with parentheses in options specs in .proto files).
/// E.g.,{ \["foo", false\], \["bar.baz", true\], \["qux", false\] } represents
/// "foo.(bar.baz).qux".
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct NamePart {
#[prost(string, required, tag = "1")]
pub name_part: ::prost::alloc::string::String,
Expand Down Expand Up @@ -1022,7 +1022,7 @@ pub struct SourceCodeInfo {
}
/// Nested message and enum types in `SourceCodeInfo`.
pub mod source_code_info {
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Location {
/// Identifies which part of the FileDescriptorProto was defined at this
/// location.
Expand Down Expand Up @@ -1125,7 +1125,7 @@ pub struct GeneratedCodeInfo {
}
/// Nested message and enum types in `GeneratedCodeInfo`.
pub mod generated_code_info {
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Annotation {
/// Identifies the element in the original source .proto file. This field
/// is formatted the same as SourceCodeInfo.Location.path.
Expand Down Expand Up @@ -1238,7 +1238,7 @@ pub mod generated_code_info {
/// "value": "1.212s"
/// }
/// ```
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Any {
/// A URL/resource name that uniquely identifies the type of the serialized
/// protocol buffer message. This string must contain at least
Expand Down Expand Up @@ -1275,7 +1275,7 @@ pub struct Any {
}
/// `SourceContext` represents information about the source of a
/// protobuf element, like the file in which it is defined.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct SourceContext {
/// The path-qualified name of the .proto file that contained the associated
/// protobuf element. For example: `"google/protobuf/source_context.proto"`.
Expand Down Expand Up @@ -1531,7 +1531,7 @@ pub struct EnumValue {
}
/// A protocol buffer option, which can be attached to a message, field,
/// enumeration, etc.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Option {
/// The option's name. For protobuf built-in options (options defined in
/// descriptor.proto), this is the short name. For example, `"map_entry"`.
Expand Down Expand Up @@ -1741,7 +1741,7 @@ pub struct Method {
/// ...
/// }
/// ```
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Mixin {
/// The fully qualified name of the interface which is included.
#[prost(string, tag = "1")]
Expand Down Expand Up @@ -1815,7 +1815,7 @@ pub struct Mixin {
/// encoded in JSON format as "3s", while 3 seconds and 1 nanosecond should
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
/// microsecond should be expressed in JSON format as "3.000001s".
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Duration {
/// Signed seconds of the span of time. Must be from -315,576,000,000
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
Expand Down Expand Up @@ -2053,7 +2053,7 @@ pub struct Duration {
/// The implementation of any API method which has a FieldMask type field in the
/// request should verify the included field paths, and return an
/// `INVALID_ARGUMENT` error if any path is unmappable.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct FieldMask {
/// The set of field mask paths.
#[prost(string, repeated, tag = "1")]
Expand Down Expand Up @@ -2249,7 +2249,7 @@ impl NullValue {
/// [`strftime`](<https://docs.python.org/2/library/time.html#time.strftime>) with
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Timestamp {
/// Represents seconds of UTC time since Unix epoch
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
Expand Down
13 changes: 0 additions & 13 deletions prost-types/src/timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,6 @@ impl Name for Timestamp {
}
}

/// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`.
/// This implies that `normalized_ts != non_normalized_ts` even if `normalized_ts == non_normalized_ts.normalized()`.
#[cfg(feature = "std")]
impl Eq for Timestamp {}

#[cfg(feature = "std")]
impl std::hash::Hash for Timestamp {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.seconds.hash(state);
self.nanos.hash(state);
}
}

#[cfg(feature = "std")]
impl From<std::time::SystemTime> for Timestamp {
fn from(system_time: std::time::SystemTime) -> Timestamp {
Expand Down
2 changes: 1 addition & 1 deletion tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn main() {
config.type_attribute("Foo.Custom.Attrs.AnotherEnum", "/// Oneof docs");
config.type_attribute(
"Foo.Custom.OneOfAttrs.Msg.field",
"#[derive(Eq, PartialOrd, Ord)]",
"#[derive(PartialOrd, Ord)]",
);
config.field_attribute("Foo.Custom.Attrs.AnotherEnum.C", "/// The C docs");
config.field_attribute("Foo.Custom.Attrs.AnotherEnum.D", "/// The D docs");
Expand Down
2 changes: 1 addition & 1 deletion tests/single-include/src/outdir/outdir.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// This file is @generated by prost-build.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct OutdirRequest {
#[prost(string, tag = "1")]
pub query: ::prost::alloc::string::String,
Expand Down

0 comments on commit 3562a58

Please sign in to comment.