Skip to content

Commit

Permalink
[XLA] Googly changes
Browse files Browse the repository at this point in the history
Support fall through of handling custom calls when running dynamic dimension inference. We still want to handle the generic cases if we add handlers for specialized cases.

PiperOrigin-RevId: 726487576
  • Loading branch information
vsytch authored and Google-ML-Automation committed Feb 13, 2025
1 parent 5e7a670 commit 7254a48
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
6 changes: 4 additions & 2 deletions xla/service/dynamic_dimension_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,11 @@ absl::Status DynamicDimensionInferenceVisitor::HandleCustomCall(
return absl::OkStatus();
}

bool handled = false;
if (custom_call_handler_) {
TF_RETURN_IF_ERROR(custom_call_handler_(hlo, parent_));
} else {
handled = custom_call_handler_(hlo, parent_);
}
if (!handled) {
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo,
[&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
Expand Down
2 changes: 1 addition & 1 deletion xla/service/dynamic_dimension_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DynamicDimensionInference {
kIgnore,
};
using CustomCallInferenceHandler =
std::function<absl::Status(HloInstruction*, DynamicDimensionInference*)>;
std::function<bool(HloInstruction*, DynamicDimensionInference*)>;

// Generate an assertion which fails the execution if the instruction value is
// false.
Expand Down
2 changes: 1 addition & 1 deletion xla/service/dynamic_dimension_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
CHECK(inference != nullptr);
CHECK(Cast<HloCustomCallInstruction>(hlo) != nullptr);
handler_called = true;
return absl::OkStatus();
return hlo->IsCustomCall("MyCustomOp");
};
TF_ASSERT_OK(RunInference(/*op_supports_dynamism_handler=*/nullptr, handler));

Expand Down
8 changes: 5 additions & 3 deletions xla/service/dynamic_padder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,25 @@ OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) {
return OpDynamismSupport::kNoSupport;
}

absl::Status CustomCallDynamicDimensionInference(
bool CustomCallDynamicDimensionInference(
HloInstruction* hlo, DynamicDimensionInference* inferencer) {
if (hlo->custom_call_target() == "OpWithDynamicLowering") {
if (hlo->shape().IsTuple()) {
// Use the operand's dynamic size as output dynamic size.
HloInstruction* dynamic_size =
inferencer->GetDynamicSize(hlo->mutable_operand(0), {1}, 0);
inferencer->SetDynamicSize(hlo, {1}, 0, dynamic_size);
return true;
} else {
// Use the operand's dynamic size as output dynamic size.
HloInstruction* dynamic_size =
inferencer->GetDynamicSize(hlo->mutable_operand(0), {}, 0);
inferencer->SetDynamicSize(hlo, {}, 0, dynamic_size);
return true;
}
}

return absl::OkStatus();
return false;
}

class DynamicPadderTest : public HloTestBase {
Expand Down Expand Up @@ -685,7 +687,7 @@ ENTRY main {
};
auto custom_call_handler = [](HloInstruction* hlo,
DynamicDimensionInference* inference) {
return absl::OkStatus();
return false;
};
TF_ASSERT_OK(
RunPadder(
Expand Down

0 comments on commit 7254a48

Please sign in to comment.