From e7e6f6acf20ac441839ecf6db48906f074950f37 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 12 Feb 2025 06:37:31 -0800 Subject: [PATCH] [XLA:GPU] Add AllGather support to collective generation tool. PiperOrigin-RevId: 726033258 --- xla/tools/collective_perf_table_gen.cc | 25 ++++++++++++++++++--- xla/tools/collective_perf_table_gen.h | 1 + xla/tools/collective_perf_table_gen_main.cc | 7 ++++-- xla/tools/collective_perf_table_gen_test.cc | 14 ++++++++---- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/xla/tools/collective_perf_table_gen.cc b/xla/tools/collective_perf_table_gen.cc index 15ce12a060ec6..0448e35eb54fb 100644 --- a/xla/tools/collective_perf_table_gen.cc +++ b/xla/tools/collective_perf_table_gen.cc @@ -78,6 +78,10 @@ int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type, case CollectivePerfTableGen::CollectiveType::ALL_REDUCE: dim_size = tensor_size_bytes / kBytesPerElem; break; + case CollectivePerfTableGen::CollectiveType::ALL_GATHER: + dim_size = tensor_size_bytes / + (kBytesPerElem * replica_groups.num_devices_per_group()); + break; default: LOG(FATAL) << "Unsupported collective type."; } @@ -91,6 +95,7 @@ int64_t GetOutputDim(CollectivePerfTableGen::CollectiveType type, CHECK_EQ(tensor_size_bytes % kBytesPerElem, 0); switch (type) { case CollectivePerfTableGen::CollectiveType::ALL_REDUCE: + case CollectivePerfTableGen::CollectiveType::ALL_GATHER: dim_size = tensor_size_bytes / kBytesPerElem; break; default: @@ -102,10 +107,11 @@ int64_t GetOutputDim(CollectivePerfTableGen::CollectiveType type, std::string GetHlo(CollectivePerfTableGen::CollectiveType type, int64_t input_dim, int64_t output_dim, const IotaReplicaGroupList& replica_groups) { + CHECK_EQ(kBytesPerElem, 4); + std::string hlo; switch (type) { case CollectivePerfTableGen::CollectiveType::ALL_REDUCE: - CHECK_EQ(kBytesPerElem, 4); hlo = absl::Substitute(R"( HloModule m @@ -117,11 +123,24 @@ std::string GetHlo(CollectivePerfTableGen::CollectiveType type, ENTRY e { p0 = $0[$1] parameter(0) - ROOT _ = $0[$2] $3(p0), replica_groups=$4, + ROOT _ = $0[$2] all-reduce(p0), replica_groups=$3, to_apply=add, use_global_device_ids=true, channel_id=1 } )", - "f32", input_dim, output_dim, "all-reduce", + "f32", input_dim, output_dim, + replica_groups.ToString()); + break; + case CollectivePerfTableGen::CollectiveType::ALL_GATHER: + hlo = absl::Substitute(R"( + HloModule m + + ENTRY e { + p0 = $0[$1] parameter(0) + ROOT _ = $0[$2] all-gather(p0), replica_groups=$3, + use_global_device_ids=true, channel_id=1, dimensions={0} + } + )", + "f32", input_dim, output_dim, replica_groups.ToString()); break; default: diff --git a/xla/tools/collective_perf_table_gen.h b/xla/tools/collective_perf_table_gen.h index 354ad45bd0c6e..f1f2a4603ddc1 100644 --- a/xla/tools/collective_perf_table_gen.h +++ b/xla/tools/collective_perf_table_gen.h @@ -52,6 +52,7 @@ class CollectivePerfTableGen { enum class CollectiveType { UNSPECIFIED, ALL_REDUCE, + ALL_GATHER, }; struct Config { diff --git a/xla/tools/collective_perf_table_gen_main.cc b/xla/tools/collective_perf_table_gen_main.cc index c4fc08f090e2e..0af0688fdd9d2 100644 --- a/xla/tools/collective_perf_table_gen_main.cc +++ b/xla/tools/collective_perf_table_gen_main.cc @@ -121,6 +121,10 @@ std::vector ParseCollectives( types.push_back(CollectivePerfTableGen::CollectiveType::ALL_REDUCE); continue; } + if (token == "ALL_GATHER") { + types.push_back(CollectivePerfTableGen::CollectiveType::ALL_GATHER); + continue; + } } CHECK_GT(types.size(), 0); return types; @@ -147,7 +151,6 @@ CollectivePerfTableGen::StepSpec ParseStepSpec(absl::string_view unparsed) { } // namespace -// TODO(b/390097558): Add support for other collectives: AG, RS. // TODO(b/390097558): Add an option to generate perf table for collective which // gets overlap to model resource contention. int main(int argc, char* argv[]) { @@ -167,7 +170,7 @@ int main(int argc, char* argv[]) { "across the distributed system you run it on."), tsl::Flag("collectives", &collectives_unparsed, "Comma separated list of collectives to generate perf table " - "for. Allowed values: ALL_REDUCE."), + "for. Allowed values: ALL_REDUCE, ALL_GATHER."), tsl::Flag("tensor_size_bytes_spec", &tensor_size_bytes_spec_unparsed, "Spec for a search sweep over transfer sizes. Format example: " "start=1,stop=8,factor=2 generates {1,2,4,8}."), diff --git a/xla/tools/collective_perf_table_gen_test.cc b/xla/tools/collective_perf_table_gen_test.cc index 9bbc8151e9ff5..3f17bc0b3ce0a 100644 --- a/xla/tools/collective_perf_table_gen_test.cc +++ b/xla/tools/collective_perf_table_gen_test.cc @@ -54,7 +54,10 @@ TEST_F(CollectivePerfTableGenTest, EmptyConfigReturnsEmptyProto) { } TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) { - cfg_.collective_types = {CollectivePerfTableGen::CollectiveType::ALL_REDUCE}; + cfg_.collective_types = { + CollectivePerfTableGen::CollectiveType::ALL_REDUCE, + CollectivePerfTableGen::CollectiveType::ALL_GATHER, + }; IotaReplicaGroupList iota(1, 1); cfg_.replica_groups_list = {iota}; CollectivePerfTableGen::StepSpec spec{ @@ -70,11 +73,14 @@ TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) { DeviceHloInstructionProfiles profiles = gen->ComputeTable(); EXPECT_EQ(profiles.entries_size(), 1); - EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 5); + EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 10); } TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) { - cfg_.collective_types = {CollectivePerfTableGen::CollectiveType::ALL_REDUCE}; + cfg_.collective_types = { + CollectivePerfTableGen::CollectiveType::ALL_REDUCE, + CollectivePerfTableGen::CollectiveType::ALL_GATHER, + }; IotaReplicaGroupList iota(1, 1); cfg_.replica_groups_list = {iota}; CollectivePerfTableGen::StepSpec spec{ @@ -90,7 +96,7 @@ TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) { DeviceHloInstructionProfiles profiles = gen->ComputeTable(); EXPECT_EQ(profiles.entries_size(), 1); - EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 4); + EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 8); } } // namespace