diff --git a/gcc/rust/Make-lang.in b/gcc/rust/Make-lang.in index 862e04eb6dc..f77d16e6008 100644 --- a/gcc/rust/Make-lang.in +++ b/gcc/rust/Make-lang.in @@ -98,6 +98,7 @@ GRS_OBJS = \ rust/rust-derive-copy.o \ rust/rust-derive-debug.o \ rust/rust-derive-default.o \ + rust/rust-derive-partial-eq.o \ rust/rust-derive-eq.o \ rust/rust-proc-macro.o \ rust/rust-macro-invoc-lexer.o \ diff --git a/gcc/rust/expand/rust-derive-partial-eq.cc b/gcc/rust/expand/rust-derive-partial-eq.cc new file mode 100644 index 00000000000..6f7ef7d8780 --- /dev/null +++ b/gcc/rust/expand/rust-derive-partial-eq.cc @@ -0,0 +1,308 @@ +// Copyright (C) 2020-2024 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-derive-partial-eq.h" +#include "rust-ast.h" +#include "rust-expr.h" +#include "rust-item.h" +#include "rust-operators.h" +#include "rust-path.h" +#include "rust-pattern.h" +#include "rust-system.h" + +namespace Rust { +namespace AST { +DerivePartialEq::DerivePartialEq (location_t loc) + : DeriveVisitor (loc), expanded (nullptr) +{} + +std::unique_ptr +DerivePartialEq::go (Item &item) +{ + item.accept_vis (*this); + + rust_assert (expanded); + + return std::move (expanded); +} + +std::unique_ptr +DerivePartialEq::partial_eq_impl ( + std::unique_ptr &&eq_fn, std::string name, + const std::vector> &type_generics) +{ + auto eq = builder.type_path (LangItem::Kind::EQ); + + auto trait_items = vec (std::move (eq_fn)); + + auto generics + = setup_impl_generics (name, type_generics, builder.trait_bound (eq)); + + return builder.trait_impl (eq, std::move (generics.self_type), + std::move (trait_items), + std::move (generics.impl)); +} + +std::unique_ptr +DerivePartialEq::eq_fn (std::unique_ptr &&cmp_expression, + std::string type_name) +{ + auto block = builder.block (tl::nullopt, std::move (cmp_expression)); + + auto self_type + = std::unique_ptr (new TypePath (builder.type_path ("Self"))); + + auto params + = vec (builder.self_ref_param (), + builder.function_param (builder.identifier_pattern ("other"), + builder.reference_type ( + std::move (self_type)))); + + return builder.function ("eq", std::move (params), + builder.single_type_path ("bool"), + std::move (block)); +} + +DerivePartialEq::SelfOther +DerivePartialEq::tuple_indexes (int idx) +{ + return SelfOther{ + builder.tuple_idx ("self", idx), + builder.tuple_idx ("other", idx), + }; +} + +DerivePartialEq::SelfOther +DerivePartialEq::field_acccesses (const std::string &field_name) +{ + return SelfOther{ + builder.field_access (builder.identifier ("self"), field_name), + builder.field_access (builder.identifier ("other"), field_name), + }; +} + +std::unique_ptr +DerivePartialEq::build_eq_expression ( + std::vector &&field_expressions) +{ + // for unit structs or empty tuples, this is always true + if (field_expressions.empty ()) + return builder.literal_bool (true); + + auto cmp_expression + = builder.comparison_expr (std::move (field_expressions.at (0).self_expr), + std::move (field_expressions.at (0).other_expr), + ComparisonOperator::EQUAL); + + for (size_t i = 1; i < field_expressions.size (); i++) + { + auto tmp = builder.comparison_expr ( + std::move (field_expressions.at (i).self_expr), + std::move (field_expressions.at (i).other_expr), + ComparisonOperator::EQUAL); + + cmp_expression + = builder.boolean_operation (std::move (cmp_expression), + std::move (tmp), + LazyBooleanOperator::LOGICAL_AND); + } + + return cmp_expression; +} + +void +DerivePartialEq::visit_tuple (TupleStruct &item) +{ + auto type_name = item.get_struct_name ().as_string (); + auto fields = std::vector (); + + for (size_t idx = 0; idx < item.get_fields ().size (); idx++) + fields.emplace_back (tuple_indexes (idx)); + + auto fn = eq_fn (build_eq_expression (std::move (fields)), type_name); + + expanded + = partial_eq_impl (std::move (fn), type_name, item.get_generic_params ()); +} + +void +DerivePartialEq::visit_struct (StructStruct &item) +{ + auto type_name = item.get_struct_name ().as_string (); + auto fields = std::vector (); + + for (auto &field : item.get_fields ()) + fields.emplace_back ( + field_acccesses (field.get_field_name ().as_string ())); + + auto fn = eq_fn (build_eq_expression (std::move (fields)), type_name); + + expanded + = partial_eq_impl (std::move (fn), type_name, item.get_generic_params ()); +} + +MatchCase +DerivePartialEq::match_enum_identifier ( + PathInExpression variant_path, const std::unique_ptr &variant) +{ + auto inner_ref_patterns + = vec (builder.ref_pattern ( + std::unique_ptr (new PathInExpression (variant_path))), + builder.ref_pattern ( + std::unique_ptr (new PathInExpression (variant_path)))); + + auto tuple_items = std::make_unique ( + std::move (inner_ref_patterns)); + + auto pattern = std::make_unique (std::move (tuple_items), loc); + + return builder.match_case (std::move (pattern), builder.literal_bool (true)); +} + +MatchCase +DerivePartialEq::match_enum_tuple (PathInExpression variant_path, + const EnumItemTuple &variant) +{ + auto self_patterns = std::vector> (); + auto other_patterns = std::vector> (); + + auto self_other_exprs = std::vector (); + + for (size_t i = 0; i < variant.get_tuple_fields ().size (); i++) + { + // The patterns we're creating for each field are `self_` and + // `other_` where `i` is the index of the field. It doesn't actually + // matter what we use, as long as it's ordered, unique, and that we can + // reuse it in the match case's return expression to check that they are + // equal. + + auto self_pattern_str = "__self_" + std::to_string (i); + auto other_pattern_str = "__other_" + std::to_string (i); + + rust_debug ("]ARTHUR[ %s", self_pattern_str.c_str ()); + + self_patterns.emplace_back ( + builder.identifier_pattern (self_pattern_str)); + other_patterns.emplace_back ( + builder.identifier_pattern (other_pattern_str)); + + self_other_exprs.emplace_back (SelfOther{ + builder.identifier (self_pattern_str), + builder.identifier (other_pattern_str), + }); + } + + auto self_pattern_items = std::unique_ptr ( + new TupleStructItemsNoRange (std::move (self_patterns))); + auto other_pattern_items = std::unique_ptr ( + new TupleStructItemsNoRange (std::move (other_patterns))); + + auto self_pattern = std::unique_ptr ( + new ReferencePattern (std::unique_ptr (new TupleStructPattern ( + variant_path, std::move (self_pattern_items))), + false, false, loc)); + auto other_pattern = std::unique_ptr ( + new ReferencePattern (std::unique_ptr (new TupleStructPattern ( + variant_path, std::move (other_pattern_items))), + false, false, loc)); + + auto tuple_items = std::make_unique ( + vec (std::move (self_pattern), std::move (other_pattern))); + + auto pattern = std::make_unique (std::move (tuple_items), loc); + + auto expr = build_eq_expression (std::move (self_other_exprs)); + + return builder.match_case (std::move (pattern), std::move (expr)); +} + +MatchCase +DerivePartialEq::match_enum_struct (PathInExpression variant_path, + const EnumItemStruct &variant) +{ + // NOTE: We currently do not support compiling struct patterns where an + // identifier is assigned a new pattern, e.g. Bloop { f0: x } + // This is what we should be using to compile PartialEq for enum struct + // variants, as we need to be comparing the field of each instance meaning we + // need to give two different names to two different instances of the same + // field. We cannot just use the field's name like we do when deriving + // `Clone`. + + rust_unreachable (); +} + +void +DerivePartialEq::visit_enum (Enum &item) +{ + auto cases = std::vector (); + + for (auto &variant : item.get_variants ()) + { + auto variant_path + = builder.variant_path (item.get_identifier ().as_string (), + variant->get_identifier ().as_string ()); + + switch (variant->get_enum_item_kind ()) + { + case EnumItem::Kind::Identifier: + case EnumItem::Kind::Discriminant: + cases.emplace_back (match_enum_identifier (variant_path, variant)); + break; + case EnumItem::Kind::Tuple: + cases.emplace_back ( + match_enum_tuple (variant_path, + static_cast (*variant))); + break; + case EnumItem::Kind::Struct: + rust_sorry_at ( + item.get_locus (), + "cannot derive(PartialEq) for enum struct variants yet"); + break; + } + } + + // NOTE: Mention using discriminant_value and skipping that last case, and + // instead skipping all identifiers/discriminant enum items and returning + // `true` in the wildcard case + + // In case the two instances of `Self` don't have the same discriminant, + // automatically return false. + cases.emplace_back ( + builder.match_case (builder.wildcard (), builder.literal_bool (false))); + + auto match + = builder.match (builder.tuple (vec (builder.identifier ("self"), + builder.identifier ("other"))), + std::move (cases)); + + auto fn = eq_fn (std::move (match), item.get_identifier ().as_string ()); + + expanded + = partial_eq_impl (std::move (fn), item.get_identifier ().as_string (), + item.get_generic_params ()); +} + +void +DerivePartialEq::visit_union (Union &item) +{ + rust_error_at (item.get_locus (), + "derive(PartialEq) cannot be used on unions"); +} + +} // namespace AST +} // namespace Rust diff --git a/gcc/rust/expand/rust-derive-partial-eq.h b/gcc/rust/expand/rust-derive-partial-eq.h new file mode 100644 index 00000000000..2bc18d2b98a --- /dev/null +++ b/gcc/rust/expand/rust-derive-partial-eq.h @@ -0,0 +1,81 @@ +// Copyright (C) 2025 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_DERIVE_PARTIAL_EQ_H +#define RUST_DERIVE_PARTIAL_EQ_H + +#include "rust-derive.h" +#include "rust-path.h" + +namespace Rust { +namespace AST { + +class DerivePartialEq : DeriveVisitor +{ +public: + DerivePartialEq (location_t loc); + + std::unique_ptr go (Item &item); + +private: + std::unique_ptr expanded; + + std::unique_ptr partial_eq_impl ( + std::unique_ptr &&eq_fn, std::string name, + const std::vector> &type_generics); + + std::unique_ptr eq_fn (std::unique_ptr &&cmp_expression, + std::string type_name); + + /** + * A pair of two expressions from each instance being compared. E.g. this + * could be `self.0` and `other.0`, or `self.field` and `other.field` + */ + struct SelfOther + { + std::unique_ptr self_expr; + std::unique_ptr other_expr; + }; + + SelfOther tuple_indexes (int idx); + SelfOther field_acccesses (const std::string &field_name); + + /** + * Build a suite of equality arithmetic expressions chained together by a + * boolean AND operator + */ + std::unique_ptr + build_eq_expression (std::vector &&field_expressions); + + MatchCase match_enum_identifier (PathInExpression variant_path, + const std::unique_ptr &variant); + MatchCase match_enum_tuple (PathInExpression variant_path, + const EnumItemTuple &variant); + MatchCase match_enum_struct (PathInExpression variant_path, + const EnumItemStruct &variant); + + virtual void visit_struct (StructStruct &item); + virtual void visit_tuple (TupleStruct &item); + virtual void visit_enum (Enum &item); + virtual void visit_union (Union &item); +}; + +} // namespace AST +} // namespace Rust + +#endif // ! RUST_DERIVE_PARTIAL_EQ_H diff --git a/gcc/rust/expand/rust-derive.cc b/gcc/rust/expand/rust-derive.cc index 8226a61a787..6f026f3da64 100644 --- a/gcc/rust/expand/rust-derive.cc +++ b/gcc/rust/expand/rust-derive.cc @@ -22,6 +22,7 @@ #include "rust-derive-debug.h" #include "rust-derive-default.h" #include "rust-derive-eq.h" +#include "rust-derive-partial-eq.h" namespace Rust { namespace AST { @@ -51,6 +52,7 @@ DeriveVisitor::derive (Item &item, const Attribute &attr, case BuiltinMacro::Eq: return DeriveEq (attr.get_locus ()).go (item); case BuiltinMacro::PartialEq: + return DerivePartialEq (attr.get_locus ()).go (item); case BuiltinMacro::Ord: case BuiltinMacro::PartialOrd: case BuiltinMacro::Hash: diff --git a/gcc/testsuite/rust/compile/derive-eq-invalid.rs b/gcc/testsuite/rust/compile/derive-eq-invalid.rs index 017241db86d..0c4d48ef6ea 100644 --- a/gcc/testsuite/rust/compile/derive-eq-invalid.rs +++ b/gcc/testsuite/rust/compile/derive-eq-invalid.rs @@ -1,5 +1,6 @@ mod core { mod cmp { + #[lang = "eq"] pub trait PartialEq { fn eq(&self, other: &Rhs) -> bool; @@ -14,11 +15,11 @@ mod core { } } -// #[lang = "phantom_data"] -// struct PhantomData; +#[lang = "phantom_data"] +struct PhantomData; -// #[lang = "sized"] -// trait Sized {} +#[lang = "sized"] +trait Sized {} #[derive(PartialEq)] struct NotEq; diff --git a/gcc/testsuite/rust/compile/derive-partialeq1.rs b/gcc/testsuite/rust/compile/derive-partialeq1.rs new file mode 100644 index 00000000000..71513241929 --- /dev/null +++ b/gcc/testsuite/rust/compile/derive-partialeq1.rs @@ -0,0 +1,59 @@ +#![feature(intrinsics)] + +#[lang = "sized"] +trait Sized {} + +#[lang = "copy"] +trait Copy {} + +#[lang = "eq"] +pub trait PartialEq { + /// This method tests for `self` and `other` values to be equal, and is used + /// by `==`. + #[must_use] + #[stable(feature = "rust1", since = "1.0.0")] + fn eq(&self, other: &Rhs) -> bool; + + /// This method tests for `!=`. + #[inline] + #[must_use] + #[stable(feature = "rust1", since = "1.0.0")] + fn ne(&self, other: &Rhs) -> bool { + !self.eq(other) + } +} + +#[derive(PartialEq, Copy)] // { dg-warning "unused name" } +struct Foo; + +#[derive(PartialEq)] +struct Bar(Foo); + +#[derive(PartialEq)] +struct Baz { _inner: Foo } + +extern "C" { + fn puts(s: *const i8); +} + +fn print(b: bool) { + if b { + unsafe { puts("true" as *const str as *const i8) } + } else { + unsafe { puts("false" as *const str as *const i8) } + } +} + +fn main() -> i32 { + let x = Foo; + + let b1 = x == Foo; + let b2 = Bar(x) != Bar(Foo); + let b3 = Baz { _inner: Foo } != Baz { _inner: x }; + + print(b1); + print(b2); + print(b3); + + 0 +} diff --git a/gcc/testsuite/rust/compile/nr2/exclude b/gcc/testsuite/rust/compile/nr2/exclude index 1b34e9fe20e..222aed53111 100644 --- a/gcc/testsuite/rust/compile/nr2/exclude +++ b/gcc/testsuite/rust/compile/nr2/exclude @@ -124,6 +124,8 @@ traits12.rs try-trait.rs derive-debug1.rs issue-3382.rs -derive-default1.rs issue-3402-1.rs +derive-default1.rs +derive-eq-invalid.rs +derive-partialeq1.rs # please don't delete the trailing newline diff --git a/gcc/testsuite/rust/execute/torture/derive-partialeq1.rs b/gcc/testsuite/rust/execute/torture/derive-partialeq1.rs new file mode 100644 index 00000000000..4d5124e85cf --- /dev/null +++ b/gcc/testsuite/rust/execute/torture/derive-partialeq1.rs @@ -0,0 +1,61 @@ +// { dg-output "true\r*\nfalse\r*\nfalse\r*\n" } + +#![feature(intrinsics)] + +#[lang = "sized"] +trait Sized {} + +#[lang = "copy"] +trait Copy {} + +#[lang = "eq"] +pub trait PartialEq { + /// This method tests for `self` and `other` values to be equal, and is used + /// by `==`. + #[must_use] + #[stable(feature = "rust1", since = "1.0.0")] + fn eq(&self, other: &Rhs) -> bool; + + /// This method tests for `!=`. + #[inline] + #[must_use] + #[stable(feature = "rust1", since = "1.0.0")] + fn ne(&self, other: &Rhs) -> bool { + !self.eq(other) + } +} + +#[derive(PartialEq, Copy)] // { dg-warning "unused name" } +struct Foo; + +#[derive(PartialEq)] +struct Bar(Foo); + +#[derive(PartialEq)] +struct Baz { _inner: Foo } + +extern "C" { + fn puts(s: *const i8); +} + +fn print(b: bool) { + if b { + unsafe { puts("true" as *const str as *const i8) } + } else { + unsafe { puts("false" as *const str as *const i8) } + } +} + +fn main() -> i32 { + let x = Foo; + + let b1 = x == Foo; + let b2 = Bar(x) != Bar(Foo); + let b3 = Baz { _inner: Foo } != Baz { _inner: x }; + + print(b1); + print(b2); + print(b3); + + 0 +}