Skip to content

Commit

Permalink
Add vlogging to aid debugging
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726119979
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 12, 2025
1 parent f995626 commit ac1e4f2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down
47 changes: 40 additions & 7 deletions xla/service/gpu/gpu_p2p_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
Expand All @@ -32,7 +33,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/service/collective_conflict_analysis.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/collective_pipeliner.h"
Expand Down Expand Up @@ -220,21 +220,36 @@ absl::Status PostprocessRotatedP2P(HloInstruction* instr) {
// conflicting collectives.
static absl::Status PostProcessRotatedSendRecvOps(
std::vector<HloInstruction*>& rotated_send_recvs) {
VLOG(5) << "Post-processing rotated send/recv ops:";
if (VLOG_IS_ON(5)) {
for (HloInstruction* instr : rotated_send_recvs) {
VLOG(5) << " - " << instr->ToShortString();
}
}

// Convert to set for faster lookup.
absl::flat_hash_set<HloInstruction*> rotated_send_recvs_set(
rotated_send_recvs.begin(), rotated_send_recvs.end());

// Add control dependencies from conflicting collectives to rotated send/recv
// ops.
for (HloInstruction* instr : rotated_send_recvs) {
CHECK(instr->opcode() == HloOpcode::kRecv ||
instr->opcode() == HloOpcode::kSend);
HloComputation* parent = instr->parent();
for (HloInstruction* rotated_instr : rotated_send_recvs) {
VLOG(5) << "Working on " << rotated_instr->ToShortString();
CHECK(rotated_instr->opcode() == HloOpcode::kRecv ||
rotated_instr->opcode() == HloOpcode::kSend);
HloComputation* parent = rotated_instr->parent();
int64_t num_conflicting_collectives = 0;
for (HloInstruction* conflicting_collective :
FindAllConflictingCollectives(parent, {instr})) {
FindAllConflictingCollectives(parent, {rotated_instr})) {
if (rotated_send_recvs_set.contains(conflicting_collective)) continue;
TF_RETURN_IF_ERROR(conflicting_collective->AddControlDependencyTo(instr));
num_conflicting_collectives++;
TF_RETURN_IF_ERROR(
conflicting_collective->AddControlDependencyTo(rotated_instr));
VLOG(5) << "Adding control dependency from "
<< conflicting_collective->ToShortString() << " to "
<< rotated_instr->ToShortString();
}
VLOG(5) << "Conflicting collectives: " << num_conflicting_collectives;
}

return absl::OkStatus();
Expand Down Expand Up @@ -286,6 +301,13 @@ static absl::Status AddControlDependencies(

static absl::Status PostProcessPeeledSendRecvOps(
std::vector<HloInstruction*>& peeled_send_recvs) {
VLOG(5) << "Post-processing peeled send/recv ops:";
if (VLOG_IS_ON(5)) {
for (HloInstruction* instr : peeled_send_recvs) {
VLOG(5) << " - " << instr->ToShortString();
}
}

// Convert to set for faster lookup.
absl::flat_hash_set<HloInstruction*> peeled_send_recvs_set;
peeled_send_recvs_set.insert(peeled_send_recvs.begin(),
Expand All @@ -294,6 +316,7 @@ static absl::Status PostProcessPeeledSendRecvOps(
// Add control dependencies between conflicting collectives and peeled
// send/recv ops.
for (HloInstruction* peeled_instr : peeled_send_recvs) {
VLOG(5) << "Working on " << peeled_instr->ToShortString();
CHECK(peeled_instr->opcode() == HloOpcode::kRecv ||
peeled_instr->opcode() == HloOpcode::kSend);

Expand All @@ -303,6 +326,8 @@ static absl::Status PostProcessPeeledSendRecvOps(
if (peeled_send_recvs_set.contains(instr)) continue;
unpeeled_conflicting_collectives.insert(instr);
}
VLOG(5) << "#Conflicting collectives: "
<< unpeeled_conflicting_collectives.size();

// Find the while loop.
CHECK_EQ(peeled_instr->user_count(), 1);
Expand All @@ -311,13 +336,15 @@ static absl::Status PostProcessPeeledSendRecvOps(
CHECK_EQ(tuple_op->user_count(), 1);
HloInstruction* while_op = tuple_op->users().front();
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
VLOG(5) << "While loop: " << while_op->ToShortString();

// We separate unpeeled conflicting collectives into two categories: those
// dominating the while loop (while loop has a data dependency on them), and
// those that don't.
std::vector<HloInstruction*> dominating_unpeeled_conflicting_collectives;
for (HloInstruction* instr :
while_op->parent()->MakeInstructionPostOrderFrom(*while_op)) {
VLOG(5) << " post order instr: " << instr->ToShortString() << "\n";
if (unpeeled_conflicting_collectives.contains(instr)) {
dominating_unpeeled_conflicting_collectives.push_back(instr);
unpeeled_conflicting_collectives.erase(instr);
Expand Down Expand Up @@ -390,11 +417,17 @@ absl::StatusOr<bool> GpuP2PPipeliner::Run(
TF_ASSIGN_OR_RETURN(
bool changed, CollectivePipeliner(config).Run(module, execution_threads));

VLOG(5) << "After pipelining, before post-processing:";
XLA_VLOG_LINES(5, module->ToString());

// Post-process rotated and peeled send/recv ops to add control dependencies
// with conflicting collectives.
TF_RETURN_IF_ERROR(PostProcessRotatedSendRecvOps(rotated_send_recvs));
TF_RETURN_IF_ERROR(PostProcessPeeledSendRecvOps(peeled_send_recvs));

VLOG(5) << "After post-processing:";
XLA_VLOG_LINES(5, module->ToString());

return changed;
}

Expand Down

0 comments on commit ac1e4f2

Please sign in to comment.