Skip to content

Commit

Permalink
PR #22554: [GPU] Fix cuDNN fusion compiler support of control predece…
Browse files Browse the repository at this point in the history
…ssors.

Imported from GitHub PR #22554

Control predecessors have to be handled like in https://github.com/openxla/xla/blob/28887817aa29aef860211b131e4f6901ef590d4c/xla/service/gpu/transforms/fusion_wrapper.cc#L133-L137 to make the removal of the original instruction safe.
Copybara import of the project:

--
7a98896 by Ilia Sergachev <[email protected]>:

[GPU] Fix cuDNN fusion compiler support of control predecessors.

Merging this change closes #22554

COPYBARA_INTEGRATE_REVIEW=#22554 from openxla:fix_cudnn_fusion_compilation 7a98896
PiperOrigin-RevId: 726463516
  • Loading branch information
sergachev authored and Google-ML-Automation committed Feb 13, 2025
1 parent 7c8a422 commit 7251cc7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 9 additions & 2 deletions xla/backends/gpu/codegen/cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,17 @@ fusion1 {
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
n {
p = f32[32,64] parameter(0)
n = f32[32,64] negate(p)
}
ENTRY e {
p0 = f32[32,96] parameter(0)
p1 = f32[96,64] parameter(1)
ROOT _ = f32[32,64] fusion(p0, p1), kind=kCustom, calls=fusion1,
f = f32[32,64] fusion(p0, p1), kind=kCustom, calls=fusion1,
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
n = f32[32,64] fusion(f), kind=kLoop, calls=n, control-predecessors={f}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kHloText));
Expand All @@ -212,10 +218,11 @@ ENTRY e {
TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::GetTupleElement(m::Fusion())));
GmockMatch(m::Fusion(m::GetTupleElement(m::Fusion()))));
EXPECT_THAT(module->entry_computation()
->root_instruction()
->operand(0)
->operand(0)
->fused_instructions_computation()
->root_instruction(),
GmockMatch(m::Tuple(m::Dot(), m::CustomCall())));
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/transforms/cudnn_fusion_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ absl::StatusOr<HloInstruction*> AddWorkspace(HloInstruction& fusion,
computation->set_root_instruction(output_tuple, true);
HloInstruction* new_fusion = fusion.parent()->AddInstruction(
fusion.CloneWithNewShape(output_tuple->shape()));
TF_RETURN_IF_ERROR(new_fusion->CopyAllControlDepsFrom(&fusion));
TF_RETURN_IF_ERROR(fusion.DropAllControlDeps());
TF_RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(new_fusion, 0))));
TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion));
Expand Down

0 comments on commit 7251cc7

Please sign in to comment.