Skip to content

Commit

Permalink
AdvancedMatchShapeCoveringDynamicIndexInstruction now simulates inde…
Browse files Browse the repository at this point in the history
…x updates.

PiperOrigin-RevId: 726165292
  • Loading branch information
Joshua Wang authored and Google-ML-Automation committed Feb 12, 2025
1 parent 648cf17 commit 803886b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 77 deletions.
142 changes: 66 additions & 76 deletions xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/service/while_loop_unroller.h"

#include <algorithm>
#include <cstdint>
#include <iterator>
#include <memory>
Expand Down Expand Up @@ -580,6 +581,10 @@ std::optional<int64_t> MatchShapeCoveringDynamicIndexInstruction(

// TODO(b/393399049): Replace MatchShapeCoveringDynamicInstruction with this
// one.
// Compared to the MatchShapeCoveringDynamicInstruction() method above, this
// implementation determines whether the (single) dynamic dimension is fully
// coverd by simulating the loop and noting which indices have been covered at
// any point.
std::optional<int64_t> AdvancedMatchShapeCoveringDynamicIndexInstruction(
const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode,
const WhileLoopConfig& config) {
Expand All @@ -600,105 +605,90 @@ std::optional<int64_t> AdvancedMatchShapeCoveringDynamicIndexInstruction(
} else {
return std::nullopt;
}
const HloInstruction* operand = instr->operand(0);
if (input != nullptr && operand != input) {

if (input != nullptr && input != instr->operand(0)) {
VLOG(3) << "Input of dynamic index instruction is not the given operand.";
return std::nullopt;
}
input = instr->operand(0);
const Shape& input_shape = input->shape();

std::optional<int64_t> dynamic_index;
std::optional<Range> dynamic_index_range;
for (int64_t start_index = start_indices_offset;
start_index < instr->operand_count(); ++start_index) {
const HloInstruction* index = instr->operand(start_index);
// All constants must be zero in order to slice the entire shape.
if (Match(index, match::ConstantScalar())) {
std::optional<int64_t> offset =
LiteralUtil::LiteralAsScalarInt64(index->literal());
if (offset.has_value() && offset.value() != 0) {
VLOG(3) << "Constant index " << start_index << " is not zero.";
return std::nullopt;
}
continue;
}
const int64_t num_indices = slice_shape->dimensions_size();
CHECK_EQ(num_indices, input_shape.dimensions_size());
CHECK_EQ(num_indices, instr->operand_count() - start_indices_offset);

// Try to compute a Range for this interval based on the loop induction
// variable's Range.
std::optional<Range> index_range =
IdentifyRangeAsFunctionOfInductionVar(index, config);
if (index_range != std::nullopt && !index_range->IsSingleValue()) {
// In order to cover the whole shape only a single non-constant index is
// allowed.
if (dynamic_index != std::nullopt) {
VLOG(3) << "Multiple non-constant indices.";
return std::nullopt;
}
dynamic_index = start_index - start_indices_offset;
dynamic_index_range = index_range;
std::vector<int64_t> dynamic_indices;
for (int64_t index = 0; index < num_indices; ++index) {
int64_t start_index_offset = start_indices_offset + index;
const HloInstruction* start_index = instr->operand(start_index_offset);

if (!Match(start_index, match::ConstantScalar())) {
dynamic_indices.push_back(index);
continue;
}

VLOG(3) << "Index is neither constant nor a function of loop induction "
"var.";
return std::nullopt;
// This is a non-dynamic index. It must start at zero and have a slice
// size matching the input size.
if (!Match(start_index, match::ConstantScalar(0))) {
VLOG(3) << "Non-dynamic-index dimensions must start at zero; "
"nonzero at index "
<< index;
return std::nullopt;
}
if (slice_shape->dimensions(index) != input_shape.dimensions(index)) {
VLOG(3) << "The slice sizes must match the input shape on "
"non-dynamic-index dimensions; mismatch at index "
<< index;
return std::nullopt;
}
}

if (dynamic_index == std::nullopt) {
if (dynamic_indices.empty()) {
VLOG(3) << "No dynamic index found.";
return std::nullopt;
}

const ConstantValue& min_index_touched = dynamic_index_range->min();
const ConstantValue operand_first_index = ConstantValue::GetZero(
min_index_touched.GetBitwidth(), min_index_touched.IsSigned());
if (min_index_touched.gt(operand_first_index)) {
VLOG(3) << "The dynamic_index must cover index zero, but it begins at "
<< min_index_touched.ToString();
if (dynamic_indices.size() >= 2) {
VLOG(3) << "Too many dynamic indices; found " << dynamic_indices.size();
return std::nullopt;
}

const ConstantValue slice_size =
ConstantValue::Get(slice_shape->dimensions(dynamic_index.value()),
dynamic_index_range->max()->GetBitwidth(),
dynamic_index_range->max()->IsSigned());
const ConstantValue max_index_touched_plus_one =
dynamic_index_range->max()->add(slice_size);
const Shape& operand_shape = operand->shape();
const ConstantValue operand_last_index_plus_one =
ConstantValue::Get(operand_shape.dimensions(dynamic_index.value()),
dynamic_index_range->max()->GetBitwidth(),
dynamic_index_range->max()->IsSigned());
if (max_index_touched_plus_one.lt(operand_last_index_plus_one)) {
const ConstantValue constant_one =
ConstantValue::GetOne(dynamic_index_range->max()->GetBitwidth(),
dynamic_index_range->max()->IsSigned());
VLOG(3) << "The dynamic_index must cover index "
<< operand_last_index_plus_one.sub(constant_one).ToString()
<< " but the last value it takes on is "
<< dynamic_index_range->max()->ToString()
<< " and the slice size is " << slice_size.ToString()
<< " so it only reaches "
<< max_index_touched_plus_one.sub(constant_one).ToString();
std::optional<int64_t> dynamic_index = dynamic_indices[0];
std::optional<Range> dynamic_index_range =
IdentifyRangeAsFunctionOfInductionVar(
instr->operand(start_indices_offset + dynamic_indices[0]), config);
if (dynamic_index_range == std::nullopt ||
!dynamic_index_range->IsBounded() ||
!dynamic_index_range->IsStepKnown()) {
VLOG(3) << "Could not compute compact dynamic index range.";
return std::nullopt;
}

if (dynamic_index_range->step()->gt(slice_size)) {
VLOG(3) << "The dynamic_index has a step size of "
<< dynamic_index_range->step()->ToString()
<< " but the slice size is " << slice_size.ToString();
return std::nullopt;
const int64_t dimension_size = input_shape.dimensions(dynamic_index.value());
// We keep a boolean per possible index of the dynamic dimension, initially
// false.
std::vector<bool> indices_covered(dimension_size);
const int64_t slice_size = slice_shape->dimensions(dynamic_index.value());

// Here, we simulate the loop based on the xla::Range that we have computed
// to represent the input to the DS/DUS.
for (int64_t start_index_value = dynamic_index_range->min().GetSignedValue();
start_index_value <= dynamic_index_range->max()->GetSignedValue();
start_index_value += dynamic_index_range->step()->GetSignedValue()) {
// DS and DUS clamp start indices so that the entire region is in-bounds.
int64_t clamped_start_index_value = std::min(
std::max<int64_t>(start_index_value, 0), dimension_size - slice_size);
// The DS/DUS covers `slice_size` many indices.
for (int64_t index = clamped_start_index_value;
index < clamped_start_index_value + slice_size; ++index) {
indices_covered[index] = true;
}
}

CHECK_EQ(slice_shape->dimensions_size(), operand_shape.dimensions_size());
for (int64_t i = 0; i < slice_shape->dimensions_size(); ++i) {
if (i != dynamic_index &&
slice_shape->dimensions(i) != operand_shape.dimensions(i)) {
VLOG(3) << "The slice sizes must match the operand-shape on "
"non-dynamic-index dimensions.";
for (int index = 0; index < indices_covered.size(); ++index) {
if (!indices_covered[index]) {
VLOG(3) << "Index " << index << " was not covered.";
return std::nullopt;
}
}

return dynamic_index;
}

Expand Down
35 changes: 34 additions & 1 deletion xla/service/while_loop_unroller_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ TEST_F(WhileLoopUnrollerTest, GetUnrollableLoops) {
EXPECT_EQ(unrollable_loops.size(), 2);
}

TEST_F(WhileLoopUnrollerTest, UnrollMutipleLoops) {
TEST_F(WhileLoopUnrollerTest, UnrollMultipleLoops) {
std::string hlo_string = R"(
HloModule SimpleLoop
SimpleLoop.body {
Expand Down Expand Up @@ -1178,6 +1178,25 @@ TEST_F(WhileLoopUnrollerTest,
.has_value());
}

TEST_F(WhileLoopUnrollerTest, AdvancedMatchShapeCoveringDSClamp) {
// In this version of the test, our dimension of interest gets incremented by
// three at time to that it takes on values {0, 3}. The DS has slice size
// two. However, because the dimension size is only 4, the second write gets
// clamped to have start index 2 and all index values {0, 1, 2, 3} are
// retrieved by the DS.
auto module = MakeModuleWithDS(/*start=*/0, /*stop=*/6, /*step=*/3,
/*slice_size=*/2, /*dim_size=*/4);
HloInstruction* loop = module->entry_computation()->root_instruction();
auto config = WhileLoopUnroller::IsLoopUnrollable(loop);
EXPECT_TRUE(config.has_value());
HloComputation* body = module->GetComputationWithName("SimpleLoop.body");
HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2");
HloInstruction* instr = body->GetInstructionWithName("slice");
EXPECT_TRUE(AdvancedMatchShapeCoveringDynamicIndexInstruction(
instr, input, HloOpcode::kDynamicSlice, config.value())
.has_value());
}

TEST_F(WhileLoopUnrollerTest, AdvancedMatchShapeCoveringDUS) {
auto module = MakeModuleWithDUS(/*start=*/0, /*stop=*/3, /*step=*/1,
/*slice_size=*/1, /*dim_size=*/3);
Expand Down Expand Up @@ -1221,6 +1240,20 @@ TEST_F(WhileLoopUnrollerTest,
.has_value());
}

TEST_F(WhileLoopUnrollerTest, AdvancedMatchShapeCoveringDUSClamp) {
auto module = MakeModuleWithDUS(/*start=*/0, /*stop=*/6, /*step=*/3,
/*slice_size=*/2, /*dim_size=*/4);
HloInstruction* loop = module->entry_computation()->root_instruction();
auto config = WhileLoopUnroller::IsLoopUnrollable(loop);
EXPECT_TRUE(config.has_value());
HloComputation* body = module->GetComputationWithName("SimpleLoop.body");
HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2");
HloInstruction* instr = body->GetInstructionWithName("slice");
EXPECT_TRUE(AdvancedMatchShapeCoveringDynamicIndexInstruction(
instr, input, HloOpcode::kDynamicUpdateSlice, config.value())
.has_value());
}

// Unroller pass must remove all the DynamicGte custom-calls.
TEST_F(WhileLoopUnrollerTest, UnrollLoopWithDynamicGte) {
std::string hlo_string = R"(
Expand Down

0 comments on commit 803886b

Please sign in to comment.