Skip to content

Commit

Permalink
[XLA:GPU] Add ReduceScatter support to collective generation tool.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726407193
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Feb 13, 2025
1 parent b523ee6 commit 6b470af
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
27 changes: 27 additions & 0 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type,
dim_size = tensor_size_bytes /
(kBytesPerElem * replica_groups.num_devices_per_group());
break;
case CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER:
dim_size = tensor_size_bytes / kBytesPerElem;
break;
default:
LOG(FATAL) << "Unsupported collective type.";
}
Expand All @@ -98,6 +101,10 @@ int64_t GetOutputDim(CollectivePerfTableGen::CollectiveType type,
case CollectivePerfTableGen::CollectiveType::ALL_GATHER:
dim_size = tensor_size_bytes / kBytesPerElem;
break;
case CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER:
dim_size = tensor_size_bytes /
(kBytesPerElem * replica_groups.num_devices_per_group());
break;
default:
LOG(FATAL) << "Unsupported collective type.";
}
Expand Down Expand Up @@ -143,6 +150,26 @@ std::string GetHlo(CollectivePerfTableGen::CollectiveType type,
"f32", input_dim, output_dim,
replica_groups.ToString());
break;
case CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER:
hlo = absl::Substitute(R"(
HloModule m
add {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT res = add(a, b)
}
ENTRY e {
p0 = $0[$1] parameter(0)
ROOT _ = $0[$2] reduce-scatter(p0), replica_groups=$3,
to_apply=add, use_global_device_ids=true, channel_id=1,
dimensions={0}
}
)",
"f32", input_dim, output_dim,
replica_groups.ToString());
break;
default:
LOG(FATAL) << "Unsupported collective type.";
}
Expand Down
1 change: 1 addition & 0 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CollectivePerfTableGen {
UNSPECIFIED,
ALL_REDUCE,
ALL_GATHER,
REDUCE_SCATTER,
};

struct Config {
Expand Down
6 changes: 5 additions & 1 deletion xla/tools/collective_perf_table_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ std::vector<CollectivePerfTableGen::CollectiveType> ParseCollectives(
types.push_back(CollectivePerfTableGen::CollectiveType::ALL_GATHER);
continue;
}
if (token == "REDUCE_SCATTER") {
types.push_back(CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER);
continue;
}
}
CHECK_GT(types.size(), 0);
return types;
Expand Down Expand Up @@ -170,7 +174,7 @@ int main(int argc, char* argv[]) {
"across the distributed system you run it on."),
tsl::Flag("collectives", &collectives_unparsed,
"Comma separated list of collectives to generate perf table "
"for. Allowed values: ALL_REDUCE, ALL_GATHER."),
"for. Allowed values: ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER."),
tsl::Flag("tensor_size_bytes_spec", &tensor_size_bytes_spec_unparsed,
"Spec for a search sweep over transfer sizes. Format example: "
"start=1,stop=8,factor=2 generates {1,2,4,8}."),
Expand Down
6 changes: 4 additions & 2 deletions xla/tools/collective_perf_table_gen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) {
cfg_.collective_types = {
CollectivePerfTableGen::CollectiveType::ALL_REDUCE,
CollectivePerfTableGen::CollectiveType::ALL_GATHER,
CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER,
};
IotaReplicaGroupList iota(1, 1);
cfg_.replica_groups_list = {iota};
Expand All @@ -73,13 +74,14 @@ TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) {

DeviceHloInstructionProfiles profiles = gen->ComputeTable();
EXPECT_EQ(profiles.entries_size(), 1);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 10);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 15);
}

TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) {
cfg_.collective_types = {
CollectivePerfTableGen::CollectiveType::ALL_REDUCE,
CollectivePerfTableGen::CollectiveType::ALL_GATHER,
CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER,
};
IotaReplicaGroupList iota(1, 1);
cfg_.replica_groups_list = {iota};
Expand All @@ -96,7 +98,7 @@ TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) {

DeviceHloInstructionProfiles profiles = gen->ComputeTable();
EXPECT_EQ(profiles.entries_size(), 1);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 8);
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 12);
}

} // namespace
Expand Down

0 comments on commit 6b470af

Please sign in to comment.