Skip to content

Commit

Permalink
Do not add control dependency from send to recv-done in decomposed co…
Browse files Browse the repository at this point in the history
…llective-permute

This control dependency is not needed in GPU.

PiperOrigin-RevId: 726643669
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 13, 2025
1 parent 405efd8 commit 5322b44
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/service/call_graph.h"
#include "xla/service/collective_conflict_analysis.h"
#include "xla/service/collective_ops_utils.h"
Expand All @@ -47,7 +46,6 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/xla_data.pb.h"

namespace xla {

Expand Down Expand Up @@ -186,7 +184,10 @@ static absl::StatusOr<DecomposedCp> DecomposeCollectivePermute(
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE) {
TF_RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send_done));
}
TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done));
if (pipeline_parallelism_opt_level ==
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE) {
TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done));
}

if (!pipeline_decision.empty()) {
send->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
Expand Down
1 change: 1 addition & 0 deletions xla/service/collective_permute_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ TEST_F(DecomposerTest, OneSendRecvWithIndirectlyConflictingCollectives) {
ASSERT_THAT(cp_cycle2, NotNull());
ASSERT_THAT(cp_fwd_recv_done, NotNull());
ASSERT_THAT(cp_fwd_send_done, NotNull());
ASSERT_THAT(cp_fwd_recv_done->control_predecessors(), ElementsAre());
EXPECT_THAT(cp_fwd_send_done->control_predecessors(),
ElementsAre(cp_fwd_recv_done));
EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done));
Expand Down

0 comments on commit 5322b44

Please sign in to comment.