Skip to content

Commit

Permalink
[XLA:GPU] Add AllGather support to collective generation tool.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726033258
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Feb 12, 2025
1 parent 9b04b86 commit e7e6f6a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
25 changes: 22 additions & 3 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type,
case CollectivePerfTableGen::CollectiveType::ALL_REDUCE:
dim_size = tensor_size_bytes / kBytesPerElem;
break;
case CollectivePerfTableGen::CollectiveType::ALL_GATHER:
dim_size = tensor_size_bytes /
(kBytesPerElem * replica_groups.num_devices_per_group());
break;
default:
LOG(FATAL) << "Unsupported collective type.";
}
Expand All @@ -91,6 +95,7 @@ int64_t GetOutputDim(CollectivePerfTableGen::CollectiveType type,
CHECK_EQ(tensor_size_bytes % kBytesPerElem, 0);
switch (type) {
case CollectivePerfTableGen::CollectiveType::ALL_REDUCE:
case CollectivePerfTableGen::CollectiveType::ALL_GATHER:
dim_size = tensor_size_bytes / kBytesPerElem;
break;
default:
Expand All @@ -102,10 +107,11 @@ int64_t GetOutputDim(CollectivePerfTableGen::CollectiveType type,
std::string GetHlo(CollectivePerfTableGen::CollectiveType type,
int64_t input_dim, int64_t output_dim,
const IotaReplicaGroupList& replica_groups) {
CHECK_EQ(kBytesPerElem, 4);

std::string hlo;
switch (type) {
case CollectivePerfTableGen::CollectiveType::ALL_REDUCE:
CHECK_EQ(kBytesPerElem, 4);
hlo = absl::Substitute(R"(
HloModule m
Expand All @@ -117,11 +123,24 @@ std::string GetHlo(CollectivePerfTableGen::CollectiveType type,
ENTRY e {
p0 = $0[$1] parameter(0)
ROOT _ = $0[$2] $3(p0), replica_groups=$4,
ROOT _ = $0[$2] all-reduce(p0), replica_groups=$3,
to_apply=add, use_global_device_ids=true, channel_id=1
}
)",
"f32", input_dim, output_dim, "all-reduce",
"f32", input_dim, output_dim,
replica_groups.ToString());
break;
case CollectivePerfTableGen::CollectiveType::ALL_GATHER:
hlo = absl::Substitute(R"(
HloModule m
ENTRY e {
p0 = $0[$1] parameter(0)
ROOT _ = $0[$2] all-gather(p0), replica_groups=$3,
use_global_device_ids=true, channel_id=1, dimensions={0}
}
)",
"f32", input_dim, output_dim,
replica_groups.ToString());
break;
default:
Expand Down
1 change: 1 addition & 0 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class CollectivePerfTableGen {
enum class CollectiveType {
UNSPECIFIED,
ALL_REDUCE,
ALL_GATHER,
};

struct Config {
Expand Down
7 changes: 5 additions & 2 deletions xla/tools/collective_perf_table_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ std::vector<CollectivePerfTableGen::CollectiveType> ParseCollectives(
types.push_back(CollectivePerfTableGen::CollectiveType::ALL_REDUCE);
continue;
}
if (token == "ALL_GATHER") {
types.push_back(CollectivePerfTableGen::CollectiveType::ALL_GATHER);
continue;
}
}
CHECK_GT(types.size(), 0);
return types;
Expand All @@ -147,7 +151,6 @@ CollectivePerfTableGen::StepSpec ParseStepSpec(absl::string_view unparsed) {

} // namespace

// TODO(b/390097558): Add support for other collectives: AG, RS.
// TODO(b/390097558): Add an option to generate perf table for collective which
// gets overlap to model resource contention.
int main(int argc, char* argv[]) {
Expand All @@ -167,7 +170,7 @@ int main(int argc, char* argv[]) {
"across the distributed system you run it on."),
tsl::Flag("collectives", &collectives_unparsed,
"Comma separated list of collectives to generate perf table "
"for. Allowed values: ALL_REDUCE."),
"for. Allowed values: ALL_REDUCE, ALL_GATHER."),
tsl::Flag("tensor_size_bytes_spec", &tensor_size_bytes_spec_unparsed,
"Spec for a search sweep over transfer sizes. Format example: "
"start=1,stop=8,factor=2 generates {1,2,4,8}."),
Expand Down
14 changes: 10 additions & 4 deletions xla/tools/collective_perf_table_gen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ TEST_F(CollectivePerfTableGenTest, EmptyConfigReturnsEmptyProto) {
}

TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) {
cfg_.collective_types = {CollectivePerfTableGen::CollectiveType::ALL_REDUCE};
cfg_.collective_types = {
CollectivePerfTableGen::CollectiveType::ALL_REDUCE,
CollectivePerfTableGen::CollectiveType::ALL_GATHER,
};
IotaReplicaGroupList iota(1, 1);
cfg_.replica_groups_list = {iota};
CollectivePerfTableGen::StepSpec spec{
Expand All @@ -70,11 +73,14 @@ TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) {

DeviceHloInstructionProfiles profiles = gen->ComputeTable();
EXPECT_EQ(profiles.entries_size(), 1);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 5);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 10);
}

TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) {
cfg_.collective_types = {CollectivePerfTableGen::CollectiveType::ALL_REDUCE};
cfg_.collective_types = {
CollectivePerfTableGen::CollectiveType::ALL_REDUCE,
CollectivePerfTableGen::CollectiveType::ALL_GATHER,
};
IotaReplicaGroupList iota(1, 1);
cfg_.replica_groups_list = {iota};
CollectivePerfTableGen::StepSpec spec{
Expand All @@ -90,7 +96,7 @@ TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) {

DeviceHloInstructionProfiles profiles = gen->ComputeTable();
EXPECT_EQ(profiles.entries_size(), 1);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 4);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 8);
}

} // namespace
Expand Down

0 comments on commit e7e6f6a

Please sign in to comment.