Skip to content

Commit

Permalink
[XLA:GPU] Add collective perf table tool.
Browse files Browse the repository at this point in the history
This produces a derating curve of network throughput at an HLO op level.

PiperOrigin-RevId: 725749411
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Feb 11, 2025
1 parent 7d8266c commit 24d2bbe
Show file tree
Hide file tree
Showing 9 changed files with 848 additions and 1 deletion.
1 change: 0 additions & 1 deletion xla/hlo/parser/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ cc_library(
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:status",
],
)

Expand Down
21 changes: 21 additions & 0 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class HloParserImpl : public HloParser {
ParseConvolutionDimensionNumbersOnly();
absl::StatusOr<PaddingConfig> ParsePaddingConfigOnly();
absl::StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
absl::StatusOr<CollectiveDeviceList> ParseCollectiveDeviceListOnly();

private:
// Types of attributes.
Expand Down Expand Up @@ -7143,6 +7144,20 @@ HloParserImpl::ParseReplicaGroupsOnly() {
return replica_groups;
}

absl::StatusOr<CollectiveDeviceList>
HloParserImpl::ParseCollectiveDeviceListOnly() {
lexer_.Lex();
CollectiveDeviceList collective_device_list;
if (!ParseCollectiveDeviceList(&collective_device_list)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after collective device list");
}
return collective_device_list;
}

absl::StatusOr<Window> HloParserImpl::ParseWindowOnly() {
lexer_.Lex();
Window window;
Expand Down Expand Up @@ -7283,6 +7298,12 @@ absl::StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
return parser.ParseReplicaGroupsOnly();
}

absl::StatusOr<CollectiveDeviceList> ParseCollectiveDeviceListOnly(
absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseCollectiveDeviceListOnly();
}

absl::StatusOr<Window> ParseWindow(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseWindowOnly();
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/parser/hlo_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ absl::StatusOr<Layout> ParseLayout(absl::string_view str);
absl::StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
absl::string_view str);

// Parses and returns a `CollectiveDeviceList` from a `str`.
absl::StatusOr<CollectiveDeviceList> ParseCollectiveDeviceListOnly(
absl::string_view str);

class HloParser {
public:
// Runs the parser and constructs the resulting HLO in the given (empty)
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/model/hlo_op_profile.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import "xla/service/hlo.proto";
message HloInstructionProfile {
xla.HloInstructionProto instruction = 1;
repeated xla.HloInstructionProto operands = 5;

int64 clock_cycles = 2;
string fingerprint = 3;
int64 flops = 4;
int64 network_throughput_bytes_per_sec = 6;
}

message HloInstructionProfileList {
Expand Down
81 changes: 81 additions & 0 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,87 @@ cc_library(
]),
)

cc_library(
name = "collective_perf_table_gen",
srcs = ["collective_perf_table_gen.cc"],
hdrs = ["collective_perf_table_gen.h"],
deps = [
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/utils:hlo_query",
"//xla/pjrt:pjrt_client",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:backend",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/service/gpu/model:hlo_op_profiler_lib",
"//xla/service/gpu/model:hlo_op_profiles",
"//xla/tools/multihost_hlo_runner:create_client",
"//xla/tools/multihost_hlo_runner:functional_hlo_runner",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
],
)

cc_library(
name = "collective_perf_table_gen_main_lib",
testonly = True,
srcs = ["collective_perf_table_gen_main.cc"],
compatible_with = None,
tags = [
"no_mac",
],
deps = [
":collective_perf_table_gen",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/tsl/util:command_line_flags",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:platform_port",
] + if_cuda([
"//xla/service:gpu_plugin",
"//xla/stream_executor:cuda_platform",
]),
alwayslink = True,
)

xla_test(
name = "collective_perf_table_gen_test",
srcs = ["collective_perf_table_gen_test.cc"],
backends = ["gpu"],
local_defines = if_cuda(["GOOGLE_CUDA"]),
deps = [
":collective_perf_table_gen",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/tests:hlo_test_base",
"@com_google_googletest//:gtest_main",
],
)

xla_cc_binary(
name = "collective_perf_table_gen_main",
testonly = True,
tags = [
"no_mac",
],
deps = [
":collective_perf_table_gen_main_lib",
],
)

xla_cc_binary(
name = "matmul_perf_table_gen_main",
testonly = True,
Expand Down
Loading

0 comments on commit 24d2bbe

Please sign in to comment.