Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #22640: [NFC] Deduplicate functions between HLO runners. #22742

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions xla/service/hlo_runner_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@ HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
}

namespace {

// Creates an HloModule from the given proto.
absl::StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloProto& proto, const DebugOptions& debug_options) {
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
debug_options));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
}
template <class T>
std::vector<T*> MakePointerVector(absl::Span<T> input_vec) {
std::vector<T*> output_pointers;
Expand All @@ -65,22 +54,31 @@ std::vector<T*> MakePointerVector(absl::Span<T> input_vec) {

} // namespace

absl::StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::CreateModuleFromProto(const HloModuleProto& proto,
const DebugOptions& debug_options) {
TF_ASSIGN_OR_RETURN(
HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto, debug_options));
return HloModule::CreateFromProto(proto, config);
}

/*static*/ absl::StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromBinaryProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
absl::string_view filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(
tsl::ReadBinaryProto(tsl::Env::Default(), filename, &proto));
return HloProtoToModule(proto, debug_options);
tsl::ReadBinaryProto(tsl::Env::Default(), std::string(filename), &proto));
return CreateModuleFromProto(proto.hlo_module(), debug_options);
}

/*static*/ absl::StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromHloTextFile(const std::string& filename,
HloRunnerInterface::ReadModuleFromHloTextFile(absl::string_view filename,
const DebugOptions& debug_options,
const HloParserOptions& options) {
std::string hlo_string;
TF_RETURN_IF_ERROR(
tsl::ReadFileToString(tsl::Env::Default(), filename, &hlo_string));
TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(),
std::string(filename), &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config, options);
Expand Down
13 changes: 10 additions & 3 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,17 @@ class HloRunnerInterface {
static absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
absl::string_view hlo_string, const DebugOptions& debug_options);

// Creates an HloModule from the given proto.
static absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
const HloModuleProto& proto,
const DebugOptions& debug_options = DebugOptions::default_instance());

// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule.
static absl::StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromBinaryProtoFile(const std::string& filename,
const DebugOptions& debug_options);
ReadModuleFromBinaryProtoFile(
absl::string_view filename,
const DebugOptions& debug_options = DebugOptions::default_instance());

// Reads the proto file in xla.HloModule format, creates and returns the
// HloModule.
Expand All @@ -229,7 +235,8 @@ class HloRunnerInterface {
// Reads the hlo text dump file in HloModule::ToString format, creates and
// returns the HloModule.
static absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options,
absl::string_view filename,
const DebugOptions& debug_options = DebugOptions::default_instance(),
const HloParserOptions& options = HloParserOptions());

// Creates a runner-internal executable object given an HLO module and returns
Expand Down
1 change: 1 addition & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ cc_library(
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_runner_interface",
"//xla/tests:test_utils",
"//xla/tools:hlo_control_flow_flattening",
"//xla/tsl/platform:env",
Expand Down
72 changes: 17 additions & 55 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ limitations under the License.
#include "xla/service/computation_placer.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tests/test_utils.h"
Expand All @@ -81,23 +82,6 @@ limitations under the License.
namespace xla {

namespace {
// Creates an HloModule from the given proto.
absl::StatusOr<std::unique_ptr<HloModule>> HloTextToModule(
absl::string_view hlo_text) {
return ParseAndReturnUnverifiedModule(hlo_text);
}

// Creates an HloModule from the given proto.
absl::StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloModuleProto& proto) {
TF_ASSIGN_OR_RETURN(
HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto, DebugOptions()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(proto, config));
return std::move(module);
}

template <typename ElementType>
void PopulateWithSameValue(Literal* literal, ElementType val) {
for (ElementType& element : literal->data<ElementType>()) {
Expand Down Expand Up @@ -449,16 +433,18 @@ FunctionalHloRunner::LoadHloModuleAndArguments(absl::string_view hlo_file,
HloModuleAndArguments hlo_module_and_arguments;
switch (input_format) {
case InputFormat::kText: {
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module,
ReadModuleFromHloTextFile(hlo_file));
TF_ASSIGN_OR_RETURN(
hlo_module_and_arguments.hlo_module,
HloRunnerInterface::ReadModuleFromHloTextFile(hlo_file));
} break;
case InputFormat::kProtoText: {
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module,
ReadModuleFromTextProtoFile(hlo_file));
} break;
case InputFormat::kProtoBinary: {
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module,
ReadModuleFromBinaryProtoFile(hlo_file));
TF_ASSIGN_OR_RETURN(
hlo_module_and_arguments.hlo_module,
HloRunnerInterface::ReadModuleFromBinaryProtoFile(hlo_file));
} break;
case InputFormat::kSnapshotProtoBinary: {
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments,
Expand Down Expand Up @@ -586,29 +572,12 @@ absl::Status FunctionalHloRunner::LoadAndCompile(
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<HloModule>>
FunctionalHloRunner::ReadModuleFromHloTextFile(absl::string_view hlo_file) {
std::string hlo_string;
TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(),
std::string(hlo_file), &hlo_string));
return ParseAndReturnUnverifiedModule(
hlo_string, {}, HloParserOptions().set_keep_module_auto_layouts(true));
}

absl::StatusOr<std::unique_ptr<HloModule>>
FunctionalHloRunner::ReadModuleFromBinaryProtoFile(absl::string_view hlo_file) {
HloProto proto;
TF_RETURN_IF_ERROR(
tsl::ReadBinaryProto(tsl::Env::Default(), std::string(hlo_file), &proto));
return HloProtoToModule(proto.hlo_module());
}

absl::StatusOr<std::unique_ptr<HloModule>>
FunctionalHloRunner::ReadModuleFromTextProtoFile(absl::string_view hlo_file) {
HloProto proto;
TF_RETURN_IF_ERROR(
tsl::ReadTextProto(tsl::Env::Default(), std::string(hlo_file), &proto));
return HloProtoToModule(proto.hlo_module());
return HloRunnerInterface::CreateModuleFromProto(proto.hlo_module());
}

absl::StatusOr<FunctionalHloRunner::HloModuleAndArguments>
Expand All @@ -624,8 +593,9 @@ FunctionalHloRunner::ReadModuleFromSnapshotBinaryProtoFile(
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.arguments.front()[i],
Literal::CreateFromProto(proto.arguments()[i]));
}
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module,
HloProtoToModule(proto.hlo().hlo_module()));
TF_ASSIGN_OR_RETURN(
hlo_module_and_arguments.hlo_module,
HloRunnerInterface::CreateModuleFromProto(proto.hlo().hlo_module()));
return hlo_module_and_arguments;
}

