From 6da089db33ff5cb1974846d8496e3e7b5e14eb40 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 14 Feb 2025 07:22:06 -0800 Subject: [PATCH] [XLA:GPU] Add merge functionality to collective perf table gen. Convenience function to merge multiple perf tables as one. PiperOrigin-RevId: 726919223 --- xla/tools/BUILD | 2 + xla/tools/collective_perf_table_gen.cc | 101 ++++++++++++++++++++ xla/tools/collective_perf_table_gen.h | 5 + xla/tools/collective_perf_table_gen_main.cc | 12 ++- 4 files changed, 119 insertions(+), 1 deletion(-) diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 09d547b98a969..afe40bf6ab277 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -724,6 +724,8 @@ cc_library( "//xla/tools/multihost_hlo_runner:functional_hlo_runner", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/xla/tools/collective_perf_table_gen.cc b/xla/tools/collective_perf_table_gen.cc index bc71ab02dafb8..d438b4df63f4f 100644 --- a/xla/tools/collective_perf_table_gen.cc +++ b/xla/tools/collective_perf_table_gen.cc @@ -15,12 +15,15 @@ limitations under the License. #include "xla/tools/collective_perf_table_gen.h" +#include #include #include #include #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -71,6 +74,31 @@ struct ExplicitSpec { std::unique_ptr module; }; +struct ProfilingResult { + std::string device_info; + HloInstructionProto hlo_proto; + std::vector operands; + std::string fingerprint; + int64_t clock_cycles; + int64_t flops; + int64_t network_throughput; + + struct Hash { + size_t operator()(const ProfilingResult& profiling_result) const { + return absl::HashOf(profiling_result.device_info, + profiling_result.fingerprint); + } + }; + + struct Eq { + bool operator()(const ProfilingResult& lhs, + const ProfilingResult& rhs) const { + return lhs.device_info == rhs.device_info && + lhs.fingerprint == rhs.fingerprint; + } + }; +}; + int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type, int64_t tensor_size_bytes, IotaReplicaGroupList replica_groups) { @@ -400,4 +428,77 @@ absl::Status CollectivePerfTableGen::Dump( return absl::OkStatus(); } +DeviceHloInstructionProfiles CollectivePerfTableGen::Merge( + absl::string_view merge_path) { + DeviceHloInstructionProfiles result; + std::vector filenames; + CHECK_OK( + tsl::Env::Default()->GetChildren(std::string(merge_path), &filenames)); + + absl::flat_hash_set + profiling_results; + uint64_t profiling_results_counter = 0; + for (const std::string& filename : filenames) { + // Read file. + std::string profile_path = absl::StrCat(merge_path, "/", filename); + DeviceHloInstructionProfiles partial_profile; + + CHECK_OK(tsl::Env::Default()->FileExists(profile_path)); + if (!tsl::ReadTextOrBinaryProto(tsl::Env::Default(), profile_path, + &partial_profile) + .ok()) { + LOG(WARNING) << "Cannot read :" << profile_path; + continue; + } + + for (auto& [device_descriptor, data] : partial_profile.entries()) { + for (const HloInstructionProfile& profile : data.entries()) { + CHECK(!profile.fingerprint().empty()) + << "Expected fingerprint to deduplicate: " << profile.DebugString(); + + ProfilingResult profiling_result{ + device_descriptor, + std::move(profile.instruction()), + { + profile.operands().begin(), + profile.operands().end(), + }, + std::move(profile.fingerprint()), + profile.clock_cycles(), + profile.flops(), + profile.network_throughput_bytes_per_sec(), + }; + profiling_results.insert(profiling_result); + profiling_results_counter++; + } + } + } + LOG(INFO) << "Merging and deduplication entries count. Before " + << profiling_results_counter << ", after " + << profiling_results.size() << "."; + + for (const ProfilingResult& profiling_result : profiling_results) { + std::string device_descriptor = profiling_result.device_info; + if (!result.mutable_entries()->contains(device_descriptor)) { + result.mutable_entries()->insert({device_descriptor, {}}); + } + + HloInstructionProfile profile_proto; + *profile_proto.mutable_instruction() = + std::move(profiling_result.hlo_proto); + for (auto op : profiling_result.operands) { + *profile_proto.add_operands() = std::move(op); + } + profile_proto.set_flops(profiling_result.flops); + profile_proto.set_clock_cycles(profiling_result.clock_cycles); + profile_proto.set_fingerprint(profiling_result.fingerprint); + + *result.mutable_entries()->at(device_descriptor).add_entries() = + std::move(profile_proto); + } + + return result; +} + } // namespace xla::gpu diff --git a/xla/tools/collective_perf_table_gen.h b/xla/tools/collective_perf_table_gen.h index e941812485ad6..90d22309ac1aa 100644 --- a/xla/tools/collective_perf_table_gen.h +++ b/xla/tools/collective_perf_table_gen.h @@ -92,6 +92,11 @@ class CollectivePerfTableGen { // content (but not deduplicating). absl::Status Dump(const DeviceHloInstructionProfiles& table); + // Merges all of the profiled files under `merge_path`, deduplicates them + // based on fingerprint and writes them to a single + // `DeviceHloInstructionProfiles` proto. + DeviceHloInstructionProfiles Merge(absl::string_view merge_path); + private: explicit CollectivePerfTableGen(Config config, PjRtEnvironment&& pjrt_env) : config_(std::move(config)), diff --git a/xla/tools/collective_perf_table_gen_main.cc b/xla/tools/collective_perf_table_gen_main.cc index 4a3deca1b5d92..6ac81080510af 100644 --- a/xla/tools/collective_perf_table_gen_main.cc +++ b/xla/tools/collective_perf_table_gen_main.cc @@ -166,6 +166,7 @@ int main(int argc, char* argv[]) { std::string collective_devices_spec_unparsed; std::string coordinator_address = std::string(kDefaultCoordinatorAddress); std::string output = std::string(CollectivePerfTableGen::Config::kStdout); + std::string merge_path; // Parse flags. std::vector flag_list = { @@ -193,6 +194,10 @@ int main(int argc, char* argv[]) { "Output mode for the program. If set to 'stdout' performance table " "will be printed to the standard output. If given a file with .pbtxt " "or .pb extension it will append the contents to that file."), + tsl::Flag("merge_path", &merge_path, + "Path to DeviceHloInstructionProfiles files. When specified it " + "will merge all of the profiled files and write them to a " + "single file specified by `output`."), }; std::string kUsageString = @@ -216,7 +221,12 @@ int main(int argc, char* argv[]) { std::unique_ptr gen = CollectivePerfTableGen::Create(cfg); - DeviceHloInstructionProfiles profiles = gen->ComputeTable(); + DeviceHloInstructionProfiles profiles; + if (merge_path.empty()) { + profiles = gen->ComputeTable(); + } else { + profiles = gen->Merge(merge_path); + }; CHECK_OK(gen->Dump(profiles)); return 0; }