diff --git a/xla/service/hlo_runner_interface.cc b/xla/service/hlo_runner_interface.cc index 78e94e4ad7b31..5f90ea9383a94 100644 --- a/xla/service/hlo_runner_interface.cc +++ b/xla/service/hlo_runner_interface.cc @@ -42,17 +42,6 @@ HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string, } namespace { - -// Creates an HloModule from the given proto. -absl::StatusOr> HloProtoToModule( - const HloProto& proto, const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN(HloModuleConfig config, - HloModule::CreateModuleConfigFromProto(proto.hlo_module(), - debug_options)); - TF_ASSIGN_OR_RETURN(auto module, - HloModule::CreateFromProto(proto.hlo_module(), config)); - return std::move(module); -} template std::vector MakePointerVector(absl::Span input_vec) { std::vector output_pointers; @@ -65,22 +54,31 @@ std::vector MakePointerVector(absl::Span input_vec) { } // namespace +absl::StatusOr> +HloRunnerInterface::CreateModuleFromProto(const HloModuleProto& proto, + const DebugOptions& debug_options) { + TF_ASSIGN_OR_RETURN( + HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto, debug_options)); + return HloModule::CreateFromProto(proto, config); +} + /*static*/ absl::StatusOr> HloRunnerInterface::ReadModuleFromBinaryProtoFile( - const std::string& filename, const DebugOptions& debug_options) { + absl::string_view filename, const DebugOptions& debug_options) { HloProto proto; TF_RETURN_IF_ERROR( - tsl::ReadBinaryProto(tsl::Env::Default(), filename, &proto)); - return HloProtoToModule(proto, debug_options); + tsl::ReadBinaryProto(tsl::Env::Default(), std::string(filename), &proto)); + return CreateModuleFromProto(proto.hlo_module(), debug_options); } /*static*/ absl::StatusOr> -HloRunnerInterface::ReadModuleFromHloTextFile(const std::string& filename, +HloRunnerInterface::ReadModuleFromHloTextFile(absl::string_view filename, const DebugOptions& debug_options, const HloParserOptions& options) { std::string hlo_string; - TF_RETURN_IF_ERROR( - tsl::ReadFileToString(tsl::Env::Default(), filename, &hlo_string)); + TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), + std::string(filename), &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); return ParseAndReturnUnverifiedModule(hlo_string, config, options); diff --git a/xla/service/hlo_runner_interface.h b/xla/service/hlo_runner_interface.h index ec6bae8e96cb5..f3139ee3b9c4d 100644 --- a/xla/service/hlo_runner_interface.h +++ b/xla/service/hlo_runner_interface.h @@ -214,11 +214,17 @@ class HloRunnerInterface { static absl::StatusOr> CreateModuleFromString( absl::string_view hlo_string, const DebugOptions& debug_options); + // Creates an HloModule from the given proto. + static absl::StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, + const DebugOptions& debug_options = DebugOptions::default_instance()); + // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. static absl::StatusOr> - ReadModuleFromBinaryProtoFile(const std::string& filename, - const DebugOptions& debug_options); + ReadModuleFromBinaryProtoFile( + absl::string_view filename, + const DebugOptions& debug_options = DebugOptions::default_instance()); // Reads the proto file in xla.HloModule format, creates and returns the // HloModule. @@ -229,7 +235,8 @@ class HloRunnerInterface { // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. static absl::StatusOr> ReadModuleFromHloTextFile( - const std::string& filename, const DebugOptions& debug_options, + absl::string_view filename, + const DebugOptions& debug_options = DebugOptions::default_instance(), const HloParserOptions& options = HloParserOptions()); // Creates a runner-internal executable object given an HLO module and returns diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index c2007bbbf69fa..a946b2cd1bfac 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -163,6 +163,7 @@ cc_library( "//xla/service:computation_placer_hdr", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "//xla/service:hlo_runner_interface", "//xla/tests:test_utils", "//xla/tools:hlo_control_flow_flattening", "//xla/tsl/platform:env", diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 3b1f0c17cfafe..55e8c1051b91f 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -62,6 +62,7 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_runner_interface.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tests/test_utils.h" @@ -81,23 +82,6 @@ limitations under the License. namespace xla { namespace { -// Creates an HloModule from the given proto. -absl::StatusOr> HloTextToModule( - absl::string_view hlo_text) { - return ParseAndReturnUnverifiedModule(hlo_text); -} - -// Creates an HloModule from the given proto. -absl::StatusOr> HloProtoToModule( - const HloModuleProto& proto) { - TF_ASSIGN_OR_RETURN( - HloModuleConfig config, - HloModule::CreateModuleConfigFromProto(proto, DebugOptions())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(proto, config)); - return std::move(module); -} - template void PopulateWithSameValue(Literal* literal, ElementType val) { for (ElementType& element : literal->data()) { @@ -449,16 +433,18 @@ FunctionalHloRunner::LoadHloModuleAndArguments(absl::string_view hlo_file, HloModuleAndArguments hlo_module_and_arguments; switch (input_format) { case InputFormat::kText: { - TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module, - ReadModuleFromHloTextFile(hlo_file)); + TF_ASSIGN_OR_RETURN( + hlo_module_and_arguments.hlo_module, + HloRunnerInterface::ReadModuleFromHloTextFile(hlo_file)); } break; case InputFormat::kProtoText: { TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module, ReadModuleFromTextProtoFile(hlo_file)); } break; case InputFormat::kProtoBinary: { - TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module, - ReadModuleFromBinaryProtoFile(hlo_file)); + TF_ASSIGN_OR_RETURN( + hlo_module_and_arguments.hlo_module, + HloRunnerInterface::ReadModuleFromBinaryProtoFile(hlo_file)); } break; case InputFormat::kSnapshotProtoBinary: { TF_ASSIGN_OR_RETURN(hlo_module_and_arguments, @@ -586,29 +572,12 @@ absl::Status FunctionalHloRunner::LoadAndCompile( return absl::OkStatus(); } -absl::StatusOr> -FunctionalHloRunner::ReadModuleFromHloTextFile(absl::string_view hlo_file) { - std::string hlo_string; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), - std::string(hlo_file), &hlo_string)); - return ParseAndReturnUnverifiedModule( - hlo_string, {}, HloParserOptions().set_keep_module_auto_layouts(true)); -} - -absl::StatusOr> -FunctionalHloRunner::ReadModuleFromBinaryProtoFile(absl::string_view hlo_file) { - HloProto proto; - TF_RETURN_IF_ERROR( - tsl::ReadBinaryProto(tsl::Env::Default(), std::string(hlo_file), &proto)); - return HloProtoToModule(proto.hlo_module()); -} - absl::StatusOr> FunctionalHloRunner::ReadModuleFromTextProtoFile(absl::string_view hlo_file) { HloProto proto; TF_RETURN_IF_ERROR( tsl::ReadTextProto(tsl::Env::Default(), std::string(hlo_file), &proto)); - return HloProtoToModule(proto.hlo_module()); + return HloRunnerInterface::CreateModuleFromProto(proto.hlo_module()); } absl::StatusOr @@ -624,8 +593,9 @@ FunctionalHloRunner::ReadModuleFromSnapshotBinaryProtoFile( TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.arguments.front()[i], Literal::CreateFromProto(proto.arguments()[i])); } - TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module, - HloProtoToModule(proto.hlo().hlo_module())); + TF_ASSIGN_OR_RETURN( + hlo_module_and_arguments.hlo_module, + HloRunnerInterface::CreateModuleFromProto(proto.hlo().hlo_module())); return hlo_module_and_arguments; } @@ -636,8 +606,9 @@ FunctionalHloRunner::ReadModuleFromUnoptimizedSnapshotBinaryProtoFile( HloModuleAndArguments hlo_module_and_arguments; TF_RETURN_IF_ERROR( tsl::ReadBinaryProto(tsl::Env::Default(), std::string(hlo_file), &proto)); - TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module, - HloProtoToModule(proto.hlo_module())); + TF_ASSIGN_OR_RETURN( + hlo_module_and_arguments.hlo_module, + HloRunnerInterface::CreateModuleFromProto(proto.hlo_module())); for (const auto& arguments : proto.partitions()) { hlo_module_and_arguments.arguments.emplace_back(); @@ -659,8 +630,9 @@ FunctionalHloRunner::ReadModuleFromUnoptimizedSnapshotTextProtoFile( HloModuleAndArguments hlo_module_and_arguments; TF_RETURN_IF_ERROR( tsl::ReadTextProto(tsl::Env::Default(), std::string(hlo_file), &proto)); - TF_ASSIGN_OR_RETURN(hlo_module_and_arguments.hlo_module, - HloProtoToModule(proto.hlo_module())); + TF_ASSIGN_OR_RETURN( + hlo_module_and_arguments.hlo_module, + HloRunnerInterface::CreateModuleFromProto(proto.hlo_module())); for (const auto& arguments : proto.partitions()) { hlo_module_and_arguments.arguments.emplace_back(); @@ -675,16 +647,6 @@ FunctionalHloRunner::ReadModuleFromUnoptimizedSnapshotTextProtoFile( return hlo_module_and_arguments; } -absl::StatusOr> -FunctionalHloRunner::ReadModuleFromString(absl::string_view hlo_text) { - return HloTextToModule(hlo_text); -} - -absl::StatusOr> -FunctionalHloRunner::ReadModuleFromProto(const HloModuleProto& proto) { - return HloProtoToModule(proto); -} - absl::StatusOr FunctionalHloRunner::CompileAndRun(PjRtClient& client, const DebugOptions& debug_options, diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 06699aa1363fe..c3eb33b45473a 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -404,10 +404,6 @@ class FunctionalHloRunner { const RunningOptions& running_options, std::minstd_rand0* engine = nullptr); - static absl::StatusOr> ReadModuleFromHloTextFile( - absl::string_view hlo_file); - static absl::StatusOr> - ReadModuleFromBinaryProtoFile(absl::string_view hlo_file); static absl::StatusOr> ReadModuleFromTextProtoFile( absl::string_view hlo_file); @@ -421,12 +417,6 @@ class FunctionalHloRunner { static absl::StatusOr LoadHloModuleAndArguments( absl::string_view hlo_file, InputFormat input_format); - static absl::StatusOr> ReadModuleFromString( - absl::string_view hlo_text); - - static absl::StatusOr> ReadModuleFromProto( - const HloModuleProto& proto); - // This would ideally be private, but we need it for the implementation of // MultihostHloRunner. static absl::Status PrepareHloModuleForCompilation(