Skip to content

Commit

Permalink
Use IsOkAndHolds in all-reduce combiner test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726896504
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 14, 2025
1 parent a16d96f commit 8ca669b
Showing 1 changed file with 22 additions and 28 deletions.
50 changes: 22 additions & 28 deletions xla/hlo/transforms/collectives/all_reduce_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ limitations under the License.

#include "xla/hlo/transforms/collectives/all_reduce_combiner.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -118,9 +120,8 @@ TEST_F(AllReduceCombinerTest, CombineAllReduces) {
// Run the AllReduce combiner optimization pass.
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
ASSERT_EQ(AllReduceCount(*module), 1);
EXPECT_TRUE(changed);

ASSERT_EQ(root, computation->root_instruction());
ASSERT_EQ(inputs.size(), root->operands().size());
Expand Down Expand Up @@ -166,10 +167,9 @@ TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) {
// Run the AllReduce combiner optimization pass.
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
ASSERT_EQ(AllReduceCount(*module), 3)
<< "expects 3 groups for 3 reduction types.";
EXPECT_TRUE(changed);
}

// Tests that the combination threshold is respected.
Expand All @@ -188,19 +188,17 @@ TEST_F(AllReduceCombinerTest, RespectThreshold) {
{
AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), inputs.size());
EXPECT_FALSE(changed);
}

// Run the AllReduce combiner optimization pass again with a slightly
// higher threshold so that the combination can occur.
{
AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 1);
EXPECT_TRUE(changed);
}
}

Expand All @@ -226,9 +224,8 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

// Tests that AllReduce ops with different groups are not combined.
Expand All @@ -255,9 +252,8 @@ TEST_F(AllReduceCombinerTest, GroupAllReduce) {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

TEST_F(AllReduceCombinerTest, DomainPreventsCombining) {
Expand All @@ -278,9 +274,11 @@ ENTRY entry {
crs1 = f32[128] all-reduce(param1),
replica_groups={}, to_apply=summit, sharding={maximal device=1}
domain0 = f32[128] domain(crs0),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}}
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}},
exit={maximal device=0}}
domain1 = f32[128] domain(crs1),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}}
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}},
exit={maximal device=1}}
ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1),
sharding={{maximal device=0}, {maximal device=1}}
}
Expand All @@ -291,9 +289,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

// This test checks that two CRS instructions that are in separate domains
Expand Down Expand Up @@ -336,9 +333,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 3);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_TRUE(changed);

// Verify that the sharding is combined correctly.
const HloInstruction* param0 =
Expand Down Expand Up @@ -375,9 +371,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

TEST_F(AllReduceCombinerTest, DoNotCombineWithControlDependencies) {
Expand Down Expand Up @@ -453,9 +448,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 4);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_TRUE(changed);

EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Add(op::Domain(op::GetTupleElement(AllOf(
Expand Down Expand Up @@ -501,9 +495,8 @@ ENTRY %comp {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 6);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 4);
EXPECT_TRUE(changed);

auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1)));
auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)),
Expand All @@ -527,16 +520,17 @@ TEST_F(AllReduceCombinerTest, PreservesMetadata) {
ENTRY entry {
%param.0 = f32[32] parameter(0)
%param.1 = f32[32] parameter(1)
%all-reduce.0 = f32[32] all-reduce(%param.0), replica_groups={}, to_apply=%add, metadata={op_type="test_type0" op_name="test_name0"}
%all-reduce.1 = f32[32] all-reduce(%param.1), replica_groups={}, to_apply=%add, metadata={op_type="test_type1" op_name="test_name1"}
%all-reduce.0 = f32[32] all-reduce(%param.0), replica_groups={},
to_apply=%add, metadata={op_type="test_type0" op_name="test_name0"}
%all-reduce.1 = f32[32] all-reduce(%param.1), replica_groups={},
to_apply=%add, metadata={op_type="test_type1" op_name="test_name1"}
ROOT tuple = (f32[32], f32[32]) tuple(%all-reduce.0, %all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
OpMetadata metadata;
metadata.set_op_type("test_type0");
metadata.set_op_name("test_name0");
Expand Down

0 comments on commit 8ca669b

Please sign in to comment.