Skip to content

Commit

Permalink
[XLA:GPU] Add merge functionality to collective perf table gen.
Browse files Browse the repository at this point in the history
Convenience function to merge multiple perf tables as one.

PiperOrigin-RevId: 726919223
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Feb 14, 2025
1 parent e9063a9 commit 6da089d
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 1 deletion.
2 changes: 2 additions & 0 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,8 @@ cc_library(
"//xla/tools/multihost_hlo_runner:functional_hlo_runner",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
101 changes: 101 additions & 0 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ limitations under the License.

#include "xla/tools/collective_perf_table_gen.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -71,6 +74,31 @@ struct ExplicitSpec {
std::unique_ptr<HloModule> module;
};

struct ProfilingResult {
std::string device_info;
HloInstructionProto hlo_proto;
std::vector<HloInstructionProto> operands;
std::string fingerprint;
int64_t clock_cycles;
int64_t flops;
int64_t network_throughput;

struct Hash {
size_t operator()(const ProfilingResult& profiling_result) const {
return absl::HashOf(profiling_result.device_info,
profiling_result.fingerprint);
}
};

struct Eq {
bool operator()(const ProfilingResult& lhs,
const ProfilingResult& rhs) const {
return lhs.device_info == rhs.device_info &&
lhs.fingerprint == rhs.fingerprint;
}
};
};

int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type,
int64_t tensor_size_bytes,
IotaReplicaGroupList replica_groups) {
Expand Down Expand Up @@ -400,4 +428,77 @@ absl::Status CollectivePerfTableGen::Dump(
return absl::OkStatus();
}

DeviceHloInstructionProfiles CollectivePerfTableGen::Merge(
absl::string_view merge_path) {
DeviceHloInstructionProfiles result;
std::vector<std::string> filenames;
CHECK_OK(
tsl::Env::Default()->GetChildren(std::string(merge_path), &filenames));

absl::flat_hash_set<ProfilingResult, ProfilingResult::Hash,
ProfilingResult::Eq>
profiling_results;
uint64_t profiling_results_counter = 0;
for (const std::string& filename : filenames) {
// Read file.
std::string profile_path = absl::StrCat(merge_path, "/", filename);
DeviceHloInstructionProfiles partial_profile;

CHECK_OK(tsl::Env::Default()->FileExists(profile_path));
if (!tsl::ReadTextOrBinaryProto(tsl::Env::Default(), profile_path,
&partial_profile)
.ok()) {
LOG(WARNING) << "Cannot read :" << profile_path;
continue;
}

for (auto& [device_descriptor, data] : partial_profile.entries()) {
for (const HloInstructionProfile& profile : data.entries()) {
CHECK(!profile.fingerprint().empty())
<< "Expected fingerprint to deduplicate: " << profile.DebugString();

ProfilingResult profiling_result{
device_descriptor,
std::move(profile.instruction()),
{
profile.operands().begin(),
profile.operands().end(),
},
std::move(profile.fingerprint()),
profile.clock_cycles(),
profile.flops(),
profile.network_throughput_bytes_per_sec(),
};
profiling_results.insert(profiling_result);
profiling_results_counter++;
}
}
}
LOG(INFO) << "Merging and deduplication entries count. Before "
<< profiling_results_counter << ", after "
<< profiling_results.size() << ".";

for (const ProfilingResult& profiling_result : profiling_results) {
std::string device_descriptor = profiling_result.device_info;
if (!result.mutable_entries()->contains(device_descriptor)) {
result.mutable_entries()->insert({device_descriptor, {}});
}

HloInstructionProfile profile_proto;
*profile_proto.mutable_instruction() =
std::move(profiling_result.hlo_proto);
for (auto op : profiling_result.operands) {
*profile_proto.add_operands() = std::move(op);
}
profile_proto.set_flops(profiling_result.flops);
profile_proto.set_clock_cycles(profiling_result.clock_cycles);
profile_proto.set_fingerprint(profiling_result.fingerprint);

*result.mutable_entries()->at(device_descriptor).add_entries() =
std::move(profile_proto);
}

return result;
}

} // namespace xla::gpu
5 changes: 5 additions & 0 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class CollectivePerfTableGen {
// content (but not deduplicating).
absl::Status Dump(const DeviceHloInstructionProfiles& table);

// Merges all of the profiled files under `merge_path`, deduplicates them
// based on fingerprint and writes them to a single
// `DeviceHloInstructionProfiles` proto.
DeviceHloInstructionProfiles Merge(absl::string_view merge_path);

private:
explicit CollectivePerfTableGen(Config config, PjRtEnvironment&& pjrt_env)
: config_(std::move(config)),
Expand Down
12 changes: 11 additions & 1 deletion xla/tools/collective_perf_table_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ int main(int argc, char* argv[]) {
std::string collective_devices_spec_unparsed;
std::string coordinator_address = std::string(kDefaultCoordinatorAddress);
std::string output = std::string(CollectivePerfTableGen::Config::kStdout);
std::string merge_path;

// Parse flags.
std::vector<tsl::Flag> flag_list = {
Expand Down Expand Up @@ -193,6 +194,10 @@ int main(int argc, char* argv[]) {
"Output mode for the program. If set to 'stdout' performance table "
"will be printed to the standard output. If given a file with .pbtxt "
"or .pb extension it will append the contents to that file."),
tsl::Flag("merge_path", &merge_path,
"Path to DeviceHloInstructionProfiles files. When specified it "
"will merge all of the profiled files and write them to a "
"single file specified by `output`."),
};

std::string kUsageString =
Expand All @@ -216,7 +221,12 @@ int main(int argc, char* argv[]) {

std::unique_ptr<CollectivePerfTableGen> gen =
CollectivePerfTableGen::Create(cfg);
DeviceHloInstructionProfiles profiles = gen->ComputeTable();
DeviceHloInstructionProfiles profiles;
if (merge_path.empty()) {
profiles = gen->ComputeTable();
} else {
profiles = gen->Merge(merge_path);
};
CHECK_OK(gen->Dump(profiles));
return 0;
}

0 comments on commit 6da089d

Please sign in to comment.