Skip to content

Commit

Permalink
[XLA:GPU] Move ConvertRangeVariablesToDimensions from scatter.cc to i…
Browse files Browse the repository at this point in the history
…ndexing_map.h.

PiperOrigin-RevId: 726001415
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Feb 12, 2025
1 parent 5268b93 commit 8f9f056
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 50 deletions.
51 changes: 1 addition & 50 deletions xla/backends/gpu/codegen/emitters/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,55 +213,6 @@ SmallVector<Value, 4> PadWithZeros(ValueRange values, int64_t size,
return padded_values;
}

// Creates a new indexing map that is the same as `map` but with the range
// variables at `range_var_indices` converted to the new dimensions variables at
// and added to the end of dimension variables list. Potentially, it can be
// moved to indexing_map.h.
IndexingMap ConvertRangeVariableToDimension(
const IndexingMap& map, ArrayRef<int64_t> range_var_indices) {
CHECK(std::is_sorted(range_var_indices.begin(), range_var_indices.end()));
auto* mlir_context = map.GetMLIRContext();

AffineMap affine_map = map.GetAffineMap();
// Update the affine map and the variables.
std::vector<IndexingMap::Variable> dims = map.GetDimVars();
std::vector<IndexingMap::Variable> range_vars;
std::vector<IndexingMap::Variable> rt_vars = map.GetRTVars();
SmallVector<AffineExpr, 4> symbol_replacements;
symbol_replacements.reserve(affine_map.getNumSymbols());
int64_t range_var_count = 0;
int64_t range_var_indices_count = range_var_indices.size();
for (int i = 0; i < affine_map.getNumSymbols(); ++i) {
auto range_var = map.GetRangeVar(i);
if (range_var_count < range_var_indices_count &&
i == range_var_indices[range_var_count]) {
symbol_replacements.push_back(
getAffineDimExpr(affine_map.getNumDims(), mlir_context));
dims.push_back(range_var);
range_var_count++;
} else {
symbol_replacements.push_back(
getAffineSymbolExpr(i - range_var_count, mlir_context));
range_vars.push_back(range_var);
}
}

AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols(
{}, symbol_replacements,
affine_map.getNumDims() + range_var_indices_count,
affine_map.getNumSymbols() - range_var_indices_count);

// Update the constraints.
std::vector<std::pair<AffineExpr, Interval>> constraints;
constraints.reserve(map.GetConstraintsCount());
for (auto constraint : map.GetConstraints()) {
constraints.push_back({constraint.first.replaceSymbols(symbol_replacements),
constraint.second});
}
return IndexingMap{converted_affine_map, std::move(dims),
std::move(range_vars), std::move(rt_vars), constraints};
}

} // namespace

class EmitterHelper {
Expand Down Expand Up @@ -738,7 +689,7 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(

// Convert index_id_loop and index_vector_id to dimension variables.
IndexingMap slice_indexing =
ConvertRangeVariableToDimension(updates_map, {0});
ConvertRangeVariablesToDimensions(updates_map, {0});

// Prepare loop initial values. Inits are packed as
// [index_changed, is_inbounds, index_0, ..., accumulator].
Expand Down
44 changes: 44 additions & 0 deletions xla/hlo/analysis/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1945,4 +1945,48 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const {
return new_indexing_map;
}

IndexingMap ConvertRangeVariablesToDimensions(
const IndexingMap& map, ArrayRef<int64_t> range_var_indices) {
CHECK(std::is_sorted(range_var_indices.begin(), range_var_indices.end()));
auto* mlir_context = map.GetMLIRContext();

AffineMap affine_map = map.GetAffineMap();
// Update the affine map and the variables.
std::vector<IndexingMap::Variable> dims = map.GetDimVars();
std::vector<IndexingMap::Variable> range_vars;
std::vector<IndexingMap::Variable> rt_vars = map.GetRTVars();
SmallVector<AffineExpr, 4> symbol_replacements;
symbol_replacements.reserve(affine_map.getNumSymbols());
int64_t range_var_count = 0;
int64_t range_var_indices_count = range_var_indices.size();
for (int i = 0; i < affine_map.getNumSymbols(); ++i) {
auto range_var = map.GetRangeVar(i);
if (range_var_count < range_var_indices_count &&
i == range_var_indices[range_var_count]) {
symbol_replacements.push_back(getAffineDimExpr(
affine_map.getNumDims() + range_var_count, mlir_context));
dims.push_back(range_var);
range_var_count++;
} else {
symbol_replacements.push_back(
getAffineSymbolExpr(i - range_var_count, mlir_context));
range_vars.push_back(range_var);
}
}
AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols(
{}, symbol_replacements,
affine_map.getNumDims() + range_var_indices_count,
affine_map.getNumSymbols() - range_var_indices_count);

// Update the constraints.
std::vector<std::pair<AffineExpr, Interval>> constraints;
constraints.reserve(map.GetConstraintsCount());
for (auto constraint : map.GetConstraints()) {
constraints.push_back({constraint.first.replaceSymbols(symbol_replacements),
constraint.second});
}
return IndexingMap{converted_affine_map, std::move(dims),
std::move(range_vars), std::move(rt_vars), constraints};
}

} // namespace xla
6 changes: 6 additions & 0 deletions xla/hlo/analysis/indexing_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ std::vector<IndexingMap::Variable> DimVarsFromGPUGrid(
std::vector<IndexingMap::Variable> RangeVarsFromTensorSizes(
absl::Span<const int64_t> tensor_sizes);

// Creates a new indexing map that is the same as `map` but with the range
// variables at `range_var_indices` converted to the new dimensions variables at
// and added to the end of dimension variables list.
IndexingMap ConvertRangeVariablesToDimensions(
const IndexingMap& map, llvm::ArrayRef<int64_t> range_var_indices);

} // namespace xla

#endif // XLA_HLO_ANALYSIS_INDEXING_MAP_H_
24 changes: 24 additions & 0 deletions xla/hlo/analysis/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1669,5 +1669,29 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) {
IndexingMap::Variable{Interval{0, 5}}})});
}

TEST_F(IndexingMapTest, ConvertRangeVariablesToDimensions) {
IndexingMap indexing_map = Parse(R"(
(d0, d1)[to_convert_0, range, to_convert_1]
-> (d1, d0, range + to_convert_1, to_convert_0),
domain:
d0 in [0, 3],
d1 in [0, 3],
to_convert_0 in [0, 2],
range in [0, 1],
to_convert_1 in [0, 3]
)");
EXPECT_THAT(ConvertRangeVariablesToDimensions(indexing_map, {0, 2}),
MatchIndexingMap(R"(
(d0, d1, to_convert_0, to_convert_1)[range]
-> (d1, d0, to_convert_1 + range, to_convert_0),
domain:
d0 in [0, 3],
d1 in [0, 3],
to_convert_0 in [0, 2],
to_convert_1 in [0, 3],
range in [0, 1]
)"));
}

} // namespace
} // namespace xla

0 comments on commit 8f9f056

Please sign in to comment.