Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Add merge functionality to collective perf table gen. #22740

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,8 @@ cc_library(
"//xla/tools/multihost_hlo_runner:functional_hlo_runner",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
101 changes: 101 additions & 0 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ limitations under the License.

#include "xla/tools/collective_perf_table_gen.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -71,6 +74,31 @@ struct ExplicitSpec {
std::unique_ptr<HloModule> module;
};

struct ProfilingResult {
std::string device_info;
HloInstructionProto hlo_proto;
std::vector<HloInstructionProto> operands;
std::string fingerprint;
int64_t clock_cycles;
int64_t flops;
int64_t network_throughput;

struct Hash {
size_t operator()(const ProfilingResult& profiling_result) const {
return absl::HashOf(profiling_result.device_info,
profiling_result.fingerprint);
}
};

struct Eq {
bool operator()(const ProfilingResult& lhs,
const ProfilingResult& rhs) const {
return lhs.device_info == rhs.device_info &&
lhs.fingerprint == rhs.fingerprint;
}
};
};

int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type,
int64_t tensor_size_bytes,
IotaReplicaGroupList replica_groups) {
Expand Down Expand Up @@ -400,4 +428,77 @@ absl::Status CollectivePerfTableGen::Dump(
return absl::OkStatus();
}

DeviceHloInstructionProfiles CollectivePerfTableGen::Merge(
absl::string_view merge_path) {
DeviceHloInstructionProfiles result;
std::vector<std::string> filenames;
CHECK_OK(
tsl::Env::Default()->GetChildren(std::string(merge_path), &filenames));

absl::flat_hash_set<ProfilingResult, ProfilingResult::Hash,
ProfilingResult::Eq>
profiling_results;
uint64_t profiling_results_counter = 0;
for (const std::string& filename : filenames) {
// Read file.
std::string profile_path = absl::StrCat(merge_path, "/", filename);
DeviceHloInstructionProfiles partial_profile;

CHECK_OK(tsl::Env::Default()->FileExists(profile_path));
if (!tsl::ReadTextOrBinaryProto(tsl::Env::Default(), profile_path,
&partial_profile)
.ok()) {
LOG(WARNING) << "Cannot read :" << profile_path;
continue;
}

for (auto& [device_descriptor, data] : partial_profile.entries()) {
for (const HloInstructionProfile& profile : data.entries()) {
CHECK(!profile.fingerprint().empty())
<< "Expected fingerprint to deduplicate: " << profile.DebugString();

ProfilingResult profiling_result{
device_descriptor,
std::move(profile.instruction()),
{
profile.operands().begin(),
profile.operands().end(),
},
std::move(profile.fingerprint()),
profile.clock_cycles(),
profile.flops(),
profile.network_throughput_bytes_per_sec(),
};
profiling_results.insert(profiling_result);
profiling_results_counter++;
}
}
}
LOG(INFO) << "Merging and deduplication entries count. Before "
<< profiling_results_counter << ", after "
<< profiling_results.size() << ".";

for (const ProfilingResult& profiling_result : profiling_results) {
std::string device_descriptor = profiling_result.device_info;
if (!result.mutable_entries()->contains(device_descriptor)) {
result.mutable_entries()->insert({device_descriptor, {}});
}

HloInstructionProfile profile_proto;
*profile_proto.mutable_instruction() =
std::move(profiling_result.hlo_proto);
for (auto op : profiling_result.operands) {
*profile_proto.add_operands() = std::move(op);
}
profile_proto.set_flops(profiling_result.flops);
profile_proto.set_clock_cycles(profiling_result.clock_cycles);
profile_proto.set_fingerprint(profiling_result.fingerprint);

*result.mutable_entries()->at(device_descriptor).add_entries() =
std::move(profile_proto);
}

return result;
}

} // namespace xla::gpu
5 changes: 5 additions & 0 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class CollectivePerfTableGen {
// content (but not deduplicating).
absl::Status Dump(const DeviceHloInstructionProfiles& table);

// Merges all of the profiled files under `merge_path`, deduplicates them
// based on fingerprint and writes them to a single
// `DeviceHloInstructionProfiles` proto.
DeviceHloInstructionProfiles Merge(absl::string_view merge_path);

private:
explicit CollectivePerfTableGen(Config config, PjRtEnvironment&& pjrt_env)
: config_(std::move(config)),
Expand Down
12 changes: 11 additions & 1 deletion xla/tools/collective_perf_table_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ int main(int argc, char* argv[]) {
std::string collective_devices_spec_unparsed;
std::string coordinator_address = std::string(kDefaultCoordinatorAddress);
std::string output = std::string(CollectivePerfTableGen::Config::kStdout);
std::string merge_path;

// Parse flags.
std::vector<tsl::Flag> flag_list = {
Expand Down Expand Up @@ -193,6 +194,10 @@ int main(int argc, char* argv[]) {
"Output mode for the program. If set to 'stdout' performance table "
"will be printed to the standard output. If given a file with .pbtxt "
"or .pb extension it will append the contents to that file."),
tsl::Flag("merge_path", &merge_path,
"Path to DeviceHloInstructionProfiles files. When specified it "
"will merge all of the profiled files and write them to a "
"single file specified by `output`."),
};

std::string kUsageString =
Expand All @@ -216,7 +221,12 @@ int main(int argc, char* argv[]) {

std::unique_ptr<CollectivePerfTableGen> gen =
CollectivePerfTableGen::Create(cfg);
DeviceHloInstructionProfiles profiles = gen->ComputeTable();
DeviceHloInstructionProfiles profiles;
if (merge_path.empty()) {
profiles = gen->ComputeTable();
} else {
profiles = gen->Merge(merge_path);
};
CHECK_OK(gen->Dump(profiles));
return 0;
}
Loading