Expand All @@ -636,8 +606,9 @@ FunctionalHloRunner::ReadModuleFromUnoptimizedSnapshotBinaryProtoFile(
HloModuleAndArguments hlo_module_and_arguments;
TF_RETURN_IF_ERROR(
tsl::ReadBinaryProto(tsl::Env::Default(), std::string(hlo_file), &proto));
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module,
HloProtoToModule(proto.hlo_module()));
TF_ASSIGN_OR_RETURN(
hlo_module_and_arguments.hlo_module,
HloRunnerInterface::CreateModuleFromProto(proto.hlo_module()));

for (const auto& arguments : proto.partitions()) {
hlo_module_and_arguments.arguments.emplace_back();
Expand All @@ -659,8 +630,9 @@ FunctionalHloRunner::ReadModuleFromUnoptimizedSnapshotTextProtoFile(
HloModuleAndArguments hlo_module_and_arguments;
TF_RETURN_IF_ERROR(
tsl::ReadTextProto(tsl::Env::Default(), std::string(hlo_file), &proto));
TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module,
HloProtoToModule(proto.hlo_module()));
TF_ASSIGN_OR_RETURN(
hlo_module_and_arguments.hlo_module,
HloRunnerInterface::CreateModuleFromProto(proto.hlo_module()));

for (const auto& arguments : proto.partitions()) {
hlo_module_and_arguments.arguments.emplace_back();
Expand All @@ -675,16 +647,6 @@ FunctionalHloRunner::ReadModuleFromUnoptimizedSnapshotTextProtoFile(
return hlo_module_and_arguments;
}

absl::StatusOr<std::unique_ptr<HloModule>>
FunctionalHloRunner::ReadModuleFromString(absl::string_view hlo_text) {
return HloTextToModule(hlo_text);
}

absl::StatusOr<std::unique_ptr<HloModule>>
FunctionalHloRunner::ReadModuleFromProto(const HloModuleProto& proto) {
return HloProtoToModule(proto);
}

absl::StatusOr<FunctionalHloRunner::PerDeviceLiteralVecType>
FunctionalHloRunner::CompileAndRun(PjRtClient& client,
const DebugOptions& debug_options,
Expand Down
10 changes: 0 additions & 10 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,6 @@ class FunctionalHloRunner {
const RunningOptions& running_options,
std::minstd_rand0* engine = nullptr);

static absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
absl::string_view hlo_file);
static absl::StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromBinaryProtoFile(absl::string_view hlo_file);
static absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
absl::string_view hlo_file);

Expand All @@ -421,12 +417,6 @@ class FunctionalHloRunner {
static absl::StatusOr<HloModuleAndArguments> LoadHloModuleAndArguments(
absl::string_view hlo_file, InputFormat input_format);

static absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromString(
absl::string_view hlo_text);

static absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromProto(
const HloModuleProto& proto);

// This would ideally be private, but we need it for the implementation of
// MultihostHloRunner.
static absl::Status PrepareHloModuleForCompilation(
Expand Down
Loading