diff --git a/xla/backends/gpu/codegen/BUILD b/xla/backends/gpu/codegen/BUILD index 149f6c97a6261..7b4df316f7a53 100644 --- a/xla/backends/gpu/codegen/BUILD +++ b/xla/backends/gpu/codegen/BUILD @@ -186,8 +186,8 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:hlo_runner_interface", - "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_executable", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", diff --git a/xla/backends/gpu/codegen/custom.cc b/xla/backends/gpu/codegen/custom.cc index 5d2cb1ef9e000..3d365d8006754 100644 --- a/xla/backends/gpu/codegen/custom.cc +++ b/xla/backends/gpu/codegen/custom.cc @@ -192,11 +192,18 @@ absl::Status CollectSliceInfo( if (arg_slice_instr == nullptr) { return absl::OkStatus(); } + std::optional async_caller = std::nullopt; + if (fusion_instr.parent()->IsAsyncComputation()) { + async_caller = fusion_instr.parent()->AsyncStart(); + } std::vector arg_offsets; for (auto idx_op : arg_slice_instr->index_operands()) { const auto* param = Cast(idx_op); - const auto* offset_value = fusion_instr.operand(param->parameter_number()); + const HloInstruction* offset_value = + async_caller.has_value() + ? (*async_caller)->operand(param->parameter_number()) + : fusion_instr.operand(param->parameter_number()); VLOG(2) << "Offset value:" << offset_value->ToString(); @@ -893,7 +900,9 @@ absl::StatusOr EmitCollective( fusion_instr.backend_config()); const std::string fusion_name = backend_config.fusion_backend_config().custom_fusion_config().name(); - TF_RET_CHECK(isDynamic == (fusion_name == "dynamic_address_computation")) + TF_RET_CHECK(isDynamic == + (fusion_name == + kDynamicSliceFusionWithDynamicAddressComputationConfigName)) << "Dynamic index operation found in a fusion instruction that is not " "labelled dynamic_address_computation"; @@ -961,13 +970,21 @@ absl::StatusOr EmitCollective( /*destination_value=*/nullptr}); auto collective_start_thunk = std::make_unique(thunk_info, instr, buffers); - auto collective_done_thunk = std::make_unique( - /*kind=*/collective_done_thunk_kind, - /*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(instr), - /*async_events=*/collective_start_thunk->async_events(), - /*async_stream_kind=*/AsyncStreamKind::kCollective); + std::shared_ptr async_events = + collective_start_thunk->async_events(); seq.emplace_back(std::move(collective_start_thunk)); - seq.emplace_back(std::move(collective_done_thunk)); + // If the fusion is async, we do not emit the done thunk at the end. + if (fusion_instr.parent()->IsAsyncComputation()) { + ir_emitter_context.collectives_async_events().insert( + {fusion_instr.parent()->AsyncStart(), async_events}); + } else { + auto collective_done_thunk = std::make_unique( + /*kind=*/collective_done_thunk_kind, + /*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*async_events=*/async_events, + /*async_stream_kind=*/AsyncStreamKind::kCollective); + seq.emplace_back(std::move(collective_done_thunk)); + } } else { return implementable_status; } diff --git a/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc b/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc index e6309f38da668..f00427275bf7f 100644 --- a/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc +++ b/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include #include -#include #include #include +#include #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" @@ -35,8 +35,8 @@ limitations under the License. #include "xla/hlo/testlib/filecheck.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/executable.h" -#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" @@ -83,6 +83,14 @@ namespace xla { namespace gpu { namespace { +using ::testing::ElementsAre; +using ::testing::Optional; +using ::testing::VariantWith; + +MATCHER_P(ThunkKindIs, kind, "") { + return ExplainMatchResult(::testing::Eq(kind), arg->kind(), result_listener); +} + class DynamicSliceFusionTest : public HloTestBase { public: HloModuleConfig GetModuleConfigWithoutCommandBuffer() { @@ -107,17 +115,8 @@ class DynamicSliceFusionTest : public HloTestBase { if (!computation->IsFusionComputation()) { continue; } - auto backend_config = computation->FusionInstruction() - ->backend_config(); - if (backend_config.ok()) { - const FusionBackendConfig& fusion_backend_config = - backend_config.value().fusion_backend_config(); - const std::string name = - fusion_backend_config.custom_fusion_config().name(); - if (name == "dynamic_address_computation" || - name == "address_computation") { - computations.push_back(computation); - } + if (IsDynamicSliceFusion(computation->FusionInstruction())) { + computations.push_back(computation); } } return computations; @@ -3138,11 +3137,22 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterSlice) { test_runner_as_hlo_runner().ExecutableFromWrapped( wrapped_executable.get())); GpuExecutable* gpu_exec = dynamic_cast(exec); - ASSERT_EQ(gpu_exec->GetThunk().thunks().size(), 2ul); - auto& rs_start_thunk = gpu_exec->GetThunk().thunks()[0]; - auto& rs_done_thunk = gpu_exec->GetThunk().thunks()[1]; - ASSERT_EQ(rs_start_thunk->kind(), Thunk::kNcclReduceScatterStart); - ASSERT_EQ(rs_done_thunk->kind(), Thunk::kNcclReduceScatterDone); + + // The pattern we have here is a static slice along with reduce-scatter + // operation. With this pattern, we can compute the offset at compile time and + // we do not need to emit a dynamic slice thunk to compute the offset at + // runtime. So, we expect to see kNcclReduceScatterStart and + // kNcclReduceScatterDone thunks. We also expect to see surrounding + // kWaitsForStreams thunks because dynamic slice fusion with a collective hero + // is converted into an async operation. The kWaitForStreams thunks are + // expected because of the async operation. + ASSERT_EQ(gpu_exec->GetThunk().thunks().size(), 4ul); + EXPECT_THAT( + gpu_exec->GetThunk().thunks(), + ::testing::ElementsAre(ThunkKindIs(Thunk::kWaitForStreams), + ThunkKindIs(Thunk::kNcclReduceScatterStart), + ThunkKindIs(Thunk::kNcclReduceScatterDone), + ThunkKindIs(Thunk::kWaitForStreams))); ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt), @@ -3279,6 +3289,87 @@ TEST_F(DynamicSliceFusionTest, /*run_hlo_passes=*/false, /*use_threads=*/true, error)); } +TEST_F(DynamicSliceFusionTest, + AsyncDynamicSliceFusionWithCollectiveOverlapsWithComputeThunk) { + const char* hlo = R"( + HloModule test-clone, replica_count=2 + + add { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT add = s32[] add(x, y) + } + + dynamic-slice-fusion { + p1 = s32[2,2,32]{2,1,0} parameter(1) + p0 = s32[8,32]{1,0} parameter(0) + slice = s32[4,32]{1,0} slice(p0), slice={[4:8], [0:32]} + rs = s32[2,32]{1,0} reduce-scatter(slice), replica_groups={{0,1}}, dimensions={0}, to_apply=add + bitcast = s32[1,2,32]{2,1,0} bitcast(rs) + p2 = s32[] parameter(2) + p3 = s32[] parameter(3) + ROOT dynamic-update-slice = s32[2,2,32]{2,1,0} dynamic-update-slice(p1, bitcast, p2, p3, p3) + } + + ENTRY main { + source = s32[8,32]{1,0} parameter(1) + destination = s32[2,2,32]{2,1,0} parameter(0) + copy = s32[2,2,32]{2,1,0} copy(destination) + c1 = s32[] constant(1) + c0 = s32[] constant(0) + fusion-start = ((s32[8,32]{1,0}, s32[2,2,32]{2,1,0}, s32[], s32[]), s32[2,2,32]{2,1,0}, u32[]) fusion-start(source, copy, c1, c0), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + fusion-done = s32[2,2,32]{2,1,0} fusion-done(fusion-start), backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + a = s32[1024,1024]{1,0} parameter(2) + b = s32[1024,1024]{1,0} parameter(3) + dot = s32[1024,1024]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = (s32[2,2,32]{2,1,0}, s32[1024,1024]{1,0}) tuple(fusion-done, dot) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr wrapped_exec, + CreateExecutable(hlo_module->Clone(), /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr exec, + test_runner_as_hlo_runner().ExecutableFromWrapped( + std::move(wrapped_exec))); + GpuExecutable* gpu_exec = dynamic_cast(exec.get()); + const ThunkSequence& thunks = gpu_exec->GetThunk().thunks(); + + // This is only needed to ensure that the next checks don't fail. + ASSERT_EQ(thunks.size(), 6); + + // In the following checks, only the order of the thunks matter. + EXPECT_THAT(thunks, + ::testing::ElementsAre(ThunkKindIs(Thunk::kCopy), + ThunkKindIs(Thunk::kWaitForStreams), + ThunkKindIs(Thunk::kDynamicSlice), + ThunkKindIs(Thunk::kKernel), + ThunkKindIs(Thunk::kNcclReduceScatterDone), + ThunkKindIs(Thunk::kWaitForStreams))); + + // Check that the dynamic slice thunk only produces a start thunk, and not a + // done thunk. + DynamicSliceThunk* dynamic_slice_thunk = + dynamic_cast(thunks[2].get()); + ASSERT_NE(dynamic_slice_thunk, nullptr); + const SequentialThunk* embedded_thunk = dynamic_cast( + dynamic_slice_thunk->embedded_thunk()); + ASSERT_NE(embedded_thunk, nullptr); + EXPECT_THAT( + embedded_thunk->thunks(), + ::testing::ElementsAre(ThunkKindIs(Thunk::kNcclReduceScatterStart))); + + // Check that the offsets were propagated as constants, and not as device + // allocated buffers. + auto offsets = dynamic_slice_thunk->get_offsets(); + EXPECT_THAT(offsets, + ElementsAre(std::nullopt, + Optional(ElementsAre(VariantWith(1), + VariantWith(0), + VariantWith(0))))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/backends/gpu/codegen/fusions.cc b/xla/backends/gpu/codegen/fusions.cc index 8a648dfe5696f..9d35434115736 100644 --- a/xla/backends/gpu/codegen/fusions.cc +++ b/xla/backends/gpu/codegen/fusions.cc @@ -89,8 +89,12 @@ std::unique_ptr GetFusionEmitter( switch (analysis.GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { - const auto& config = backend_config.custom_fusion_config(); - if (absl::StrContains(config.name(), "address_computation")) { + const absl::string_view& config_name = + backend_config.custom_fusion_config().name(); + if (config_name == + kDynamicSliceFusionWithStaticAddressComputationConfigName || + config_name == + kDynamicSliceFusionWithDynamicAddressComputationConfigName) { return std::make_unique(analysis); } return std::make_unique(); diff --git a/xla/hlo/parser/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc index 8163bf0019048..ef5180aefc5bb 100644 --- a/xla/hlo/parser/hlo_parser.cc +++ b/xla/hlo/parser/hlo_parser.cc @@ -1473,18 +1473,37 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, } if (metadata) { instruction->set_metadata(*metadata); + if (instruction->IsAsynchronous()) { + instruction->async_wrapped_instruction()->set_metadata(*metadata); + } } if (original_value) { instruction->set_original_value(*original_value); + if (instruction->IsAsynchronous()) { + instruction->async_wrapped_instruction()->set_original_value( + *original_value); + } } if (backend_config) { - instruction->set_raw_backend_config_string(std::move(*backend_config)); + instruction->set_raw_backend_config_string(*backend_config); + if (instruction->IsAsynchronous()) { + instruction->async_wrapped_instruction()->set_raw_backend_config_string( + *backend_config); + } } if (frontend_attributes) { instruction->set_frontend_attributes(*frontend_attributes); + if (instruction->IsAsynchronous()) { + instruction->async_wrapped_instruction()->set_frontend_attributes( + *frontend_attributes); + } } if (statistics_viz) { instruction->set_statistics_viz(*statistics_viz); + if (instruction->IsAsynchronous()) { + instruction->async_wrapped_instruction()->set_statistics_viz( + *statistics_viz); + } } return AddInstruction(name, instruction, name_loc); diff --git a/xla/hlo/parser/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc index c4a281ce737d9..750d223634253 100644 --- a/xla/hlo/parser/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -5924,5 +5924,57 @@ ENTRY main { "error: unexpected attribute \"result_accuracy\""); } +TEST_F(HloParserTest, + AsyncInstructionWithAttributesShouldPropagateToWrappedInstruction) { + const char* hlo = R"( + HloModule test + ENTRY main { + a = s32[] parameter(0), origin={{"v1"}} + b = s32[] parameter(1), origin={{"v2"}} + add-start = ((s32[], s32[]), s32[], s32[]) add-start(a, b), + metadata={op_type="add" op_name="sample name" source_file="path/to/test.cc" source_line=68}, + backend_config="foo\" bar", + frontend_attributes={attr_a="test_a",attr_b="b"}, + statistics={visualizing_index=1,stat-1=33,stat-2=44} + ROOT add-done = s32[] add-done(add-start), origin={{"v3"}} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + // Check the wrapped instruction. + HloInstruction* wrapped_instr = + m->entry_computation()->root_instruction()->async_wrapped_instruction(); + EXPECT_EQ(wrapped_instr->metadata().op_name(), "sample name"); + EXPECT_EQ(wrapped_instr->metadata().op_type(), "add"); + EXPECT_EQ(wrapped_instr->metadata().source_file(), "path/to/test.cc"); + EXPECT_EQ(wrapped_instr->metadata().source_line(), 68); + EXPECT_EQ(wrapped_instr->raw_backend_config_string(), "foo\" bar"); + EXPECT_EQ(wrapped_instr->frontend_attributes().map().size(), 2); + EXPECT_EQ(wrapped_instr->frontend_attributes().map().at("attr_a"), "test_a"); + EXPECT_EQ(wrapped_instr->frontend_attributes().map().at("attr_b"), "b"); + EXPECT_EQ(wrapped_instr->statistics_viz().stat_index_to_visualize(), 1); + EXPECT_EQ(wrapped_instr->statistics_viz().statistics_size(), 2); + EXPECT_EQ(wrapped_instr->statistics_viz().statistics(0).stat_name(), + "stat-1"); + EXPECT_EQ(wrapped_instr->statistics_viz().statistics(0).stat_val(), 33); + EXPECT_EQ(wrapped_instr->statistics_viz().statistics(1).stat_name(), + "stat-2"); + EXPECT_EQ(wrapped_instr->statistics_viz().statistics(1).stat_val(), 44); + EXPECT_EQ(OriginalValueToString(*wrapped_instr->original_value()), + "{\"v3\"}"); + // Check the async-start and async-done instructions. + HloInstruction* async_done = m->entry_computation()->root_instruction(); + HloInstruction* async_start = async_done->async_chain_start(); + EXPECT_EQ(async_start->metadata().DebugString(), + wrapped_instr->metadata().DebugString()); + EXPECT_EQ(async_start->raw_backend_config_string(), + wrapped_instr->raw_backend_config_string()); + EXPECT_EQ(async_start->frontend_attributes().DebugString(), + wrapped_instr->frontend_attributes().DebugString()); + EXPECT_EQ(async_start->statistics_viz().DebugString(), + wrapped_instr->statistics_viz().DebugString()); + EXPECT_EQ(OriginalValueToString(*async_done->original_value()), + OriginalValueToString(*wrapped_instr->original_value())); +} + } // namespace } // namespace xla diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 62141945e2f57..08fae4015c367 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -385,6 +385,7 @@ cc_library( "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_traversal", "//xla/mlir/utils:error_util", "//xla/mlir_hlo:transforms_gpu_passes", "//xla/service:buffer_assignment", @@ -640,6 +641,7 @@ cc_library( hdrs = ["ir_emission_utils.h"], compatible_with = get_compatible_with_portable(), deps = [ + ":backend_configs_cc", ":target_util", "//xla:literal", "//xla:shape_util", @@ -1487,6 +1489,7 @@ cc_library( "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/hlo/utils:hlo_query", + "//xla/hlo/utils:hlo_traversal", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service/gpu/autotuning:autotuner_util", "//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner", @@ -1562,6 +1565,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:buffer_value", "//xla/service:call_inliner", + "//xla/service:collective_ops_utils", "//xla/service:collective_permute_decomposer", "//xla/service:collective_pipeliner", "//xla/service:collective_utils", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 04d3e43f9e9ea..10b59d1c4405f 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -122,6 +122,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" #include "xla/hlo/transforms/while_loop_trip_count_annotator.h" +#include "xla/hlo/utils/hlo_traversal.h" #include "xla/maybe_owning.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_reduce_reassociate.h" @@ -130,6 +131,7 @@ limitations under the License. #include "xla/service/batchnorm_expander.h" #include "xla/service/buffer_assignment.h" #include "xla/service/call_inliner.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/collective_permute_decomposer.h" #include "xla/service/collective_pipeliner.h" #include "xla/service/collective_utils.h" @@ -1248,6 +1250,17 @@ absl::Status RunDynamicSliceFusionPasses(HloModule* hlo_module, TF_ASSIGN_OR_RETURN(se::Platform * platform, se::PlatformManager::PlatformWithId(platform_id)); pipeline.AddPass(platform->Name()); + pipeline.AddPass([](const HloInstruction* instr) { + if (!IsDynamicSliceFusion(instr)) { + return false; + } + std::optional hero_op = HloBfsFindIf( + {instr->fused_instructions_computation()->root_instruction()}, + [](const HloInstruction* instr) -> bool { + return IsCollective(instr); + }); + return hero_op.has_value(); + }); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index e7616435bbf5b..a83adec1c415a 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -1790,6 +1790,72 @@ ROOT tmp_11 = f64[3,2]{1,0} reshape(tmp_10) })"); } +TEST_F(GpuCompilerTest, + DynamicSliceFusionWithCollectiveShouldWrapInAsyncAndTestE2E) { + const char* hlo = R"( + HloModule test, replica_count=2 + add { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT add = s32[] add(x, y) + } + ENTRY main { + destination = s32[2,2,32] parameter(0) + c1 = s32[] constant(1) + c0 = s32[] constant(0) + c4 = s32[] constant(4) + source = s32[8,32] parameter(1) + a = s32[1024,1024] parameter(2) + b = s32[1024,1024] parameter(3) + slice = s32[4,32] slice(source), slice={[4:8], [0:32]} + rs = s32[2,32] reduce-scatter(slice), replica_groups={{0,1}}, dimensions={0}, to_apply=add + reshape = s32[1,2,32] reshape(rs) + dus = s32[2,2,32] dynamic-update-slice(destination, reshape, c1, c0, c0) + dot = s32[1024,1024] dot(a,b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = tuple(dus,dot) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + std::unique_ptr m_ref = m->Clone(); + m->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(true); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr wrapped_exec, + CreateExecutable(m->Clone(), /*run_hlo_passes=*/true)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr exec, + test_runner_as_hlo_runner().ExecutableFromWrapped( + std::move(wrapped_exec))); + const char* kExpected = R"( + // CHECK: dynamic-slice-fusion{{.+}} { + // CHECK: %[[slice:.+]] = {{.+}} slice({{.+}}), slice={[4:8], [0:32]} + // CHECK: %[[rs:.+]] = {{.+}} reduce-scatter(%[[slice]]), + // CHECK-SAME{LITERAL}: replica_groups={{0,1}}, dimensions={0} + // CHECK: %[[bitcast:.+]] = {{.+}} bitcast(%[[rs]]) + // CHECK: ROOT {{.+}} = {{.+}} dynamic-update-slice({{.+}}, %[[bitcast]], {{.+}}) + // CHECK: ENTRY + // CHECK: %[[fusion_start:.+]] = {{.+}} fusion-start({{.+}}), kind=kCustom, {{.+}}"name":"dynamic_address_computation" + // CHECK-NEXT: %[[wrapped_dot:.+]] = {{.+}} fusion({{.+}}), kind=kLoop + // CHECK-NEXT: %[[fusion_done:.+]] = {{.+}} fusion-done(%[[fusion_start]]), {{.+}}"name":"dynamic_address_computation" + // CHECK: ROOT {{.+}} = {{.+}} tuple(%[[fusion_done]], %[[wrapped_dot]]) + )"; + EXPECT_THAT( + RunFileCheck(exec->module().ToString(HloPrintOptions{} + .set_print_operand_shape(false) + .set_print_metadata(false)), + kExpected), + ::tsl::testing::IsOkAndHolds(true)); + + if (test_runner().device_count() < 2) { + GTEST_SKIP() << "Skipping test as it requires at least 2 devices."; + } + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(m), std::move(m_ref), + /*run_hlo_passes=*/true, + /*use_threads=*/true, + std::nullopt)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 8fc966cbf79ba..2f0e5fa6db657 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -54,6 +54,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -638,5 +639,30 @@ absl::StatusOr GetProtoFingerprint( return absl::WebSafeBase64Escape(result); } +std::optional GetCustomFusionConfigName( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kFusion || + instr->fusion_kind() != HloInstruction::FusionKind::kCustom) { + return std::nullopt; + } + absl::StatusOr backend_config = + instr->backend_config(); + if (!backend_config.ok() || !backend_config->has_fusion_backend_config()) { + return std::nullopt; + } + const FusionBackendConfig& fusion_backend_config = + backend_config->fusion_backend_config(); + if (!fusion_backend_config.has_custom_fusion_config()) { + return std::nullopt; + } + return fusion_backend_config.custom_fusion_config().name(); +} + +bool IsDynamicSliceFusion(const HloInstruction* instr) { + std::optional name = GetCustomFusionConfigName(instr); + return name == kDynamicSliceFusionWithStaticAddressComputationConfigName || + name == kDynamicSliceFusionWithDynamicAddressComputationConfigName; +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emission_utils.h b/xla/service/gpu/ir_emission_utils.h index 3432cec278bee..3f65c32ab700e 100644 --- a/xla/service/gpu/ir_emission_utils.h +++ b/xla/service/gpu/ir_emission_utils.h @@ -88,6 +88,29 @@ inline constexpr absl::string_view kUncompilableFusion = inline constexpr absl::string_view kTopKCustomCallTarget = "__gpu$TopK"; +// The name of the custom fusion config for dynamic slice fusion with static +// slices, such that the offset can be computed at compile time. +inline constexpr absl::string_view + kDynamicSliceFusionWithStaticAddressComputationConfigName = + "address_computation"; +// The name of the custom fusion config for dynamic slice fusion with dynamic +// slices, such that the offset is computed at runtime. +inline constexpr absl::string_view + kDynamicSliceFusionWithDynamicAddressComputationConfigName = + "dynamic_address_computation"; + +// Returns the name of the custom fusion config if the given instruction is a +// custom fusion and has a custom fusion name, otherwise returns std::nullopt. +// The custom fusion name is basically the value of +// instr.backend_config().fusion_backend_config().custom_fusion_config().name(). +// If any of this does not exist in the chain, then we return std::nullopt. +std::optional GetCustomFusionConfigName( + const HloInstruction* instr); + +// Returns true if the given instruction is a custom fusion for dynamic slice +// fusion. This is determined by checking the name of custom fusion config. +bool IsDynamicSliceFusion(const HloInstruction* instr); + // Returns true if `hlo` will be implemented as a call to a cuSolver routine. // // This returns true if `hlo` is a CustomCall HLO with a call target equal to diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index b5b2ad5afccb0..5f8552b560408 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -80,6 +80,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_traversal.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -2574,6 +2575,19 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk( return absl::OkStatus(); } +// If the fusion instruction is a dynamic-slice-fusion instruction, with a +// collective hero operation, then this function returns the collective +// operation. Returns std::nullopt otherwise. +std::optional GetCollectiveHeroForDynamicSliceFusion( + const HloFusionInstruction* instruction) { + if (!IsDynamicSliceFusion(instruction)) { + return std::nullopt; + } + return HloBfsFindIf( + {instruction->fused_instructions_computation()->root_instruction()}, + [](const HloInstruction* instr) { return IsCollective(instr); }); +} + absl::Status IrEmitterUnnested::EmitHloInstruction( const HloInstruction* instr) { switch (instr->opcode()) { @@ -2609,7 +2623,27 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( return EmitNcclAsyncDone(Thunk::kNcclRaggedAllToAllDone, instr); case HloOpcode::kCollectiveBroadcast: return EmitNcclAsyncDone(Thunk::kNcclCollectiveBroadcastDone, instr); - case HloOpcode::kFusion: + case HloOpcode::kFusion: { + auto collective_hero = GetCollectiveHeroForDynamicSliceFusion( + Cast(wrapped)); + if (collective_hero.has_value()) { + switch ((*collective_hero)->opcode()) { + case HloOpcode::kReduceScatter: + TF_RETURN_IF_ERROR( + EmitNcclAsyncDone(Thunk::kNcclReduceScatterDone, instr)); + break; + default: + return absl::InternalError(absl::StrFormat( + "Unhandled collective in dynamic slice fusion " + "instruction: %s", + (*collective_hero) + ->fused_instructions_computation() + ->ToString())); + } + } + // We still want to emit the stream done thunk. + [[clang::fallthrough]]; + } case HloOpcode::kCall: case HloOpcode::kCustomCall: { // Wait until the concurrent stream has finished. diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 9b45c3096f8e9..153555d4cd12c 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -694,6 +694,7 @@ xla_test( "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:verified_hlo_module", "//xla/service:executable", + "//xla/service:hlo_module_config", "//xla/service:hlo_runner_interface", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_executable", diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.cc b/xla/service/gpu/transforms/command_buffer_scheduling.cc index f0f84e291e46f..ba06effb219cc 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -109,7 +109,17 @@ static bool IsAsyncStartCommand(const HloInstruction* hlo, return config.enabled_commands.contains(DebugOptions::CUBLAS); } if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { - return config.enabled_commands.contains(DebugOptions::FUSION); + // We currently only support static address computations in command + // buffers. + if (IsDynamicSliceFusion(hlo->async_wrapped_instruction())) { + bool is_static_ds_fusion = + GetCustomFusionConfigName(hlo->async_wrapped_instruction()) == + kDynamicSliceFusionWithStaticAddressComputationConfigName; + return is_static_ds_fusion && config.enabled_commands.contains( + DebugOptions::DYNAMIC_SLICE_FUSION); + } else { + return config.enabled_commands.contains(DebugOptions::FUSION); + } } if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter || hlo->async_wrapped_opcode() == HloOpcode::kAllToAll) { @@ -136,7 +146,17 @@ static bool IsAsyncDoneCommand(const HloInstruction* hlo, return config.enabled_commands.contains(DebugOptions::CUBLAS); } if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { - return config.enabled_commands.contains(DebugOptions::FUSION); + // We currently only support static address computations in command + // buffers. + if (IsDynamicSliceFusion(hlo->async_wrapped_instruction())) { + bool is_static_ds_fusion = + GetCustomFusionConfigName(hlo->async_wrapped_instruction()) == + kDynamicSliceFusionWithStaticAddressComputationConfigName; + return is_static_ds_fusion && config.enabled_commands.contains( + DebugOptions::DYNAMIC_SLICE_FUSION); + } else { + return config.enabled_commands.contains(DebugOptions::FUSION); + } } if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter || hlo->async_wrapped_opcode() == HloOpcode::kAllToAll) { @@ -241,9 +261,7 @@ static bool IsCommand(const HloInstruction* hlo, if (backend_config.kind() == kCuDnnFusionKind) { return config.enabled_commands.contains(DebugOptions::CUDNN); } - const auto& custom_config = backend_config.custom_fusion_config(); - if ((custom_config.name() == "address_computation") || - (custom_config.name() == "dynamic_address_computation")) { + if (IsDynamicSliceFusion(fusion)) { auto fusion_analysis = HloFusionAnalysis::Create(*hlo, config.device_description); const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); @@ -254,7 +272,10 @@ static bool IsCommand(const HloInstruction* hlo, }); const HloInstruction* hero = &hero_adaptor->instruction(); - if (custom_config.name() == "address_computation") { + const absl::string_view& config_name = + backend_config.custom_fusion_config().name(); + if (config_name == + kDynamicSliceFusionWithStaticAddressComputationConfigName) { return IsCommand(hero, config) || IsAsyncStartCommand(hero, config); } else { // DynamicSliceFusionRewriter currently only rewrites for dynamic slice @@ -380,7 +401,9 @@ CommandBufferScheduling::CollectCommandBufferSequences( const FusionBackendConfig& backend_config = gpu_config->fusion_backend_config(); const auto& custom_config = backend_config.custom_fusion_config(); - if (custom_config.name() != "dynamic_address_computation") return true; + if (custom_config.name() != + kDynamicSliceFusionWithDynamicAddressComputationConfigName) + return true; auto* fused_computation = fusion->called_computation(); return !absl::c_any_of( diff --git a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 6400e72b5e05c..3c60b7f6ec225 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner_interface.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -1119,83 +1120,6 @@ TEST_F(CommandBufferSchedulingTest, AsyncAlltoAll) { }); } -TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionDynamicSlicing) { - if (backend().platform()->Name() == "Host") { - GTEST_SKIP() << "GPU support required for this test"; - } - const char* hlo = R"( - HloModule jit_slice, replica_count=2 - - add { - a = s32[] parameter(0) - b = s32[] parameter(1) - ROOT add = add(a,b) - } - - ENTRY main.9 { - p0 = s32[2,8,32]{2,1,0} parameter(0) - p1 = s32[8,32]{1,0} parameter(1) - c0 = s32[] constant(0) - c1 = s32[] constant(1) - slice = s32[1,8,32]{2,1,0} dynamic-slice(p0, c1, c0, c0), dynamic_slice_sizes={1,8,32} - input = s32[8,32]{1,0} reshape(slice) - rs = s32[4,32] reduce-scatter(input), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add - ROOT dus = s32[8,32] dynamic-update-slice(p1, rs, c0, c0) - })"; - TF_ASSERT_OK_AND_ASSIGN(auto original_module, - ParseAndReturnVerifiedModule(hlo)); - DebugOptions& original_options = - original_module->mutable_config().mutable_debug_options(); - original_options.set_xla_gpu_enable_dynamic_slice_fusion(true); - - TF_ASSERT_OK_AND_ASSIGN(auto m, - GetOptimizedModule(std::move(original_module))); - - HloModuleConfig config(m->config()); - DebugOptions options(config.debug_options()); - options.set_xla_gpu_graph_min_graph_size(0); - - auto check = [&m, this](DebugOptions options) -> absl::Status { - auto m_clone = m->Clone(); - HloModuleConfig config(m_clone->config()); - config.set_debug_options(options); - m_clone->set_config(config); - TF_ASSIGN_OR_RETURN(std::unique_ptr wrapped_exec, - CreateExecutable(std::move(m_clone), false)); - TF_ASSIGN_OR_RETURN(std::unique_ptr exec, - test_runner_as_hlo_runner().ExecutableFromWrapped( - std::move(wrapped_exec))); - auto gpu_exec = std::unique_ptr( - static_cast(exec.release())); - TF_RET_CHECK(llvm::any_of(gpu_exec->GetThunk().thunks(), - [](const std::unique_ptr& thunk) { - return thunk->kind() == Thunk::kDynamicSlice; - })); - return absl::OkStatus(); - }; - - // With dynamic slicing, no matter what, there should be no command buffer. - // Case 1: FUSION on, COLLECTIVES on - options.clear_xla_gpu_enable_command_buffer(); - options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); - options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); - TF_ASSERT_OK(check(options)); - - // Case 2: FUSION off, COLLECTIVES off - options.clear_xla_gpu_enable_command_buffer(); - TF_ASSERT_OK(check(options)); - - // Case 3: FUSION off, COLLECTIVES on - options.clear_xla_gpu_enable_command_buffer(); - options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); - TF_ASSERT_OK(check(options)); - - // Case 4: FUSION on, COLLECTIVES off - options.clear_xla_gpu_enable_command_buffer(); - options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); - TF_ASSERT_OK(check(options)); -} - TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { if (backend().platform()->Name() == "Host" || backend().device_count() < 2) { GTEST_SKIP() << "Atleast two GPUs required for this test"; @@ -1212,26 +1136,31 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { ENTRY main.9 { p0 = s32[2,8,32]{2,1,0} parameter(0) p1 = s32[8,32]{1,0} parameter(1) + a = s32[128,128] parameter(2) + b = s32[128,128] parameter(3) c0 = s32[] constant(0) c1 = s32[] constant(1) slice = s32[1,8,32]{2,1,0} slice(p0), slice={[1:2], [0:8], [0:32]} input = s32[8,32]{1,0} reshape(slice) - ROOT rs = s32[4,32] reduce-scatter(input), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + rs = s32[4,32] reduce-scatter(input), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + dot = s32[128,128] dot(a,b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = tuple(rs, dot) })"; - TF_ASSERT_OK_AND_ASSIGN(auto m, GetOptimizedModule(hlo)); - - HloModuleConfig config(m->config()); - DebugOptions options(config.debug_options()); - + HloModuleConfig config; + DebugOptions options; + options.set_xla_gpu_enable_dynamic_slice_fusion(true); options.set_xla_gpu_graph_min_graph_size(0); + config.set_debug_options(options); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo, config)); + TF_ASSERT_OK_AND_ASSIGN(m, GetOptimizedModule(std::move(m))); auto get_exec = [&m, this](DebugOptions options) -> absl::StatusOr> { - auto m_clone = m->Clone(); - HloModuleConfig config(m_clone->config()); - config.set_debug_options(options); - m_clone->set_config(config); + std::unique_ptr m_clone = m->Clone(); + m_clone->mutable_config().set_debug_options(options); TF_ASSIGN_OR_RETURN(std::unique_ptr wrapped_exec, CreateExecutable(std::move(m_clone), false)); TF_ASSIGN_OR_RETURN(std::unique_ptr exec, @@ -1241,34 +1170,18 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { static_cast(exec.release())); }; - // FUSION on, COLLECTIVES on -> command buffer + // DYNAMIC_SLICE_FUSION on, FUSION on { options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer( + DebugOptions::DYNAMIC_SLICE_FUSION); options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); - options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); ASSERT_EQ(child->kind(), Thunk::kCommandBuffer); } - // FUSION off, COLLECTIVES off -> no command buffer because collective hero. - { - options.clear_xla_gpu_enable_command_buffer(); - TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); - Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); - ASSERT_NE(child->kind(), Thunk::kCommandBuffer); - } - - // FUSION off, COLLECTIVES on -> command buffer because static slices. - { - options.clear_xla_gpu_enable_command_buffer(); - options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); - TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); - Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); - ASSERT_EQ(child->kind(), Thunk::kCommandBuffer); - } - - // FUSION on, COLLECTIVES off -> no command buffer because collective hero. + // DYNAMIC_SLICE_FUSION off, FUSION on { options.clear_xla_gpu_enable_command_buffer(); options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); @@ -1279,12 +1192,12 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { // Finally compare with/without command buffer. options.clear_xla_gpu_enable_command_buffer(); - auto m_ref = m->Clone(); - config.set_debug_options(options); - m_ref->set_config(config); - - config.set_debug_options(GetDebugOptionsForTest()); - m->set_config(config); + m->mutable_config().set_debug_options(options); + std::unique_ptr m_ref = m->Clone(); + m->mutable_config().mutable_debug_options().add_xla_gpu_enable_command_buffer( + DebugOptions::DYNAMIC_SLICE_FUSION); + m->mutable_config().mutable_debug_options().add_xla_gpu_enable_command_buffer( + DebugOptions::FUSION); ASSERT_TRUE(RunAndCompareTwoModulesReplicated(std::move(m_ref), std::move(m), false, true, std::nullopt)); } @@ -1336,5 +1249,68 @@ TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) { )"); } +TEST_F(CommandBufferSchedulingTest, + DynamicSliceFusionWithDynamicAddressesNotACommand) { + // This is not implemented yet. Once this is implemented in codegen, we can + // remove this test. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP() << "This test requires GPU."; + } + if (test_runner().device_count() < 2) { + GTEST_SKIP() << "Skipping test as it requires at least 2 devices."; + } + const char* hlo = R"( + HloModule test, replica_count=2 + add { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT add = s32[] add(x, y) + } + ENTRY main { + destination = s32[2,2,32] parameter(0) + c1 = s32[] constant(1) + c0 = s32[] constant(0) + c4 = s32[] constant(4) + source = s32[8,32] parameter(1) + a = s32[1024,1024] parameter(2) + b = s32[1024,1024] parameter(3) + slice = s32[4,32] slice(source), slice={[4:8], [0:32]} + rs = s32[2,32] reduce-scatter(slice), replica_groups={{0,1}}, dimensions={0}, to_apply=add + reshape = s32[1,2,32] reshape(rs) + dus = s32[2,2,32] dynamic-update-slice(destination, reshape, c1, c0, c0) + dot = s32[1024,1024] dot(a,b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = tuple(dus,dot) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + auto m_ref = m->Clone(); + m->mutable_config().mutable_debug_options().add_xla_gpu_enable_command_buffer( + DebugOptions::DYNAMIC_SLICE_FUSION); + m->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m_opt, + GetOptimizedModule(m->Clone())); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr wrapped_exec, + CreateExecutable(std::move(m_opt), /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr exec, + test_runner_as_hlo_runner().ExecutableFromWrapped( + std::move(wrapped_exec))); + HloInstruction* fusion_start = + FindInstruction(&exec->module(), HloOpcode::kAsyncStart); + HloInstruction* fusion_done = + FindInstruction(&exec->module(), HloOpcode::kAsyncDone); + ASSERT_NE(fusion_start, nullptr); + ASSERT_NE(fusion_done, nullptr); + EXPECT_EQ(fusion_start->parent(), exec->module().entry_computation()); + EXPECT_EQ(fusion_done->parent(), exec->module().entry_computation()); + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(m_ref), std::move(m), + /*run_hlo_passes=*/true, + /*use_threads=*/true, + /*error=*/std::nullopt)); +} + } // namespace } // namespace xla::gpu diff --git a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index fd36117c90f8d..41dc03d9954e3 100644 --- a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -506,8 +506,9 @@ absl::StatusOr CreateFusionInstruction( *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind("__custom_fusion"); CustomFusionConfig config; - config.set_name(dynamic ? "dynamic_address_computation" - : "address_computation"); + config.set_name(std::string( + dynamic ? kDynamicSliceFusionWithDynamicAddressComputationConfigName + : kDynamicSliceFusionWithStaticAddressComputationConfigName)); *backend_config.mutable_custom_fusion_config() = config; TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config)));