Skip to content

Commit

Permalink
Do not combine all-reduces with control dependencies.
Browse files Browse the repository at this point in the history
These control dependencies would otherwise be transferred to some get-tuple-element op, which triggers an assertion when replacing the op.

PiperOrigin-RevId: 726624437
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 13, 2025
1 parent bd67ea2 commit 4b352b7
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/hlo/transforms/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings:string_view",
Expand Down
41 changes: 41 additions & 0 deletions xla/hlo/transforms/collectives/all_reduce_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"

Expand All @@ -41,7 +42,9 @@ namespace {

using std::nullopt;
using ::testing::AllOf;
using tsl::testing::IsOkAndHolds;
namespace op = xla::testing::opcode_matchers;

int64_t kMaxCombineCount = 256;

int64_t AllReduceCount(const HloModule& module) {
Expand Down Expand Up @@ -377,6 +380,44 @@ ENTRY entry {
EXPECT_FALSE(changed);
}

TEST_F(AllReduceCombinerTest, DoNotCombineWithControlDependencies) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0)
param1 = f32[128] parameter(1)
// This all-reduce must happen first, which is enforced by the control
// dependency and must be respected.
lead_ar = f32[128] all-reduce(param0), replica_groups={{0}}, to_apply=add,
channel_id=1
// These all-reduce have control dependencies and must not be combined.
ar0 = f32[128] all-reduce(lead_ar),
replica_groups={{0}}, to_apply=add, channel_id=2,
control-predecessors={lead_ar}
ar1 = f32[128] all-reduce(param1),
replica_groups={{0}}, to_apply=add, channel_id=3,
control-predecessors={lead_ar}
ROOT tuple = (f32[128], f32[128]) tuple(ar0, ar1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 3);
ASSERT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 3);
}

TEST_F(AllReduceCombinerTest, CrossCoreAllReduce) {
const char* const hlo_string = R"(
HloModule Module
Expand Down
8 changes: 8 additions & 0 deletions xla/service/all_reduce_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ namespace xla {
std::optional<AllReduceKey> GetAllReduceKey(const HloInstruction* instruction,
const HloDomainMap* domain_map,
bool ignore_replica_groups) {
// TODO(b/396147741): Support all-reduce combining with control dependencies.
// Currently, this would crash when replacing the original all-reduces. Such
// control dependencies would need to be transferred to the new combined
// all-reduce.
if (instruction->HasControlDependencies()) {
return std::nullopt;
}

if (instruction->opcode() != HloOpcode::kAllReduce &&
instruction->opcode() != HloOpcode::kReduceScatter) {
return std::nullopt;
Expand Down
43 changes: 43 additions & 0 deletions xla/service/gpu/transforms/collectives/all_reduce_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,49 @@ TEST_F(GpuAllReduceCombinerTest, CombinesSynchronousCollectivesMaximally) {
op::GetTupleElement(combined_all_reduce, 1)));
}

TEST_F(GpuAllReduceCombinerTest,
DoNotCombineCollectivesWithControlDependencies) {
absl::string_view kHloText = R"(
HloModule m
add {
p0 = f16[] parameter(0)
p1 = f16[] parameter(1)
ROOT add = f16[] add(p0, p1)
}
ENTRY main {
p0 = f16[10000000]{0} parameter(0)
p1 = f16[10000000]{0} parameter(1)
// This all-reduce must happen first, which is enforced by the control
// dependency and must be respected.
lead_ar = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add
// These all-reduce have control dependencies and must not be combined.
ar0 = f16[10000000]{0} all-reduce(p0), replica_groups={}, to_apply=add,
control-predecessors={lead_ar}
ar1 = f16[10000000]{0} all-reduce(p1), replica_groups={}, to_apply=add,
control-predecessors={lead_ar}
ROOT result = tuple(ar0, ar1)
}
)";
DeviceDescription device_info;
device_info.set_device_memory_size(10000000000); // 10GB

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
GpuAllReduceCombiner combiner(
device_info, /*default_combine_threshold_in_bytes=*/
kDefaultAllReduceCombineThreshold,
/*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold,
/*combine_threshold_count=*/256, /*pointer_size=*/4);

module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_experimental_enable_sync_collective_combining(true);
EXPECT_THAT(combiner.Run(module.get()), IsOkAndHolds(false));
}

} // namespace

} // namespace xla::gpu

0 comments on commit 4b352b7

Please sign in to comment.