diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 4d5ac95543835..09d547b98a969 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -19,7 +19,7 @@ load( "if_google", "tsl_gpu_library", ) -load("//xla/tsl:tsl.default.bzl", "filegroup") +load("//xla/tsl:tsl.default.bzl", "filegroup", "tsl_pybind_extension") load( "//xla/tsl/platform:build_config.bzl", "tf_proto_library", @@ -716,6 +716,7 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:backend", + "//xla/service:hlo_module_config", "//xla/service/gpu/model:hlo_op_profile_proto_cc", "//xla/service/gpu/model:hlo_op_profiler_lib", "//xla/service/gpu/model:hlo_op_profiles", @@ -759,6 +760,16 @@ cc_library( alwayslink = True, ) +tsl_pybind_extension( + name = "collective_perf_table_gen_bindings", + srcs = ["collective_perf_table_gen_bindings.cc"], + deps = [ + ":collective_perf_table_gen", + "@com_google_absl//absl/log:check", + "@nanobind", + ], +) + xla_test( name = "collective_perf_table_gen_test", srcs = ["collective_perf_table_gen_test.cc"], @@ -766,7 +777,6 @@ xla_test( local_defines = if_cuda(["GOOGLE_CUDA"]), deps = [ ":collective_perf_table_gen", - "//xla/hlo/ir:hlo", "//xla/service/gpu/model:hlo_op_profile_proto_cc", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:hlo_test_base", diff --git a/xla/tools/collective_perf_table_gen.cc b/xla/tools/collective_perf_table_gen.cc index a55e5e72bb275..bc71ab02dafb8 100644 --- a/xla/tools/collective_perf_table_gen.cc +++ b/xla/tools/collective_perf_table_gen.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" #include "xla/hlo/ir/collective_device_list.h" @@ -38,6 +39,7 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/service/gpu/model/hlo_op_profiler.h" #include "xla/service/gpu/model/hlo_op_profiles.h" +#include "xla/service/hlo_module_config.h" #include "xla/tools/multihost_hlo_runner/create_client.h" #include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h" #include "xla/tsl/platform/env.h" @@ -186,7 +188,10 @@ std::unique_ptr CreateCollectiveModule(const StaticSpec& spec) { std::string hlo = GetHlo(spec.collective_type, input_dim, output_dim, spec.replica_groups); - auto parsed = ParseAndReturnUnverifiedModule(hlo); + HloModuleConfig config; + config.set_num_partitions(spec.replica_groups.num_devices_per_group() * + spec.replica_groups.num_replica_groups()); + auto parsed = ParseAndReturnUnverifiedModule(hlo, config); CHECK_OK(parsed.status()); return std::move(*parsed); } @@ -202,6 +207,16 @@ uint64_t GetNetworkThroughputBytesPerSec(absl::Duration runtime, return tensor_size_bytes * 1e9 / absl::ToInt64Nanoseconds(runtime); } +IotaReplicaGroupList GetCollectiveDeviceList( + absl::string_view collective_device_list_unparsed) { + auto collective_device_list = + xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed); + CHECK_OK(collective_device_list); + CHECK(collective_device_list->iota_replica_group_list().has_value()); + + return *collective_device_list->iota_replica_group_list(); +} + } // namespace /*static*/ std::unique_ptr @@ -224,7 +239,7 @@ std::unique_ptr CollectivePerfTableGen::Compile( std::unique_ptr module) { DebugOptions debug_opts; FunctionalHloRunner::RawCompileOptions opts; - opts.num_partitions = 8; + opts.num_partitions = module->config().num_partitions(); opts.spmd_mode = FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning; auto compile_opts = FunctionalHloRunner::CreateCompileOptions( *pjrt_env_.client, opts, config_.task_id, config_.num_nodes); @@ -250,18 +265,26 @@ CollectivePerfTableGen::ProfilingData CollectivePerfTableGen::Profile( VLOG(1) << "Compiled module: " << executable->GetHloModules().value()[0]->ToString(); + // We do not profile dry runs or on more than one tasks. if (config_.dry_run) { return {}; } - std::unique_ptr tracer = - HloOpProfiler::GetKernelTracer(); + if (config_.task_id == 0) { + std::unique_ptr tracer = + HloOpProfiler::GetKernelTracer(); + for (int i = 0; i < kNumProfilingRuns; ++i) { + Run(*executable); + } + return { + /*runtime=*/absl::Nanoseconds( + std::move(*tracer).getMedianKernelTimeNs()), + }; + } for (int i = 0; i < kNumProfilingRuns; ++i) { Run(*executable); } - return { - /*runtime=*/absl::Nanoseconds(std::move(*tracer).getMedianKernelTimeNs()), - }; + return {}; } DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() { @@ -282,11 +305,14 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() { for (int64_t tensor_size = tsize_spec.start; tensor_size <= tsize_spec.stop; tensor_size = inc(tensor_size, tsize_spec)) { for (CollectiveType collective_type : config_.collective_types) { - for (const IotaReplicaGroupList& replica_groups : - config_.replica_groups_list) { + for (absl::string_view replica_groups_raw : config_.replica_groups_list) { CHECK(collective_type != CollectiveType::UNSPECIFIED); - StaticSpec spec{collective_type, replica_groups, tensor_size}; + StaticSpec spec{ + collective_type, + GetCollectiveDeviceList(replica_groups_raw), + tensor_size, + }; static_specs.push_back(spec); } } @@ -313,6 +339,10 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() { if (!config_.dry_run) { profiled_data = Profile(std::move(spec.module)); } + if (profiled_data.runtime == absl::ZeroDuration()) { + VLOG(1) << "Size: " << static_spec.tensor_size_bytes << " too small."; + continue; + } entry.set_network_throughput_bytes_per_sec(GetNetworkThroughputBytesPerSec( profiled_data.runtime, static_spec.tensor_size_bytes)); @@ -332,6 +362,9 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() { absl::Status CollectivePerfTableGen::Dump( const DeviceHloInstructionProfiles& table) { + if (config_.task_id != 0) { + return absl::OkStatus(); + } if (config_.output == CollectivePerfTableGen::Config::kStdout) { LOG(INFO) << table.DebugString(); return absl::OkStatus(); diff --git a/xla/tools/collective_perf_table_gen.h b/xla/tools/collective_perf_table_gen.h index 5e87f358b7e51..e941812485ad6 100644 --- a/xla/tools/collective_perf_table_gen.h +++ b/xla/tools/collective_perf_table_gen.h @@ -26,7 +26,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/backend.h" @@ -43,10 +42,10 @@ namespace xla::gpu { class CollectivePerfTableGen { public: struct StepSpec { - int64_t start = 0; - int64_t stop = -1; + int64_t start = 1024; + int64_t stop = 2ll * 1024 * 1024 * 1024; int64_t step = 0; - int64_t factor = 0; + int64_t factor = 2; }; enum class CollectiveType { @@ -61,14 +60,18 @@ class CollectivePerfTableGen { // Search space. StepSpec tensor_size_bytes_spec; - std::vector collective_types; - std::vector replica_groups_list; + std::vector collective_types = { + CollectiveType::ALL_REDUCE, + CollectiveType::ALL_GATHER, + CollectiveType::REDUCE_SCATTER, + }; + std::vector replica_groups_list; // Execution opts. bool dry_run = false; std::string output = std::string(kStdout); std::string coordinator_address = ""; - absl::Duration connection_timeout = absl::Seconds(60); + absl::Duration connection_timeout = absl::Seconds(600); uint16_t num_nodes = 1; uint16_t task_id = 0; }; diff --git a/xla/tools/collective_perf_table_gen_bindings.cc b/xla/tools/collective_perf_table_gen_bindings.cc new file mode 100644 index 0000000000000..61eb545b9f319 --- /dev/null +++ b/xla/tools/collective_perf_table_gen_bindings.cc @@ -0,0 +1,74 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/log/check.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/tools/collective_perf_table_gen.h" + +namespace nb = nanobind; + +NB_MODULE(collective_perf_table_gen_bindings, m) { + // Bind the Config struct + nb::class_(m, "Config") + .def(nb::init<>()) + .def_rw("tensor_size_bytes_spec", + &xla::gpu::CollectivePerfTableGen::Config::tensor_size_bytes_spec) + .def_rw("collective_types", + &xla::gpu::CollectivePerfTableGen::Config::collective_types) + .def_rw("replica_groups_list", + &xla::gpu::CollectivePerfTableGen::Config::replica_groups_list) + .def_rw("dry_run", &xla::gpu::CollectivePerfTableGen::Config::dry_run) + .def_rw("output", &xla::gpu::CollectivePerfTableGen::Config::output) + .def_rw("coordinator_address", + &xla::gpu::CollectivePerfTableGen::Config::coordinator_address) + .def_rw("connection_timeout", + &xla::gpu::CollectivePerfTableGen::Config::connection_timeout) + .def_rw("num_nodes", &xla::gpu::CollectivePerfTableGen::Config::num_nodes) + .def_rw("task_id", &xla::gpu::CollectivePerfTableGen::Config::task_id); + + // Bind the StepSpec struct + nb::class_(m, "StepSpec") + .def(nb::init<>()) + .def_rw("start", &xla::gpu::CollectivePerfTableGen::StepSpec::start) + .def_rw("stop", &xla::gpu::CollectivePerfTableGen::StepSpec::stop) + .def_rw("step", &xla::gpu::CollectivePerfTableGen::StepSpec::step) + .def_rw("factor", &xla::gpu::CollectivePerfTableGen::StepSpec::factor); + + // Bind the CollectiveType enum + nb::enum_(m, + "CollectiveType") + .value("UNSPECIFIED", + xla::gpu::CollectivePerfTableGen::CollectiveType::UNSPECIFIED) + .value("ALL_REDUCE", + xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_REDUCE) + .value("REDUCE_SCATTER", + xla::gpu::CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER) + .value("ALL_GATHER", + xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_GATHER) + .export_values(); + + m.def("run", [](xla::gpu::CollectivePerfTableGen::Config config) -> void { + std::unique_ptr gen = + xla::gpu::CollectivePerfTableGen::Create(config); + auto table = gen->ComputeTable(); + CHECK_OK(gen->Dump(table)); + }); +} diff --git a/xla/tools/collective_perf_table_gen_main.cc b/xla/tools/collective_perf_table_gen_main.cc index 9c4655d0d5517..4a3deca1b5d92 100644 --- a/xla/tools/collective_perf_table_gen_main.cc +++ b/xla/tools/collective_perf_table_gen_main.cc @@ -29,8 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "xla/hlo/ir/collective_device_list.h" -#include "xla/hlo/parser/hlo_parser.h" +#include "absl/strings/substitute.h" #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/tools/collective_perf_table_gen.h" #include "xla/tsl/util/command_line_flags.h" @@ -76,7 +75,6 @@ to 4 GPUs. constexpr absl::string_view kDefaultCoordinatorAddress = "127.0.0.1:1234"; -using ::xla::IotaReplicaGroupList; using ::xla::gpu::CollectivePerfTableGen; using ::xla::gpu::DeviceHloInstructionProfiles; @@ -91,27 +89,6 @@ std::pair ExtractKV( return {key, value}; } -IotaReplicaGroupList GetCollectiveDeviceList( - absl::string_view collective_device_list_unparsed) { - auto collective_device_list = - xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed); - CHECK_OK(collective_device_list); - CHECK(collective_device_list->iota_replica_group_list().has_value()); - - return *collective_device_list->iota_replica_group_list(); -} - -std::vector GetCollectiveDeviceLists( - absl::string_view collective_device_lists_unparsed) { - std::vector device_lists; - for (absl::string_view token : - absl::StrSplit(collective_device_lists_unparsed, ';')) { - device_lists.emplace_back(GetCollectiveDeviceList(token)); - } - CHECK_GT(device_lists.size(), 0); - return device_lists; -} - std::vector ParseCollectives( absl::string_view unparsed) { std::vector types; @@ -153,22 +130,49 @@ CollectivePerfTableGen::StepSpec ParseStepSpec(absl::string_view unparsed) { return spec; } +std::vector CollectiveDeviceLists( + absl::string_view device_list_unparsed) { + std::vector result; + for (absl::string_view device_list : + absl::StrSplit(device_list_unparsed, ';')) { + result.emplace_back(device_list); + } + return result; +} + +std::string DefaultCollectiveDevicesIfEmpty( + const std::string& collective_devices_spec_unparsed, int32_t num_nodes, + int32_t num_devices_per_host) { + if (collective_devices_spec_unparsed.empty()) { + return absl::Substitute("[1,$0]<=[$0];[$2,$1]<=[$1,$2]T(1,0)", + num_devices_per_host * num_nodes, num_nodes, + num_devices_per_host); + } + return collective_devices_spec_unparsed; +} + } // namespace // TODO(b/390097558): Add an option to generate perf table for collective which // gets overlap to model resource contention. int main(int argc, char* argv[]) { + // Default args. int32_t num_nodes = 1; + int32_t num_devices_per_host = 8; int32_t task_id = 0; - std::string collectives_unparsed; - std::string tensor_size_bytes_spec_unparsed; + std::string collectives_unparsed = "ALL_REDUCE,ALL_GATHER,REDUCE_SCATTER"; + std::string tensor_size_bytes_spec_unparsed = + "start=1024,stop=2147483648,factor=2"; std::string collective_devices_spec_unparsed; std::string coordinator_address = std::string(kDefaultCoordinatorAddress); std::string output = std::string(CollectivePerfTableGen::Config::kStdout); + // Parse flags. std::vector flag_list = { tsl::Flag("num_nodes", &num_nodes, "Specifies number of processes across a distributed system."), + tsl::Flag("num_devices_per_host", &num_devices_per_host, + "Specified number of devices per host."), tsl::Flag("task_id", &task_id, "Specifies task identifier of this process. Must be unique " "across the distributed system you run it on."), @@ -204,8 +208,10 @@ int main(int argc, char* argv[]) { cfg.task_id = task_id; cfg.collective_types = ParseCollectives(collectives_unparsed); cfg.tensor_size_bytes_spec = ParseStepSpec(tensor_size_bytes_spec_unparsed); + collective_devices_spec_unparsed = DefaultCollectiveDevicesIfEmpty( + collective_devices_spec_unparsed, num_nodes, num_devices_per_host); cfg.replica_groups_list = - GetCollectiveDeviceLists(collective_devices_spec_unparsed); + CollectiveDeviceLists(collective_devices_spec_unparsed); cfg.output = output; std::unique_ptr gen = diff --git a/xla/tools/collective_perf_table_gen_test.cc b/xla/tools/collective_perf_table_gen_test.cc index b7d6d1dc63d64..04642bdd1549b 100644 --- a/xla/tools/collective_perf_table_gen_test.cc +++ b/xla/tools/collective_perf_table_gen_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "xla/hlo/ir/collective_device_list.h" #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/tests/hlo_test_base.h" @@ -59,8 +58,7 @@ TEST_F(CollectivePerfTableGenTest, ConstantStepGeneratesConfigs) { CollectivePerfTableGen::CollectiveType::ALL_GATHER, CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER, }; - IotaReplicaGroupList iota(1, 1); - cfg_.replica_groups_list = {iota}; + cfg_.replica_groups_list.emplace_back("[1,1]<=[1]"); CollectivePerfTableGen::StepSpec spec{ /*start=*/4, /*stop=*/20, @@ -83,8 +81,7 @@ TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) { CollectivePerfTableGen::CollectiveType::ALL_GATHER, CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER, }; - IotaReplicaGroupList iota(1, 1); - cfg_.replica_groups_list = {iota}; + cfg_.replica_groups_list.emplace_back("[1,1]<=[1]"); CollectivePerfTableGen::StepSpec spec{ /*start=*/4, /*stop=*/32,