diff --git a/xla/service/collective_permute_decomposer.cc b/xla/service/collective_permute_decomposer.cc index c6a97b4441f27..229e7415e8398 100644 --- a/xla/service/collective_permute_decomposer.cc +++ b/xla/service/collective_permute_decomposer.cc @@ -36,7 +36,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/call_graph.h" #include "xla/service/collective_conflict_analysis.h" #include "xla/service/collective_ops_utils.h" @@ -47,7 +46,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/errors.h" -#include "xla/xla_data.pb.h" namespace xla { @@ -186,7 +184,10 @@ static absl::StatusOr DecomposeCollectivePermute( DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE) { TF_RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send_done)); } - TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); + if (pipeline_parallelism_opt_level == + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE) { + TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); + } if (!pipeline_decision.empty()) { send->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision); diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index 9187c541d2b13..cf3e4b6deaf4c 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -1147,6 +1147,7 @@ TEST_F(DecomposerTest, OneSendRecvWithIndirectlyConflictingCollectives) { ASSERT_THAT(cp_cycle2, NotNull()); ASSERT_THAT(cp_fwd_recv_done, NotNull()); ASSERT_THAT(cp_fwd_send_done, NotNull()); + ASSERT_THAT(cp_fwd_recv_done->control_predecessors(), ElementsAre()); EXPECT_THAT(cp_fwd_send_done->control_predecessors(), ElementsAre(cp_fwd_recv_done)); EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done));