Skip to content

Commit

Permalink
derive(Eq): Also derive StructuralEq
Browse files Browse the repository at this point in the history
gcc/rust/ChangeLog:

	* expand/rust-derive-eq.cc: Adapt functions to return two generated impls.
	* expand/rust-derive-eq.h: Likewise.
	* expand/rust-derive.cc (DeriveVisitor::derive): Likewise.
  • Loading branch information
CohenArthur committed Feb 6, 2025
1 parent bff0ff4 commit b9c5bdf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 35 deletions.
64 changes: 35 additions & 29 deletions gcc/rust/expand/rust-derive-eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
namespace Rust {
namespace AST {

DeriveEq::DeriveEq (location_t loc) : DeriveVisitor (loc) {}

std::vector<std::unique_ptr<AST::Item>>
DeriveEq::go (Item &item)
{
item.accept_vis (*this);

return std::move (expanded);
}

std::unique_ptr<AssociatedItem>
DeriveEq::assert_receiver_is_total_eq_fn (
std::vector<std::unique_ptr<Type>> &&types)
Expand Down Expand Up @@ -98,33 +108,29 @@ DeriveEq::assert_type_is_eq (std::unique_ptr<Type> &&type)
return builder.let (builder.wildcard (), std::move (full_path));
}

std::unique_ptr<Item>
DeriveEq::eq_impl (
std::vector<std::unique_ptr<Item>>
DeriveEq::eq_impls (
std::unique_ptr<AssociatedItem> &&fn, std::string name,
const std::vector<std::unique_ptr<GenericParam>> &type_generics)
{
auto eq = builder.type_path ({"core", "cmp", "Eq"}, true);
auto steq = builder.type_path (LangItem::Kind::STRUCTURAL_TEQ);

auto trait_items = vec (std::move (fn));

auto generics
auto eq_generics
= setup_impl_generics (name, type_generics, builder.trait_bound (eq));
auto steq_generics = setup_impl_generics (name, type_generics);

return builder.trait_impl (eq, std::move (generics.self_type),
std::move (trait_items),
std::move (generics.impl));
}

DeriveEq::DeriveEq (location_t loc) : DeriveVisitor (loc), expanded (nullptr) {}

std::unique_ptr<AST::Item>
DeriveEq::go (Item &item)
{
item.accept_vis (*this);
auto eq_impl = builder.trait_impl (eq, std::move (eq_generics.self_type),
std::move (trait_items),
std::move (eq_generics.impl));
auto steq_impl
= builder.trait_impl (steq, std::move (steq_generics.self_type),
std::move (trait_items),
std::move (steq_generics.impl));

rust_assert (expanded);

return std::move (expanded);
return vec (std::move (eq_impl), std::move (steq_impl));
}

void
Expand All @@ -135,9 +141,9 @@ DeriveEq::visit_tuple (TupleStruct &item)
for (auto &field : item.get_fields ())
types.emplace_back (field.get_field_type ().clone_type ());

expanded
= eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (), item.get_generic_params ());
expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (),
item.get_generic_params ());
}

void
Expand All @@ -148,9 +154,9 @@ DeriveEq::visit_struct (StructStruct &item)
for (auto &field : item.get_fields ())
types.emplace_back (field.get_field_type ().clone_type ());

expanded
= eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (), item.get_generic_params ());
expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (),
item.get_generic_params ());
}

void
Expand Down Expand Up @@ -185,9 +191,9 @@ DeriveEq::visit_enum (Enum &item)
}
}

expanded
= eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (), item.get_generic_params ());
expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (),
item.get_generic_params ());
}

void
Expand All @@ -198,9 +204,9 @@ DeriveEq::visit_union (Union &item)
for (auto &field : item.get_variants ())
types.emplace_back (field.get_field_type ().clone_type ());

expanded
= eq_impl (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (), item.get_generic_params ());
expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
item.get_identifier ().as_string (),
item.get_generic_params ());
}

} // namespace AST
Expand Down
10 changes: 5 additions & 5 deletions gcc/rust/expand/rust-derive-eq.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class DeriveEq : DeriveVisitor
public:
DeriveEq (location_t loc);

std::unique_ptr<AST::Item> go (Item &item);
std::vector<std::unique_ptr<AST::Item>> go (Item &item);

private:
std::unique_ptr<Item> expanded;
std::vector<std::unique_ptr<Item>> expanded;

/**
* Create the actual `assert_receiver_is_total_eq` function of the
Expand All @@ -52,9 +52,9 @@ class DeriveEq : DeriveVisitor
* }
*
*/
std::unique_ptr<Item>
eq_impl (std::unique_ptr<AssociatedItem> &&fn, std::string name,
const std::vector<std::unique_ptr<GenericParam>> &type_generics);
std::vector<std::unique_ptr<Item>>
eq_impls (std::unique_ptr<AssociatedItem> &&fn, std::string name,
const std::vector<std::unique_ptr<GenericParam>> &type_generics);

/**
* Generate the following structure definition
Expand Down
2 changes: 1 addition & 1 deletion gcc/rust/expand/rust-derive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ DeriveVisitor::derive (Item &item, const Attribute &attr,
case BuiltinMacro::Default:
return vec (DeriveDefault (attr.get_locus ()).go (item));
case BuiltinMacro::Eq:
return vec (DeriveEq (attr.get_locus ()).go (item));
return DeriveEq (attr.get_locus ()).go (item);
case BuiltinMacro::PartialEq:
return DerivePartialEq (attr.get_locus ()).go (item);
case BuiltinMacro::Ord:
Expand Down

0 comments on commit b9c5bdf

Please sign in to comment.