Skip to content

Commit

Permalink
Reenable pipeline parallelism test after fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726514634
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 13, 2025
1 parent db332ed commit d7a8923
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions xla/tests/collective_pipeline_parallelism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1341,15 +1341,18 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
after_all_bwd = token[] after-all()
bwd_recv = (f32[16], u32[], token[]) recv(after_all_bwd),
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}, control-predecessors={fwd_send_done}
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}},
control-predecessors={fwd_send_done, fwd_send}
bwd_recv_done = (f32[16], token[]) recv-done(bwd_recv),
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
bwd_send = (f32[16], u32[], token[]) send(next_stage_slice, after_all_bwd),
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}, control-predecessors={bwd_recv_done}
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}},
control-predecessors={bwd_recv_done, bwd_recv}
bwd_send_done = token[] send-done(bwd_send)
fwd_recv = (f32[16], u32[], token[]) recv(after_all_fwd),
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, control-predecessors={bwd_send_done}
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}},
control-predecessors={bwd_send_done, bwd_send}
fwd_recv_done = (f32[16], token[]) recv-done(fwd_recv),
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}
Expand All @@ -1375,15 +1378,16 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
after_all_bwd = token[] after-all()
bwd_recv = (f32[16], u32[], token[]) recv(after_all_bwd),
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
bwd_recv_done = (f32[16], token[]) recv-done(bwd_recv),
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
bwd_recv_done = (f32[16], token[]) recv-done(bwd_recv)
bwd_send = (f32[16], u32[], token[]) send(input_slice, after_all_bwd),
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}, control-predecessors={bwd_recv_done}
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}},
control-predecessors={bwd_recv_done, bwd_recv}
bwd_send_done = token[] send-done(bwd_send)
after_all_fwd = token[] after-all()
fwd_recv = (f32[16], u32[], token[]) recv(after_all_fwd),
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, control-predecessors={bwd_send_done}
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}},
control-predecessors={bwd_send_done, bwd_send}
fwd_recv_done = (f32[16], token[]) recv-done(fwd_recv),
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}
Expand Down Expand Up @@ -1450,12 +1454,6 @@ XLA_TEST_P(CollectivePipelineParallelismTest,
}
)";

// TODO(b/393216980): Enable this test when bug is fixed.
if (xla_gpu_experimental_pipeline_parallelism_opt_level_ !=
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE) {
GTEST_SKIP();
}

const int64_t kNumReplicas = 4;
if (test_runner().device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
Expand Down

0 comments on commit d7a8923

Please sign in to comment.