diff --git a/xla/backends/gpu/codegen/emitters/scatter.cc b/xla/backends/gpu/codegen/emitters/scatter.cc index bb105655ccc87..7adea7229bb7c 100644 --- a/xla/backends/gpu/codegen/emitters/scatter.cc +++ b/xla/backends/gpu/codegen/emitters/scatter.cc @@ -213,55 +213,6 @@ SmallVector PadWithZeros(ValueRange values, int64_t size, return padded_values; } -// Creates a new indexing map that is the same as `map` but with the range -// variables at `range_var_indices` converted to the new dimensions variables at -// and added to the end of dimension variables list. Potentially, it can be -// moved to indexing_map.h. -IndexingMap ConvertRangeVariableToDimension( - const IndexingMap& map, ArrayRef range_var_indices) { - CHECK(std::is_sorted(range_var_indices.begin(), range_var_indices.end())); - auto* mlir_context = map.GetMLIRContext(); - - AffineMap affine_map = map.GetAffineMap(); - // Update the affine map and the variables. - std::vector dims = map.GetDimVars(); - std::vector range_vars; - std::vector rt_vars = map.GetRTVars(); - SmallVector symbol_replacements; - symbol_replacements.reserve(affine_map.getNumSymbols()); - int64_t range_var_count = 0; - int64_t range_var_indices_count = range_var_indices.size(); - for (int i = 0; i < affine_map.getNumSymbols(); ++i) { - auto range_var = map.GetRangeVar(i); - if (range_var_count < range_var_indices_count && - i == range_var_indices[range_var_count]) { - symbol_replacements.push_back( - getAffineDimExpr(affine_map.getNumDims(), mlir_context)); - dims.push_back(range_var); - range_var_count++; - } else { - symbol_replacements.push_back( - getAffineSymbolExpr(i - range_var_count, mlir_context)); - range_vars.push_back(range_var); - } - } - - AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols( - {}, symbol_replacements, - affine_map.getNumDims() + range_var_indices_count, - affine_map.getNumSymbols() - range_var_indices_count); - - // Update the constraints. - std::vector> constraints; - constraints.reserve(map.GetConstraintsCount()); - for (auto constraint : map.GetConstraints()) { - constraints.push_back({constraint.first.replaceSymbols(symbol_replacements), - constraint.second}); - } - return IndexingMap{converted_affine_map, std::move(dims), - std::move(range_vars), std::move(rt_vars), constraints}; -} - } // namespace class EmitterHelper { @@ -738,7 +689,7 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( // Convert index_id_loop and index_vector_id to dimension variables. IndexingMap slice_indexing = - ConvertRangeVariableToDimension(updates_map, {0}); + ConvertRangeVariablesToDimensions(updates_map, {0}); // Prepare loop initial values. Inits are packed as // [index_changed, is_inbounds, index_0, ..., accumulator]. diff --git a/xla/hlo/analysis/indexing_map.cc b/xla/hlo/analysis/indexing_map.cc index dd18b1db6b41e..74819f6118022 100644 --- a/xla/hlo/analysis/indexing_map.cc +++ b/xla/hlo/analysis/indexing_map.cc @@ -1945,4 +1945,48 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const { return new_indexing_map; } +IndexingMap ConvertRangeVariablesToDimensions( + const IndexingMap& map, ArrayRef range_var_indices) { + CHECK(std::is_sorted(range_var_indices.begin(), range_var_indices.end())); + auto* mlir_context = map.GetMLIRContext(); + + AffineMap affine_map = map.GetAffineMap(); + // Update the affine map and the variables. + std::vector dims = map.GetDimVars(); + std::vector range_vars; + std::vector rt_vars = map.GetRTVars(); + SmallVector symbol_replacements; + symbol_replacements.reserve(affine_map.getNumSymbols()); + int64_t range_var_count = 0; + int64_t range_var_indices_count = range_var_indices.size(); + for (int i = 0; i < affine_map.getNumSymbols(); ++i) { + auto range_var = map.GetRangeVar(i); + if (range_var_count < range_var_indices_count && + i == range_var_indices[range_var_count]) { + symbol_replacements.push_back(getAffineDimExpr( + affine_map.getNumDims() + range_var_count, mlir_context)); + dims.push_back(range_var); + range_var_count++; + } else { + symbol_replacements.push_back( + getAffineSymbolExpr(i - range_var_count, mlir_context)); + range_vars.push_back(range_var); + } + } + AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols( + {}, symbol_replacements, + affine_map.getNumDims() + range_var_indices_count, + affine_map.getNumSymbols() - range_var_indices_count); + + // Update the constraints. + std::vector> constraints; + constraints.reserve(map.GetConstraintsCount()); + for (auto constraint : map.GetConstraints()) { + constraints.push_back({constraint.first.replaceSymbols(symbol_replacements), + constraint.second}); + } + return IndexingMap{converted_affine_map, std::move(dims), + std::move(range_vars), std::move(rt_vars), constraints}; +} + } // namespace xla diff --git a/xla/hlo/analysis/indexing_map.h b/xla/hlo/analysis/indexing_map.h index 087b4a3de6590..8059bdb36112f 100644 --- a/xla/hlo/analysis/indexing_map.h +++ b/xla/hlo/analysis/indexing_map.h @@ -486,6 +486,12 @@ std::vector DimVarsFromGPUGrid( std::vector RangeVarsFromTensorSizes( absl::Span tensor_sizes); +// Creates a new indexing map that is the same as `map` but with the range +// variables at `range_var_indices` converted to the new dimensions variables at +// and added to the end of dimension variables list. +IndexingMap ConvertRangeVariablesToDimensions( + const IndexingMap& map, llvm::ArrayRef range_var_indices); + } // namespace xla #endif // XLA_HLO_ANALYSIS_INDEXING_MAP_H_ diff --git a/xla/hlo/analysis/indexing_map_test.cc b/xla/hlo/analysis/indexing_map_test.cc index 9459b77a8a282..add0bb3acf460 100644 --- a/xla/hlo/analysis/indexing_map_test.cc +++ b/xla/hlo/analysis/indexing_map_test.cc @@ -1669,5 +1669,29 @@ TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { IndexingMap::Variable{Interval{0, 5}}})}); } +TEST_F(IndexingMapTest, ConvertRangeVariablesToDimensions) { + IndexingMap indexing_map = Parse(R"( + (d0, d1)[to_convert_0, range, to_convert_1] + -> (d1, d0, range + to_convert_1, to_convert_0), + domain: + d0 in [0, 3], + d1 in [0, 3], + to_convert_0 in [0, 2], + range in [0, 1], + to_convert_1 in [0, 3] + )"); + EXPECT_THAT(ConvertRangeVariablesToDimensions(indexing_map, {0, 2}), + MatchIndexingMap(R"( + (d0, d1, to_convert_0, to_convert_1)[range] + -> (d1, d0, to_convert_1 + range, to_convert_0), + domain: + d0 in [0, 3], + d1 in [0, 3], + to_convert_0 in [0, 2], + to_convert_1 in [0, 3], + range in [0, 1] + )")); +} + } // namespace } // namespace xla