From ac1e4f27e3ec4f73a81c204dc52ec4ec169f3b7c Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 12 Feb 2025 11:00:41 -0800 Subject: [PATCH] Add vlogging to aid debugging PiperOrigin-RevId: 726119979 --- xla/service/gpu/BUILD | 1 + xla/service/gpu/gpu_p2p_pipeliner.cc | 47 +++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 3e4959fb70a84..bf57cf0d196cb 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2221,6 +2221,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/xla/service/gpu/gpu_p2p_pipeliner.cc b/xla/service/gpu/gpu_p2p_pipeliner.cc index 74006627f87f3..c14e1c354934e 100644 --- a/xla/service/gpu/gpu_p2p_pipeliner.cc +++ b/xla/service/gpu/gpu_p2p_pipeliner.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" -#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/collective_conflict_analysis.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_pipeliner.h" @@ -220,21 +220,36 @@ absl::Status PostprocessRotatedP2P(HloInstruction* instr) { // conflicting collectives. static absl::Status PostProcessRotatedSendRecvOps( std::vector& rotated_send_recvs) { + VLOG(5) << "Post-processing rotated send/recv ops:"; + if (VLOG_IS_ON(5)) { + for (HloInstruction* instr : rotated_send_recvs) { + VLOG(5) << " - " << instr->ToShortString(); + } + } + // Convert to set for faster lookup. absl::flat_hash_set rotated_send_recvs_set( rotated_send_recvs.begin(), rotated_send_recvs.end()); // Add control dependencies from conflicting collectives to rotated send/recv // ops. - for (HloInstruction* instr : rotated_send_recvs) { - CHECK(instr->opcode() == HloOpcode::kRecv || - instr->opcode() == HloOpcode::kSend); - HloComputation* parent = instr->parent(); + for (HloInstruction* rotated_instr : rotated_send_recvs) { + VLOG(5) << "Working on " << rotated_instr->ToShortString(); + CHECK(rotated_instr->opcode() == HloOpcode::kRecv || + rotated_instr->opcode() == HloOpcode::kSend); + HloComputation* parent = rotated_instr->parent(); + int64_t num_conflicting_collectives = 0; for (HloInstruction* conflicting_collective : - FindAllConflictingCollectives(parent, {instr})) { + FindAllConflictingCollectives(parent, {rotated_instr})) { if (rotated_send_recvs_set.contains(conflicting_collective)) continue; - TF_RETURN_IF_ERROR(conflicting_collective->AddControlDependencyTo(instr)); + num_conflicting_collectives++; + TF_RETURN_IF_ERROR( + conflicting_collective->AddControlDependencyTo(rotated_instr)); + VLOG(5) << "Adding control dependency from " + << conflicting_collective->ToShortString() << " to " + << rotated_instr->ToShortString(); } + VLOG(5) << "Conflicting collectives: " << num_conflicting_collectives; } return absl::OkStatus(); @@ -286,6 +301,13 @@ static absl::Status AddControlDependencies( static absl::Status PostProcessPeeledSendRecvOps( std::vector& peeled_send_recvs) { + VLOG(5) << "Post-processing peeled send/recv ops:"; + if (VLOG_IS_ON(5)) { + for (HloInstruction* instr : peeled_send_recvs) { + VLOG(5) << " - " << instr->ToShortString(); + } + } + // Convert to set for faster lookup. absl::flat_hash_set peeled_send_recvs_set; peeled_send_recvs_set.insert(peeled_send_recvs.begin(), @@ -294,6 +316,7 @@ static absl::Status PostProcessPeeledSendRecvOps( // Add control dependencies between conflicting collectives and peeled // send/recv ops. for (HloInstruction* peeled_instr : peeled_send_recvs) { + VLOG(5) << "Working on " << peeled_instr->ToShortString(); CHECK(peeled_instr->opcode() == HloOpcode::kRecv || peeled_instr->opcode() == HloOpcode::kSend); @@ -303,6 +326,8 @@ static absl::Status PostProcessPeeledSendRecvOps( if (peeled_send_recvs_set.contains(instr)) continue; unpeeled_conflicting_collectives.insert(instr); } + VLOG(5) << "#Conflicting collectives: " + << unpeeled_conflicting_collectives.size(); // Find the while loop. CHECK_EQ(peeled_instr->user_count(), 1); @@ -311,6 +336,7 @@ static absl::Status PostProcessPeeledSendRecvOps( CHECK_EQ(tuple_op->user_count(), 1); HloInstruction* while_op = tuple_op->users().front(); CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(5) << "While loop: " << while_op->ToShortString(); // We separate unpeeled conflicting collectives into two categories: those // dominating the while loop (while loop has a data dependency on them), and @@ -318,6 +344,7 @@ static absl::Status PostProcessPeeledSendRecvOps( std::vector dominating_unpeeled_conflicting_collectives; for (HloInstruction* instr : while_op->parent()->MakeInstructionPostOrderFrom(*while_op)) { + VLOG(5) << " post order instr: " << instr->ToShortString() << "\n"; if (unpeeled_conflicting_collectives.contains(instr)) { dominating_unpeeled_conflicting_collectives.push_back(instr); unpeeled_conflicting_collectives.erase(instr); @@ -390,11 +417,17 @@ absl::StatusOr GpuP2PPipeliner::Run( TF_ASSIGN_OR_RETURN( bool changed, CollectivePipeliner(config).Run(module, execution_threads)); + VLOG(5) << "After pipelining, before post-processing:"; + XLA_VLOG_LINES(5, module->ToString()); + // Post-process rotated and peeled send/recv ops to add control dependencies // with conflicting collectives. TF_RETURN_IF_ERROR(PostProcessRotatedSendRecvOps(rotated_send_recvs)); TF_RETURN_IF_ERROR(PostProcessPeeledSendRecvOps(peeled_send_recvs)); + VLOG(5) << "After post-processing:"; + XLA_VLOG_LINES(5, module->ToString()); + return changed; }