diff --git a/xla/service/hlo_cost_analysis.cc b/xla/service/hlo_cost_analysis.cc index 39cea851b106b..0cfd7c2b500ee 100644 --- a/xla/service/hlo_cost_analysis.cc +++ b/xla/service/hlo_cost_analysis.cc @@ -103,7 +103,7 @@ absl::Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { auto [it_ignored, inserted] = hlo_properties_.emplace(hlo, std::move(current_properties_)); current_properties_ = Properties(); - TF_RET_CHECK(inserted); + TF_RET_CHECK(inserted) << hlo->name() << " already exists in hlo_properties_"; return absl::OkStatus(); } diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index 4d1442dd174e3..d5f0eb6992fd6 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -107,6 +107,7 @@ xla_cc_test( "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", + "//xla/service/cost_modelling:op_cost", "//xla/service/heap_simulator", "//xla/service/heap_simulator:allocation_block", "//xla/tests:test_utils", @@ -184,6 +185,7 @@ cc_library( "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", + "//xla/service/cost_modelling:op_cost", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", @@ -351,8 +353,8 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:call_graph", "//xla/service:hlo_buffer", - "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", + "//xla/service/cost_modelling:op_cost", "//xla/service/heap_simulator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", @@ -372,6 +374,7 @@ xla_cc_test( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", + "//xla/service/cost_modelling:op_cost", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", @@ -417,6 +420,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", + "//xla/service/cost_modelling:op_cost", "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", @@ -466,6 +470,7 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:call_graph", "//xla/service:hlo_cost_analysis", + "//xla/service/cost_modelling:op_cost", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -532,6 +537,7 @@ xla_cc_test( "//xla/service:buffer_value", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", + "//xla/service/cost_modelling:op_cost", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", @@ -649,6 +655,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", + "//xla/service/cost_modelling:op_cost", "//xla/tests:hlo_test_base", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index 3e3559e17d76c..08a73ee273aaa 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -2211,16 +2211,14 @@ MsaAlgorithm::GetInefficientAllocationSites( if (!allocation->is_copy_like_allocation()) { const HloPosition& defining_position = allocation->defining_position(); - int64_t accessed = - options_.cost_analysis->base_costs().OutputBytesAccessed( - *defining_position.instruction, defining_position.index); + int64_t accessed = options_.cost_analysis->OutputBytesAccessed( + *defining_position.instruction, defining_position.index); VLOG(3) << " pos: " << defining_position.ToString() << ", accessed: " << accessed << " / " << size; } for (const HloUse& use : allocation->uses()) { - int64_t accessed = - options_.cost_analysis->base_costs().OperandBytesAccessed( - *use.instruction, use.operand_number, use.operand_index); + int64_t accessed = options_.cost_analysis->OperandBytesAccessed( + *use.instruction, use.operand_number, use.operand_index); VLOG(3) << " use: " << use.ToString() << ", accessed: " << accessed << " / " << size; } @@ -2248,15 +2246,14 @@ MsaAlgorithm::GetInefficientAllocationSites( copy_bytes += size; } if (position_memory_space == MemorySpace::kAlternate) { - use_bytes += options_.cost_analysis->base_costs().OutputBytesAccessed( + use_bytes += options_.cost_analysis->OutputBytesAccessed( *allocation->defining_position().instruction, allocation->defining_position().index); } if (allocation->memory_space() == MemorySpace::kAlternate) { for (const HloUse& use : allocation->uses()) { - use_bytes += - options_.cost_analysis->base_costs().OperandBytesAccessed( - *use.instruction, use.operand_number, use.operand_index); + use_bytes += options_.cost_analysis->OperandBytesAccessed( + *use.instruction, use.operand_number, use.operand_index); } } } @@ -4569,10 +4566,10 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { << options_.cost_analysis->GetAlternateMemoryBenefit( request.use->hlo_use); VLOG(3) << "Definition bytes accessed = " - << options_.cost_analysis->base_costs().OutputBytesAccessed( + << options_.cost_analysis->OutputBytesAccessed( *defining_position.instruction, defining_position.index) << ", use bytes accessed = " - << options_.cost_analysis->base_costs().OperandBytesAccessed( + << options_.cost_analysis->OperandBytesAccessed( *use.instruction, use.operand_number, use.operand_index); } diff --git a/xla/service/memory_space_assignment/cost_analysis.cc b/xla/service/memory_space_assignment/cost_analysis.cc index f4927c8ecd11c..0b44c1c241a88 100644 --- a/xla/service/memory_space_assignment/cost_analysis.cc +++ b/xla/service/memory_space_assignment/cost_analysis.cc @@ -33,11 +33,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" -#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -45,40 +44,8 @@ limitations under the License. namespace xla { namespace memory_space_assignment { -HloCostAnalysisCosts::HloCostAnalysisCosts( - const HloCostAnalysis& hlo_cost_analysis) - : hlo_cost_analysis_(hlo_cost_analysis) {} - -float HloCostAnalysisCosts::BytesAccessed(const HloInstruction& instruction) { - return static_cast(hlo_cost_analysis_.bytes_accessed(instruction)); -} - -float HloCostAnalysisCosts::OperandBytesAccessed( - const HloInstruction& instruction, int64_t operand_num, - const ShapeIndex& shape_index) { - return static_cast(hlo_cost_analysis_.operand_bytes_accessed( - instruction, operand_num, shape_index)); -} - -float HloCostAnalysisCosts::OutputBytesAccessed( - const HloInstruction& instruction, const ShapeIndex& shape_index) { - return static_cast( - hlo_cost_analysis_.output_bytes_accessed(instruction, shape_index)); -} - -float HloCostAnalysisCosts::ComputeSeconds(const HloInstruction& instruction) { - return std::max( - std::max( - hlo_cost_analysis_.min_latency_seconds(HloCostAnalysis::kFlopsKey), - static_cast(hlo_cost_analysis_.flop_count(instruction)) / - hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey)), - static_cast(hlo_cost_analysis_.transcendental_count(instruction)) / - hlo_cost_analysis_.per_second_rate( - HloCostAnalysis::kTranscendentalsKey)); -} - /*static*/ absl::StatusOr> CostAnalysis::Create( - BaseCosts& base_costs, const CostAnalysisOptions& options, + OpCostManager& op_cost_manager, const CostAnalysisOptions& options, const HloModule& module) { TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); TF_ASSIGN_OR_RETURN(auto hlo_live_range, @@ -87,7 +54,7 @@ float HloCostAnalysisCosts::ComputeSeconds(const HloInstruction& instruction) { auto call_graph = CallGraph::Build(&module); // Using `new` to access a non-public constructor. return absl::WrapUnique( - new CostAnalysis(base_costs, options, std::move(alias_analysis), + new CostAnalysis(op_cost_manager, options, std::move(alias_analysis), std::move(hlo_live_range), std::move(call_graph))); } @@ -104,6 +71,18 @@ double CostAnalysis::DefaultMemBandwidthBytesPerSecond( return options_.default_mem_bandwidth_bytes_per_second; } +float CostAnalysis::OperandBytesAccessed(const HloInstruction& instruction, + int64_t operand_num, + const ShapeIndex& shape_index) const { + return op_cost_manager_.OperandBytesAccessed(instruction, operand_num, + shape_index); +} + +float CostAnalysis::OutputBytesAccessed(const HloInstruction& instruction, + const ShapeIndex& shape_index) const { + return op_cost_manager_.OutputBytesAccessed(instruction, shape_index); +} + float CostAnalysis::GetAlternateMemoryBenefit( const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, CostAnalysis::Cache* cache) const { @@ -298,7 +277,7 @@ float CostAnalysis::GetDefaultMemoryAccessOverhead( // = (window_size / bytes_accessed) * compute_elapsed const float window_size_bytes = options_.pipeline_overhead_window_size_mib * 1024 * 1024; - const float bytes_accessed = base_costs_.BytesAccessed(instruction); + const float bytes_accessed = op_cost_manager_.TotalBytesAccessed(instruction); const float default_memory_bytes_accessed = bytes_accessed - GetBytesAccessedFromAlternateMemory( @@ -318,7 +297,7 @@ float CostAnalysis::GetDefaultMemoryBandwidthIdleTime( absl::Span> operands_in_alternate_mem, absl::Span outputs_in_alternate_mem) const { const float default_memory_bytes_accessed = - base_costs_.BytesAccessed(instruction) - + op_cost_manager_.TotalBytesAccessed(instruction) - GetBytesAccessedFromAlternateMemory( instruction, operands_in_alternate_mem, outputs_in_alternate_mem); const float elapsed_due_to_default_mem = @@ -334,14 +313,14 @@ float CostAnalysis::GetBytesAccessedFromAlternateMemory( absl::Span outputs_in_alternate_mem) const { float bytes_accessed_from_alternate_mem = 0.0; for (auto& operand : operands_in_alternate_mem) { - const float operand_bytes_accessed = base_costs_.OperandBytesAccessed( + const float operand_bytes_accessed = op_cost_manager_.OperandBytesAccessed( instruction, operand.first, operand.second); bytes_accessed_from_alternate_mem += operand_bytes_accessed; } for (auto& shape_idx : outputs_in_alternate_mem) { const float output_bytes_accessed = - base_costs_.OutputBytesAccessed(instruction, shape_idx); + op_cost_manager_.OutputBytesAccessed(instruction, shape_idx); bytes_accessed_from_alternate_mem += output_bytes_accessed; } return bytes_accessed_from_alternate_mem; @@ -370,7 +349,7 @@ float CostAnalysis::GetInstructionElapsedDueToCompute( if (ExcludeInstructionFromElapsed(instruction)) { return 0.0f; } - return base_costs_.ComputeSeconds(instruction); + return op_cost_manager_.ComputeSeconds(instruction); } float CostAnalysis::GetInstructionElapsedDueToMemory( @@ -380,7 +359,7 @@ float CostAnalysis::GetInstructionElapsedDueToMemory( if (ExcludeInstructionFromElapsed(instruction)) { return 0.0f; } - float total_bytes_accessed = base_costs_.BytesAccessed(instruction); + float total_bytes_accessed = op_cost_manager_.TotalBytesAccessed(instruction); float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory( instruction, operands_in_alternate_mem, outputs_in_alternate_mem); float elapsed_due_to_alternate_mem = @@ -398,7 +377,7 @@ float CostAnalysis::GetInstructionElapsedDueToMemory( if (ExcludeInstructionFromElapsed(instruction)) { return 0.0f; } - float total_bytes_accessed = base_costs_.BytesAccessed(instruction); + float total_bytes_accessed = op_cost_manager_.TotalBytesAccessed(instruction); float bytes_accessed_from_alternate_mem = 0.0; for (int operand_num = 0; operand_num < instruction.operand_count(); ++operand_num) { @@ -410,8 +389,8 @@ float CostAnalysis::GetInstructionElapsedDueToMemory( } if (is_in_alternate_mem(operand_num, index, subshape)) { bytes_accessed_from_alternate_mem += - base_costs_.OperandBytesAccessed(instruction, operand_num, - index); + op_cost_manager_.OperandBytesAccessed(instruction, operand_num, + index); } }); } @@ -422,7 +401,7 @@ float CostAnalysis::GetInstructionElapsedDueToMemory( } if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) { bytes_accessed_from_alternate_mem += - base_costs_.OutputBytesAccessed(instruction, index); + op_cost_manager_.OutputBytesAccessed(instruction, index); } }); float elapsed_due_to_alternate_mem = diff --git a/xla/service/memory_space_assignment/cost_analysis.h b/xla/service/memory_space_assignment/cost_analysis.h index 23857dc08b14f..bb44b74886e8e 100644 --- a/xla/service/memory_space_assignment/cost_analysis.h +++ b/xla/service/memory_space_assignment/cost_analysis.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_COST_ANALYSIS_H_ -#include #include #include #include @@ -32,8 +31,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -71,51 +70,6 @@ struct CostAnalysisOptions { [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }; }; -// An interface for getting basic HLO costs. -class BaseCosts { - public: - virtual ~BaseCosts() = default; - - // The number of operand and output bytes accessed by instruction. - virtual float BytesAccessed(const HloInstruction& instruction) = 0; - - // The number of bytes accessed by instruction, for operand operand_num, at - // shape_index. - virtual float OperandBytesAccessed(const HloInstruction& instruction, - int64_t operand_num, - const ShapeIndex& shape_index) = 0; - - // The number of bytes accessed by instruction, in its output, at shape_index. - virtual float OutputBytesAccessed(const HloInstruction& instruction, - const ShapeIndex& shape_index) = 0; - - // The compute cost of instruction. The compute cost assumes 0 memory transfer - // is required. - virtual float ComputeSeconds(const HloInstruction& instruction) = 0; - - protected: - BaseCosts() = default; -}; - -// An implementation of BaseCosts based on HloCostAnalysis. -class HloCostAnalysisCosts : public BaseCosts { - public: - explicit HloCostAnalysisCosts(const HloCostAnalysis& hlo_cost_analysis); - - ~HloCostAnalysisCosts() override = default; - - float BytesAccessed(const HloInstruction& instruction) override; - float OperandBytesAccessed(const HloInstruction& instruction, - int64_t operand_num, - const ShapeIndex& shape_index) override; - float OutputBytesAccessed(const HloInstruction& instruction, - const ShapeIndex& shape_index) override; - float ComputeSeconds(const HloInstruction& instruction) override; - - private: - const HloCostAnalysis& hlo_cost_analysis_; -}; - // A wrapper class around BaseCosts with additional knowledge about the // bandwidths of different memory spaces. class CostAnalysis { @@ -140,13 +94,18 @@ class CostAnalysis { virtual ~CostAnalysis() = default; static absl::StatusOr> Create( - BaseCosts& base_costs, const CostAnalysisOptions& options, + OpCostManager& op_cost_manager, const CostAnalysisOptions& options, const HloModule& module); - BaseCosts& base_costs() const { return base_costs_; } - int64_t GetShapeSizeBytes(const Shape& shape) const; + float OperandBytesAccessed(const HloInstruction& instruction, + int64_t operand_num, + const ShapeIndex& shape_index) const; + + float OutputBytesAccessed(const HloInstruction& instruction, + const ShapeIndex& shape_index) const; + double DefaultMemBandwidthBytesPerSecond( bool use_scaling_factor = false) const; @@ -278,18 +237,20 @@ class CostAnalysis { const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } protected: - CostAnalysis(BaseCosts& base_costs, const CostAnalysisOptions& options, + CostAnalysis(OpCostManager& op_cost_manager, + const CostAnalysisOptions& options, std::unique_ptr alias_analysis, std::unique_ptr hlo_live_range, std::unique_ptr call_graph) - : base_costs_(base_costs), + : op_cost_manager_(op_cost_manager), options_(options), alias_analysis_(std::move(alias_analysis)), hlo_live_range_(std::move(hlo_live_range)), call_graph_(std::move(call_graph)) {} private: - BaseCosts& base_costs_; + // A manager responsible for return basic cost metrics. + OpCostManager& op_cost_manager_; const CostAnalysisOptions options_; std::unique_ptr alias_analysis_; std::unique_ptr hlo_live_range_; diff --git a/xla/service/memory_space_assignment/cost_analysis_test.cc b/xla/service/memory_space_assignment/cost_analysis_test.cc index 9995ff88d9c05..aebbe0689080b 100644 --- a/xla/service/memory_space_assignment/cost_analysis_test.cc +++ b/xla/service/memory_space_assignment/cost_analysis_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -51,21 +52,27 @@ class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { options.set_transcendentals_per_second(16); options.set_flops_min_latency_second(1); hlo_cost_analysis_ = std::make_unique(options); - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept(hlo_cost_analysis_.get())); - hlo_cost_analysis_costs_ = - std::make_unique( - *hlo_cost_analysis_); + hlo_cost_analysis_wrapper_ = + std::make_unique(*hlo_cost_analysis_); + op_cost_manager_ = std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper_), + /*enable_cache=*/false)); TF_ASSIGN_OR_RETURN( cost_analysis_, - CostAnalysis::Create(*hlo_cost_analysis_costs_, options_, *module)); + CostAnalysis::Create(*op_cost_manager_, options_, *module)); return absl::OkStatus(); } CostAnalysisOptions options_; std::unique_ptr hlo_cost_analysis_; - std::unique_ptr - hlo_cost_analysis_costs_; + std::unique_ptr hlo_cost_analysis_wrapper_; + std::unique_ptr op_cost_manager_; std::unique_ptr cost_analysis_; }; diff --git a/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc index 82290fe5ed066..4b16904915a54 100644 --- a/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc +++ b/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -496,7 +496,7 @@ void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( // Keep track of bytes accessed by this value. if (loop_index || prev_iteration_index) { - float bytes_accessed = cost_analysis_.base_costs().OutputBytesAccessed( + float bytes_accessed = cost_analysis_.OutputBytesAccessed( *position.instruction, position.index); pos_bytes += bytes_accessed; VLOG(3) << " accessed: " << bytes_accessed; @@ -526,7 +526,7 @@ void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( // Keep track of bytes accessed by this value. if (loop_index || next_iteration_index) { - float bytes_accessed = cost_analysis_.base_costs().OperandBytesAccessed( + float bytes_accessed = cost_analysis_.OperandBytesAccessed( *use.instruction, use.operand_number, use.operand_index); use_bytes += bytes_accessed; VLOG(3) << " accessed: " << bytes_accessed; diff --git a/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index 656dedb9a5613..cc91243642189 100644 --- a/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" @@ -284,12 +285,19 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { options.set_bytes_per_second(32); options.set_transcendentals_per_second(16); hlo_cost_analysis_ = std::make_unique(options); - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept(hlo_cost_analysis_.get())); - hlo_cost_analysis_costs_ = - std::make_unique(*hlo_cost_analysis_); + hlo_cost_analysis_wrapper_ = + std::make_unique(*hlo_cost_analysis_); + op_cost_manager_ = std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper_), + /*enable_cache=*/false)); TF_ASSIGN_OR_RETURN(cost_analysis_, - CostAnalysis::Create(*hlo_cost_analysis_costs_, + CostAnalysis::Create(*op_cost_manager_, cost_analysis_options_, *module)); TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); TF_ASSIGN_OR_RETURN(live_range_, @@ -683,7 +691,8 @@ ENTRY Entry { Options options_; CostAnalysisOptions cost_analysis_options_; std::unique_ptr hlo_cost_analysis_; - std::unique_ptr hlo_cost_analysis_costs_; + std::unique_ptr hlo_cost_analysis_wrapper_; + std::unique_ptr op_cost_manager_; std::unique_ptr cost_analysis_; std::unique_ptr alias_analysis_; std::unique_ptr live_range_; diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index fcc329a9c28b6..884b6048cfb19 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -59,6 +59,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal_util.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_buffer.h" @@ -11323,14 +11324,22 @@ ENTRY main { properties[HloCostAnalysis::kBytesAccessedKey] = kBytesPerSecond; HloCostAnalysis hlo_cost_analysis(HloCostAnalysis::DefaultShapeSize, properties); + HloCostAnalysisWithAcceptState hlo_cost_analysis_wrapper(hlo_cost_analysis); CostAnalysisOptions cost_analysis_options; cost_analysis_options.default_mem_bandwidth_bytes_per_second = kBytesPerSecond; - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); - TF_ASSERT_OK_AND_ASSIGN( - auto cost_analysis, - FakeCostAnalysis::Create(hlo_cost_analysis_costs, *module, - cost_analysis_options)); + OpCostManager op_cost_manager( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeCostAnalysis::Create(op_cost_manager, *module, + cost_analysis_options)); cost_analysis->SetOverrideForGetInstructionElapsed( [](const HloInstruction& instruction) -> float { return 10.0; }); cost_analysis->SetOverrideForGetAsyncCopyElapsed( diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test_base.h b/xla/service/memory_space_assignment/memory_space_assignment_test_base.h index 66515e8b4c875..8eb5deef631ad 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test_base.h +++ b/xla/service/memory_space_assignment/memory_space_assignment_test_base.h @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/instruction_hoister.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" @@ -144,6 +145,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { } HloCostAnalysis hlo_cost_analysis(hlo_cost_options); + HloCostAnalysisWithAcceptState hlo_cost_analysis_wrapper(hlo_cost_analysis); for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); } @@ -157,10 +159,18 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { if (cost_analysis_options_override) { cost_analysis_options = *cost_analysis_options_override; } - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); + OpCostManager op_cost_manager( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); - auto status_or_cost_analysis = CostAnalysis::Create( - hlo_cost_analysis_costs, cost_analysis_options, *module); + auto status_or_cost_analysis = + CostAnalysis::Create(op_cost_manager, cost_analysis_options, *module); TF_CHECK_OK(status_or_cost_analysis.status()); auto cost_analysis = std::move(status_or_cost_analysis.value()); diff --git a/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc b/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc index 9676903fb2272..ab023d22739a3 100644 --- a/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc +++ b/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/memory_space_assignment/prefetch_interval_picker.h" +#include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/cost_analysis.h" @@ -71,11 +73,22 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { ParseAndReturnVerifiedModule(hlo_string)); HloCostAnalysis hlo_cost_analysis; + std::unique_ptr hlo_cost_analysis_wrapper = + std::make_unique(hlo_cost_analysis); + std::unique_ptr op_cost_manager = + std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); CostAnalysisOptions options; - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( auto cost_analysis, - FakeCostAnalysis::Create(hlo_cost_analysis_costs, *module, options)); + FakeCostAnalysis::Create(*op_cost_manager, *module, options)); CostAnalysisPrefetchIntervalPicker interval_picker( *cost_analysis, /*min_overlap_to_async_copy_ratio=*/1.0, @@ -171,11 +184,22 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { ParseAndReturnVerifiedModule(hlo_string)); HloCostAnalysis hlo_cost_analysis; + std::unique_ptr hlo_cost_analysis_wrapper = + std::make_unique(hlo_cost_analysis); + std::unique_ptr op_cost_manager = + std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); CostAnalysisOptions options; - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( auto cost_analysis, - FakeCostAnalysis::Create(hlo_cost_analysis_costs, *module, options)); + FakeCostAnalysis::Create(*op_cost_manager, *module, options)); CostAnalysisPrefetchIntervalPicker interval_picker( *cost_analysis, /*min_overlap_to_async_copy_ratio=*/1.0, @@ -254,12 +278,23 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); + HloCostAnalysis hlo_cost_analysis; + std::unique_ptr hlo_cost_analysis_wrapper = + std::make_unique(hlo_cost_analysis); + std::unique_ptr op_cost_manager = + std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); TF_ASSERT_OK_AND_ASSIGN( auto cost_analysis, - FakeCostAnalysis::Create(hlo_cost_analysis_costs, *module, options)); + FakeCostAnalysis::Create(*op_cost_manager, *module, options)); CostAnalysisPrefetchIntervalPicker interval_picker( *cost_analysis, /*min_overlap_to_async_copy_ratio=*/1.0, @@ -323,12 +358,23 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); + HloCostAnalysis hlo_cost_analysis; + std::unique_ptr hlo_cost_analysis_wrapper = + std::make_unique(hlo_cost_analysis); + std::unique_ptr op_cost_manager = + std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); TF_ASSERT_OK_AND_ASSIGN( auto cost_analysis, - FakeCostAnalysis::Create(hlo_cost_analysis_costs, *module, options)); + FakeCostAnalysis::Create(*op_cost_manager, *module, options)); CostAnalysisPrefetchIntervalPicker interval_picker( *cost_analysis, /*min_overlap_to_async_copy_ratio=*/1.0, @@ -369,12 +415,23 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); + HloCostAnalysis hlo_cost_analysis; + std::unique_ptr hlo_cost_analysis_wrapper = + std::make_unique(hlo_cost_analysis); + std::unique_ptr op_cost_manager = + std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper), + /*enable_cache=*/false)); TF_ASSERT_OK_AND_ASSIGN( auto cost_analysis, - FakeCostAnalysis::Create(hlo_cost_analysis_costs, *module, options)); + FakeCostAnalysis::Create(*op_cost_manager, *module, options)); cost_analysis->SetOverrideForGetInstructionElapsed( [](const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kTanh) { diff --git a/xla/service/memory_space_assignment/simulator_test.cc b/xla/service/memory_space_assignment/simulator_test.cc index 1e5e97a850479..a3e64c63eb553 100644 --- a/xla/service/memory_space_assignment/simulator_test.cc +++ b/xla/service/memory_space_assignment/simulator_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" @@ -82,18 +83,24 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { // Assume 1 byte per second for testing. tpu_device_options.set_bytes_per_second(1); hlo_cost_analysis_ = std::make_unique(tpu_device_options); - TF_RETURN_IF_ERROR( - module_->entry_computation()->Accept(hlo_cost_analysis_.get())); - hlo_cost_analysis_costs_ = - std::make_unique( - *hlo_cost_analysis_); + hlo_cost_analysis_wrapper_ = + std::make_unique(*hlo_cost_analysis_); + op_cost_manager_ = std::make_unique( + OpCostManager::Options{ + /*enable_cache=*/false, + /*enable_analysis_logging=*/false, + }, + OpCostManager::CalculationNode::CreateLeaf( + "HloCostAnalysis", + CreateHloCostAnalysisCalculator(*hlo_cost_analysis_wrapper_), + /*enable_cache=*/false)); CostAnalysisOptions cost_analysis_options; // Assume 2 byte per second for testing. cost_analysis_options.alternate_mem_bandwidth_bytes_per_second = 2; cost_analysis_options.default_mem_bandwidth_bytes_per_second = 1.0; TF_ASSIGN_OR_RETURN(cost_analysis_, - CostAnalysis::Create(*hlo_cost_analysis_costs_, + CostAnalysis::Create(*op_cost_manager_, cost_analysis_options, *module_)); TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module_.get())); @@ -107,8 +114,8 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { absl::flat_hash_map instruction_map_; std::unique_ptr hlo_cost_analysis_; - std::unique_ptr - hlo_cost_analysis_costs_; + std::unique_ptr hlo_cost_analysis_wrapper_; + std::unique_ptr op_cost_manager_; std::unique_ptr cost_analysis_; std::unique_ptr alias_analysis_; std::unique_ptr hlo_live_range_; diff --git a/xla/service/memory_space_assignment/testing_utils.h b/xla/service/memory_space_assignment/testing_utils.h index c1bf0f5a8648b..d6cdde8953a53 100644 --- a/xla/service/memory_space_assignment/testing_utils.h +++ b/xla/service/memory_space_assignment/testing_utils.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" +#include "xla/service/cost_modelling/op_cost.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/shape.h" @@ -42,7 +43,7 @@ namespace memory_space_assignment { class FakeCostAnalysis : public CostAnalysis { public: static absl::StatusOr> Create( - HloCostAnalysisCosts& cost_analysis_costs, const HloModule& module, + OpCostManager& op_cost_manager, const HloModule& module, const CostAnalysisOptions& options) { TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); TF_ASSIGN_OR_RETURN(auto hlo_live_range, @@ -50,7 +51,7 @@ class FakeCostAnalysis : public CostAnalysis { module.entry_computation())); auto call_graph = CallGraph::Build(&module); return absl::WrapUnique(new FakeCostAnalysis( - cost_analysis_costs, options, std::move(alias_analysis), + op_cost_manager, options, std::move(alias_analysis), std::move(hlo_live_range), std::move(call_graph))); } @@ -104,12 +105,12 @@ class FakeCostAnalysis : public CostAnalysis { } protected: - FakeCostAnalysis(HloCostAnalysisCosts& cost_analysis_costs, + FakeCostAnalysis(OpCostManager& op_cost_manager, const CostAnalysisOptions& options, std::unique_ptr alias_analysis, std::unique_ptr hlo_live_range, std::unique_ptr call_graph) - : CostAnalysis(cost_analysis_costs, options, std::move(alias_analysis), + : CostAnalysis(op_cost_manager, options, std::move(alias_analysis), std::move(hlo_live_range), std::move(call_graph)) {} private: