Skip to content

Commit

Permalink
[XLA:GPU] Add python bindings to collective perf table generator.
Browse files Browse the repository at this point in the history
Also add logical defaults for search space and pass in replica groups as string instead of plumbing through IotaReplicaGroupList to python bindings.

PiperOrigin-RevId: 726907010
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Feb 14, 2025
1 parent 8ca669b commit e9063a9
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 51 deletions.
14 changes: 12 additions & 2 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ load(
"if_google",
"tsl_gpu_library",
)
load("//xla/tsl:tsl.default.bzl", "filegroup")
load("//xla/tsl:tsl.default.bzl", "filegroup", "tsl_pybind_extension")
load(
"//xla/tsl/platform:build_config.bzl",
"tf_proto_library",
Expand Down Expand Up @@ -716,6 +716,7 @@ cc_library(
"//xla/pjrt:pjrt_client",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:backend",
"//xla/service:hlo_module_config",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/service/gpu/model:hlo_op_profiler_lib",
"//xla/service/gpu/model:hlo_op_profiles",
Expand Down Expand Up @@ -759,14 +760,23 @@ cc_library(
alwayslink = True,
)

tsl_pybind_extension(
name = "collective_perf_table_gen_bindings",
srcs = ["collective_perf_table_gen_bindings.cc"],
deps = [
":collective_perf_table_gen",
"@com_google_absl//absl/log:check",
"@nanobind",
],
)

xla_test(
name = "collective_perf_table_gen_test",
srcs = ["collective_perf_table_gen_test.cc"],
backends = ["gpu"],
local_defines = if_cuda(["GOOGLE_CUDA"]),
deps = [
":collective_perf_table_gen",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/tests:hlo_test_base",
Expand Down
53 changes: 43 additions & 10 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/collective_device_list.h"
Expand All @@ -38,6 +39,7 @@ limitations under the License.
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/service/gpu/model/hlo_op_profiler.h"
#include "xla/service/gpu/model/hlo_op_profiles.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tools/multihost_hlo_runner/create_client.h"
#include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h"
#include "xla/tsl/platform/env.h"
Expand Down Expand Up @@ -186,7 +188,10 @@ std::unique_ptr<HloModule> CreateCollectiveModule(const StaticSpec& spec) {
std::string hlo =
GetHlo(spec.collective_type, input_dim, output_dim, spec.replica_groups);

auto parsed = ParseAndReturnUnverifiedModule(hlo);
HloModuleConfig config;
config.set_num_partitions(spec.replica_groups.num_devices_per_group() *
spec.replica_groups.num_replica_groups());
auto parsed = ParseAndReturnUnverifiedModule(hlo, config);
CHECK_OK(parsed.status());
return std::move(*parsed);
}
Expand All @@ -202,6 +207,16 @@ uint64_t GetNetworkThroughputBytesPerSec(absl::Duration runtime,
return tensor_size_bytes * 1e9 / absl::ToInt64Nanoseconds(runtime);
}

IotaReplicaGroupList GetCollectiveDeviceList(
absl::string_view collective_device_list_unparsed) {
auto collective_device_list =
xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed);
CHECK_OK(collective_device_list);
CHECK(collective_device_list->iota_replica_group_list().has_value());

return *collective_device_list->iota_replica_group_list();
}

} // namespace

/*static*/ std::unique_ptr<CollectivePerfTableGen>
Expand All @@ -224,7 +239,7 @@ std::unique_ptr<PjRtLoadedExecutable> CollectivePerfTableGen::Compile(
std::unique_ptr<HloModule> module) {
DebugOptions debug_opts;
FunctionalHloRunner::RawCompileOptions opts;
opts.num_partitions = 8;
opts.num_partitions = module->config().num_partitions();
opts.spmd_mode = FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning;
auto compile_opts = FunctionalHloRunner::CreateCompileOptions(
*pjrt_env_.client, opts, config_.task_id, config_.num_nodes);
Expand All @@ -250,18 +265,26 @@ CollectivePerfTableGen::ProfilingData CollectivePerfTableGen::Profile(
VLOG(1) << "Compiled module: "
<< executable->GetHloModules().value()[0]->ToString();

// We do not profile dry runs or on more than one tasks.
if (config_.dry_run) {
return {};
}

std::unique_ptr<HloOpProfiler::KernelTracer> tracer =
HloOpProfiler::GetKernelTracer();
if (config_.task_id == 0) {
std::unique_ptr<HloOpProfiler::KernelTracer> tracer =
HloOpProfiler::GetKernelTracer();
for (int i = 0; i < kNumProfilingRuns; ++i) {
Run(*executable);
}
return {
/*runtime=*/absl::Nanoseconds(
std::move(*tracer).getMedianKernelTimeNs()),
};
}
for (int i = 0; i < kNumProfilingRuns; ++i) {
Run(*executable);
}
return {
/*runtime=*/absl::Nanoseconds(std::move(*tracer).getMedianKernelTimeNs()),
};
return {};
}

DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
Expand All @@ -282,11 +305,14 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
for (int64_t tensor_size = tsize_spec.start; tensor_size <= tsize_spec.stop;
tensor_size = inc(tensor_size, tsize_spec)) {
for (CollectiveType collective_type : config_.collective_types) {
for (const IotaReplicaGroupList& replica_groups :
config_.replica_groups_list) {
for (absl::string_view replica_groups_raw : config_.replica_groups_list) {
CHECK(collective_type != CollectiveType::UNSPECIFIED);

StaticSpec spec{collective_type, replica_groups, tensor_size};
StaticSpec spec{
collective_type,
GetCollectiveDeviceList(replica_groups_raw),
tensor_size,
};
static_specs.push_back(spec);
}
}
Expand All @@ -313,6 +339,10 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
if (!config_.dry_run) {
profiled_data = Profile(std::move(spec.module));
}
if (profiled_data.runtime == absl::ZeroDuration()) {
VLOG(1) << "Size: " << static_spec.tensor_size_bytes << " too small.";
continue;
}
entry.set_network_throughput_bytes_per_sec(GetNetworkThroughputBytesPerSec(
profiled_data.runtime, static_spec.tensor_size_bytes));

Expand All @@ -332,6 +362,9 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {

absl::Status CollectivePerfTableGen::Dump(
const DeviceHloInstructionProfiles& table) {
if (config_.task_id != 0) {
return absl::OkStatus();
}
if (config_.output == CollectivePerfTableGen::Config::kStdout) {
LOG(INFO) << table.DebugString();
return absl::OkStatus();
Expand Down
17 changes: 10 additions & 7 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/service/backend.h"
Expand All @@ -43,10 +42,10 @@ namespace xla::gpu {
class CollectivePerfTableGen {
public:
struct StepSpec {
int64_t start = 0;
int64_t stop = -1;
int64_t start = 1024;
int64_t stop = 2ll * 1024 * 1024 * 1024;
int64_t step = 0;
int64_t factor = 0;
int64_t factor = 2;
};

enum class CollectiveType {
Expand All @@ -61,14 +60,18 @@ class CollectivePerfTableGen {

// Search space.
StepSpec tensor_size_bytes_spec;
std::vector<CollectiveType> collective_types;
std::vector<IotaReplicaGroupList> replica_groups_list;
std::vector<CollectiveType> collective_types = {
CollectiveType::ALL_REDUCE,
CollectiveType::ALL_GATHER,
CollectiveType::REDUCE_SCATTER,
};
std::vector<std::string> replica_groups_list;

// Execution opts.
bool dry_run = false;
std::string output = std::string(kStdout);
std::string coordinator_address = "";
absl::Duration connection_timeout = absl::Seconds(60);
absl::Duration connection_timeout = absl::Seconds(600);
uint16_t num_nodes = 1;
uint16_t task_id = 0;
};
Expand Down
74 changes: 74 additions & 0 deletions xla/tools/collective_perf_table_gen_bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "absl/log/check.h"
#include "nanobind/nanobind.h"
#include "nanobind/nb_defs.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/tools/collective_perf_table_gen.h"

namespace nb = nanobind;

NB_MODULE(collective_perf_table_gen_bindings, m) {
// Bind the Config struct
nb::class_<xla::gpu::CollectivePerfTableGen::Config>(m, "Config")
.def(nb::init<>())
.def_rw("tensor_size_bytes_spec",
&xla::gpu::CollectivePerfTableGen::Config::tensor_size_bytes_spec)
.def_rw("collective_types",
&xla::gpu::CollectivePerfTableGen::Config::collective_types)
.def_rw("replica_groups_list",
&xla::gpu::CollectivePerfTableGen::Config::replica_groups_list)
.def_rw("dry_run", &xla::gpu::CollectivePerfTableGen::Config::dry_run)
.def_rw("output", &xla::gpu::CollectivePerfTableGen::Config::output)
.def_rw("coordinator_address",
&xla::gpu::CollectivePerfTableGen::Config::coordinator_address)
.def_rw("connection_timeout",
&xla::gpu::CollectivePerfTableGen::Config::connection_timeout)
.def_rw("num_nodes", &xla::gpu::CollectivePerfTableGen::Config::num_nodes)
.def_rw("task_id", &xla::gpu::CollectivePerfTableGen::Config::task_id);

// Bind the StepSpec struct
nb::class_<xla::gpu::CollectivePerfTableGen::StepSpec>(m, "StepSpec")
.def(nb::init<>())
.def_rw("start", &xla::gpu::CollectivePerfTableGen::StepSpec::start)
.def_rw("stop", &xla::gpu::CollectivePerfTableGen::StepSpec::stop)
.def_rw("step", &xla::gpu::CollectivePerfTableGen::StepSpec::step)
.def_rw("factor", &xla::gpu::CollectivePerfTableGen::StepSpec::factor);

// Bind the CollectiveType enum
nb::enum_<xla::gpu::CollectivePerfTableGen::CollectiveType>(m,
"CollectiveType")
.value("UNSPECIFIED",
xla::gpu::CollectivePerfTableGen::CollectiveType::UNSPECIFIED)
.value("ALL_REDUCE",
xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_REDUCE)
.value("REDUCE_SCATTER",
xla::gpu::CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER)
.value("ALL_GATHER",
xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_GATHER)
.export_values();

m.def("run", [](xla::gpu::CollectivePerfTableGen::Config config) -> void {
std::unique_ptr<xla::gpu::CollectivePerfTableGen> gen =
xla::gpu::CollectivePerfTableGen::Create(config);
auto table = gen->ComputeTable();
CHECK_OK(gen->Dump(table));
});
}
60 changes: 33 additions & 27 deletions xla/tools/collective_perf_table_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "absl/strings/substitute.h"
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/tools/collective_perf_table_gen.h"
#include "xla/tsl/util/command_line_flags.h"
Expand Down Expand Up @@ -76,7 +75,6 @@ to 4 GPUs.

constexpr absl::string_view kDefaultCoordinatorAddress = "127.0.0.1:1234";

using ::xla::IotaReplicaGroupList;
using ::xla::gpu::CollectivePerfTableGen;
using ::xla::gpu::DeviceHloInstructionProfiles;

Expand All @@ -91,27 +89,6 @@ std::pair<std::string /*key*/, std::string /*value*/> ExtractKV(
return {key, value};
}

IotaReplicaGroupList GetCollectiveDeviceList(
absl::string_view collective_device_list_unparsed) {
auto collective_device_list =
xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed);
CHECK_OK(collective_device_list);
CHECK(collective_device_list->iota_replica_group_list().has_value());

return *collective_device_list->iota_replica_group_list();
}

std::vector<IotaReplicaGroupList> GetCollectiveDeviceLists(
absl::string_view collective_device_lists_unparsed) {
std::vector<IotaReplicaGroupList> device_lists;
for (absl::string_view token :
absl::StrSplit(collective_device_lists_unparsed, ';')) {
device_lists.emplace_back(GetCollectiveDeviceList(token));
}
CHECK_GT(device_lists.size(), 0);
return device_lists;
}

std::vector<CollectivePerfTableGen::CollectiveType> ParseCollectives(
absl::string_view unparsed) {
std::vector<CollectivePerfTableGen::CollectiveType> types;
Expand Down Expand Up @@ -153,22 +130,49 @@ CollectivePerfTableGen::StepSpec ParseStepSpec(absl::string_view unparsed) {
return spec;
}

std::vector<std::string> CollectiveDeviceLists(
absl::string_view device_list_unparsed) {
std::vector<std::string> result;
for (absl::string_view device_list :
absl::StrSplit(device_list_unparsed, ';')) {
result.emplace_back(device_list);
}
return result;
}

std::string DefaultCollectiveDevicesIfEmpty(
const std::string& collective_devices_spec_unparsed, int32_t num_nodes,
int32_t num_devices_per_host) {
if (collective_devices_spec_unparsed.empty()) {
return absl::Substitute("[1,$0]<=[$0];[$2,$1]<=[$1,$2]T(1,0)",
num_devices_per_host * num_nodes, num_nodes,
num_devices_per_host);
}
return collective_devices_spec_unparsed;
}

} // namespace

// TODO(b/390097558): Add an option to generate perf table for collective which
// gets overlap to model resource contention.
int main(int argc, char* argv[]) {
// Default args.
int32_t num_nodes = 1;
int32_t num_devices_per_host = 8;
int32_t task_id = 0;
std::string collectives_unparsed;
std::string tensor_size_bytes_spec_unparsed;
std::string collectives_unparsed = "ALL_REDUCE,ALL_GATHER,REDUCE_SCATTER";
std::string tensor_size_bytes_spec_unparsed =
"start=1024,stop=2147483648,factor=2";
std::string collective_devices_spec_unparsed;
std::string coordinator_address = std::string(kDefaultCoordinatorAddress);
std::string output = std::string(CollectivePerfTableGen::Config::kStdout);

// Parse flags.
std::vector<tsl::Flag> flag_list = {
tsl::Flag("num_nodes", &num_nodes,
"Specifies number of processes across a distributed system."),
tsl::Flag("num_devices_per_host", &num_devices_per_host,
"Specified number of devices per host."),
tsl::Flag("task_id", &task_id,
"Specifies task identifier of this process. Must be unique "
"across the distributed system you run it on."),
Expand Down Expand Up @@ -204,8 +208,10 @@ int main(int argc, char* argv[]) {
cfg.task_id = task_id;
cfg.collective_types = ParseCollectives(collectives_unparsed);
cfg.tensor_size_bytes_spec = ParseStepSpec(tensor_size_bytes_spec_unparsed);
collective_devices_spec_unparsed = DefaultCollectiveDevicesIfEmpty(
collective_devices_spec_unparsed, num_nodes, num_devices_per_host);
cfg.replica_groups_list =
GetCollectiveDeviceLists(collective_devices_spec_unparsed);
CollectiveDeviceLists(collective_devices_spec_unparsed);
cfg.output = output;

std::unique_ptr<CollectivePerfTableGen> gen =
Expand Down
Loading

0 comments on commit e9063a9

Please sign in to comment.