Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Add python bindings to collective perf table generator. #22692

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ load(
"if_google",
"tsl_gpu_library",
)
load("//xla/tsl:tsl.default.bzl", "filegroup")
load("//xla/tsl:tsl.default.bzl", "filegroup", "tsl_pybind_extension")
load(
"//xla/tsl/platform:build_config.bzl",
"tf_proto_library",
Expand Down Expand Up @@ -716,6 +716,7 @@ cc_library(
"//xla/pjrt:pjrt_client",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:backend",
"//xla/service:hlo_module_config",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/service/gpu/model:hlo_op_profiler_lib",
"//xla/service/gpu/model:hlo_op_profiles",
Expand Down Expand Up @@ -759,14 +760,23 @@ cc_library(
alwayslink = True,
)

tsl_pybind_extension(
name = "collective_perf_table_gen_bindings",
srcs = ["collective_perf_table_gen_bindings.cc"],
deps = [
":collective_perf_table_gen",
"@com_google_absl//absl/log:check",
"@nanobind",
],
)

xla_test(
name = "collective_perf_table_gen_test",
srcs = ["collective_perf_table_gen_test.cc"],
backends = ["gpu"],
local_defines = if_cuda(["GOOGLE_CUDA"]),
deps = [
":collective_perf_table_gen",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/tests:hlo_test_base",
Expand Down
53 changes: 43 additions & 10 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/collective_device_list.h"
Expand All @@ -38,6 +39,7 @@ limitations under the License.
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/service/gpu/model/hlo_op_profiler.h"
#include "xla/service/gpu/model/hlo_op_profiles.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tools/multihost_hlo_runner/create_client.h"
#include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h"
#include "xla/tsl/platform/env.h"
Expand Down Expand Up @@ -186,7 +188,10 @@ std::unique_ptr<HloModule> CreateCollectiveModule(const StaticSpec& spec) {
std::string hlo =
GetHlo(spec.collective_type, input_dim, output_dim, spec.replica_groups);

auto parsed = ParseAndReturnUnverifiedModule(hlo);
HloModuleConfig config;
config.set_num_partitions(spec.replica_groups.num_devices_per_group() *
spec.replica_groups.num_replica_groups());
auto parsed = ParseAndReturnUnverifiedModule(hlo, config);
CHECK_OK(parsed.status());
return std::move(*parsed);
}
Expand All @@ -202,6 +207,16 @@ uint64_t GetNetworkThroughputBytesPerSec(absl::Duration runtime,
return tensor_size_bytes * 1e9 / absl::ToInt64Nanoseconds(runtime);
}

IotaReplicaGroupList GetCollectiveDeviceList(
absl::string_view collective_device_list_unparsed) {
auto collective_device_list =
xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed);
CHECK_OK(collective_device_list);
CHECK(collective_device_list->iota_replica_group_list().has_value());

return *collective_device_list->iota_replica_group_list();
}

} // namespace

/*static*/ std::unique_ptr<CollectivePerfTableGen>
Expand All @@ -224,7 +239,7 @@ std::unique_ptr<PjRtLoadedExecutable> CollectivePerfTableGen::Compile(
std::unique_ptr<HloModule> module) {
DebugOptions debug_opts;
FunctionalHloRunner::RawCompileOptions opts;
opts.num_partitions = 8;
opts.num_partitions = module->config().num_partitions();
opts.spmd_mode = FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning;
auto compile_opts = FunctionalHloRunner::CreateCompileOptions(
*pjrt_env_.client, opts, config_.task_id, config_.num_nodes);
Expand All @@ -250,18 +265,26 @@ CollectivePerfTableGen::ProfilingData CollectivePerfTableGen::Profile(
VLOG(1) << "Compiled module: "
<< executable->GetHloModules().value()[0]->ToString();

// We do not profile dry runs or on more than one tasks.
if (config_.dry_run) {
return {};
}

std::unique_ptr<HloOpProfiler::KernelTracer> tracer =
HloOpProfiler::GetKernelTracer();
if (config_.task_id == 0) {
std::unique_ptr<HloOpProfiler::KernelTracer> tracer =
HloOpProfiler::GetKernelTracer();
for (int i = 0; i < kNumProfilingRuns; ++i) {
Run(*executable);
}
return {
/*runtime=*/absl::Nanoseconds(
std::move(*tracer).getMedianKernelTimeNs()),
};
}
for (int i = 0; i < kNumProfilingRuns; ++i) {
Run(*executable);
}
return {
/*runtime=*/absl::Nanoseconds(std::move(*tracer).getMedianKernelTimeNs()),
};
return {};
}

DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
Expand All @@ -282,11 +305,14 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
for (int64_t tensor_size = tsize_spec.start; tensor_size <= tsize_spec.stop;
tensor_size = inc(tensor_size, tsize_spec)) {
for (CollectiveType collective_type : config_.collective_types) {
for (const IotaReplicaGroupList& replica_groups :
config_.replica_groups_list) {
for (absl::string_view replica_groups_raw : config_.replica_groups_list) {
CHECK(collective_type != CollectiveType::UNSPECIFIED);

StaticSpec spec{collective_type, replica_groups, tensor_size};
StaticSpec spec{
collective_type,
GetCollectiveDeviceList(replica_groups_raw),
tensor_size,
};
static_specs.push_back(spec);
}
}
Expand All @@ -313,6 +339,10 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
if (!config_.dry_run) {
profiled_data = Profile(std::move(spec.module));
}
if (profiled_data.runtime == absl::ZeroDuration()) {
VLOG(1) << "Size: " << static_spec.tensor_size_bytes << " too small.";
continue;
}
entry.set_network_throughput_bytes_per_sec(GetNetworkThroughputBytesPerSec(
profiled_data.runtime, static_spec.tensor_size_bytes));

Expand All @@ -332,6 +362,9 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {

absl::Status CollectivePerfTableGen::Dump(
const DeviceHloInstructionProfiles& table) {
if (config_.task_id != 0) {
return absl::OkStatus();
}
if (config_.output == CollectivePerfTableGen::Config::kStdout) {
LOG(INFO) << table.DebugString();
return absl::OkStatus();
Expand Down
17 changes: 10 additions & 7 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/service/backend.h"
Expand All @@ -43,10 +42,10 @@ namespace xla::gpu {
class CollectivePerfTableGen {
public:
struct StepSpec {
int64_t start = 0;
int64_t stop = -1;
int64_t start = 1024;
int64_t stop = 2ll * 1024 * 1024 * 1024;
int64_t step = 0;
int64_t factor = 0;
int64_t factor = 2;
};

enum class CollectiveType {
Expand All @@ -61,14 +60,18 @@ class CollectivePerfTableGen {

// Search space.
StepSpec tensor_size_bytes_spec;
std::vector<CollectiveType> collective_types;
std::vector<IotaReplicaGroupList> replica_groups_list;
std::vector<CollectiveType> collective_types = {
CollectiveType::ALL_REDUCE,
CollectiveType::ALL_GATHER,
CollectiveType::REDUCE_SCATTER,
};
std::vector<std::string> replica_groups_list;

// Execution opts.
bool dry_run = false;
std::string output = std::string(kStdout);
std::string coordinator_address = "";
absl::Duration connection_timeout = absl::Seconds(60);
absl::Duration connection_timeout = absl::Seconds(600);
uint16_t num_nodes = 1;
uint16_t task_id = 0;
};
Expand Down
74 changes: 74 additions & 0 deletions xla/tools/collective_perf_table_gen_bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "absl/log/check.h"
#include "nanobind/nanobind.h"
#include "nanobind/nb_defs.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/tools/collective_perf_table_gen.h"

namespace nb = nanobind;

NB_MODULE(collective_perf_table_gen_bindings, m) {
// Bind the Config struct
nb::class_<xla::gpu::CollectivePerfTableGen::Config>(m, "Config")
.def(nb::init<>())
.def_rw("tensor_size_bytes_spec",
&xla::gpu::CollectivePerfTableGen::Config::tensor_size_bytes_spec)
.def_rw("collective_types",
&xla::gpu::CollectivePerfTableGen::Config::collective_types)
.def_rw("replica_groups_list",
&xla::gpu::CollectivePerfTableGen::Config::replica_groups_list)
.def_rw("dry_run", &xla::gpu::CollectivePerfTableGen::Config::dry_run)
.def_rw("output", &xla::gpu::CollectivePerfTableGen::Config::output)
.def_rw("coordinator_address",
&xla::gpu::CollectivePerfTableGen::Config::coordinator_address)
.def_rw("connection_timeout",
&xla::gpu::CollectivePerfTableGen::Config::connection_timeout)
.def_rw("num_nodes", &xla::gpu::CollectivePerfTableGen::Config::num_nodes)
.def_rw("task_id", &xla::gpu::CollectivePerfTableGen::Config::task_id);

// Bind the StepSpec struct
nb::class_<xla::gpu::CollectivePerfTableGen::StepSpec>(m, "StepSpec")
.def(nb::init<>())
.def_rw("start", &xla::gpu::CollectivePerfTableGen::StepSpec::start)
.def_rw("stop", &xla::gpu::CollectivePerfTableGen::StepSpec::stop)
.def_rw("step", &xla::gpu::CollectivePerfTableGen::StepSpec::step)
.def_rw("factor", &xla::gpu::CollectivePerfTableGen::StepSpec::factor);

// Bind the CollectiveType enum
nb::enum_<xla::gpu::CollectivePerfTableGen::CollectiveType>(m,
"CollectiveType")
.value("UNSPECIFIED",
xla::gpu::CollectivePerfTableGen::CollectiveType::UNSPECIFIED)
.value("ALL_REDUCE",
xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_REDUCE)
.value("REDUCE_SCATTER",
xla::gpu::CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER)
.value("ALL_GATHER",
xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_GATHER)
.export_values();

m.def("run", [](xla::gpu::CollectivePerfTableGen::Config config) -> void {
std::unique_ptr<xla::gpu::CollectivePerfTableGen> gen =
xla::gpu::CollectivePerfTableGen::Create(config);
auto table = gen->ComputeTable();
CHECK_OK(gen->Dump(table));
});
}
60 changes: 33 additions & 27 deletions xla/tools/collective_perf_table_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "absl/strings/substitute.h"
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/tools/collective_perf_table_gen.h"
#include "xla/tsl/util/command_line_flags.h"
Expand Down Expand Up @@ -76,7 +75,6 @@ to 4 GPUs.

constexpr absl::string_view kDefaultCoordinatorAddress = "127.0.0.1:1234";

using ::xla::IotaReplicaGroupList;
using ::xla::gpu::CollectivePerfTableGen;
using ::xla::gpu::DeviceHloInstructionProfiles;

Expand All @@ -91,27 +89,6 @@ std::pair<std::string /*key*/, std::string /*value*/> ExtractKV(
return {key, value};
}

IotaReplicaGroupList GetCollectiveDeviceList(
absl::string_view collective_device_list_unparsed) {
auto collective_device_list =
xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed);
CHECK_OK(collective_device_list);
CHECK(collective_device_list->iota_replica_group_list().has_value());

return *collective_device_list->iota_replica_group_list();
}

std::vector<IotaReplicaGroupList> GetCollectiveDeviceLists(
absl::string_view collective_device_lists_unparsed) {
std::vector<IotaReplicaGroupList> device_lists;
for (absl::string_view token :
absl::StrSplit(collective_device_lists_unparsed, ';')) {
device_lists.emplace_back(GetCollectiveDeviceList(token));
}
CHECK_GT(device_lists.size(), 0);
return device_lists;
}

std::vector<CollectivePerfTableGen::CollectiveType> ParseCollectives(
absl::string_view unparsed) {
std::vector<CollectivePerfTableGen::CollectiveType> types;
Expand Down Expand Up @@ -153,22 +130,49 @@ CollectivePerfTableGen::StepSpec ParseStepSpec(absl::string_view unparsed) {
return spec;
}

std::vector<std::string> CollectiveDeviceLists(
absl::string_view device_list_unparsed) {
std::vector<std::string> result;
for (absl::string_view device_list :
absl::StrSplit(device_list_unparsed, ';')) {
result.emplace_back(device_list);
}
return result;
}

std::string DefaultCollectiveDevicesIfEmpty(
const std::string& collective_devices_spec_unparsed, int32_t num_nodes,
int32_t num_devices_per_host) {
if (collective_devices_spec_unparsed.empty()) {
return absl::Substitute("[1,$0]<=[$0];[$2,$1]<=[$1,$2]T(1,0)",
num_devices_per_host * num_nodes, num_nodes,
num_devices_per_host);
}
return collective_devices_spec_unparsed;
}

} // namespace

// TODO(b/390097558): Add an option to generate perf table for collective which
// gets overlap to model resource contention.
int main(int argc, char* argv[]) {
// Default args.
int32_t num_nodes = 1;
int32_t num_devices_per_host = 8;
int32_t task_id = 0;
std::string collectives_unparsed;
std::string tensor_size_bytes_spec_unparsed;
std::string collectives_unparsed = "ALL_REDUCE,ALL_GATHER,REDUCE_SCATTER";
std::string tensor_size_bytes_spec_unparsed =
"start=1024,stop=2147483648,factor=2";
std::string collective_devices_spec_unparsed;
std::string coordinator_address = std::string(kDefaultCoordinatorAddress);
std::string output = std::string(CollectivePerfTableGen::Config::kStdout);

// Parse flags.
std::vector<tsl::Flag> flag_list = {
tsl::Flag("num_nodes", &num_nodes,
"Specifies number of processes across a distributed system."),
tsl::Flag("num_devices_per_host", &num_devices_per_host,
"Specified number of devices per host."),
tsl::Flag("task_id", &task_id,
"Specifies task identifier of this process. Must be unique "
"across the distributed system you run it on."),
Expand Down Expand Up @@ -204,8 +208,10 @@ int main(int argc, char* argv[]) {
cfg.task_id = task_id;
cfg.collective_types = ParseCollectives(collectives_unparsed);
cfg.tensor_size_bytes_spec = ParseStepSpec(tensor_size_bytes_spec_unparsed);
collective_devices_spec_unparsed = DefaultCollectiveDevicesIfEmpty(
collective_devices_spec_unparsed, num_nodes, num_devices_per_host);
cfg.replica_groups_list =
GetCollectiveDeviceLists(collective_devices_spec_unparsed);
CollectiveDeviceLists(collective_devices_spec_unparsed);
cfg.output = output;

std::unique_ptr<CollectivePerfTableGen> gen =
Expand Down
Loading
Loading