Skip to content

Commit

Permalink
Vectorize group_sizes by including more lhs dimensions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726690023
  • Loading branch information
pravnar authored and Google-ML-Automation committed Feb 14, 2025
1 parent be4f5c8 commit 51ba9b7
Show file tree
Hide file tree
Showing 12 changed files with 955 additions and 151 deletions.
334 changes: 334 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,246 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/B
def HLO_ComplexTensor : RankedTensorOf<[HLO_Complex]>;

def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_PerAxisQuantizedIntTensor, HLO_Token]>;
diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.cpp b/stablehlo/stablehlo/dialect/ChloOps.cpp
--- stablehlo/stablehlo/dialect/ChloOps.cpp
+++ stablehlo/stablehlo/dialect/ChloOps.cpp
@@ -16,9 +16,13 @@

#include "stablehlo/dialect/ChloOps.h"

+#include <algorithm>
#include <cassert>
#include <cstdint>
+#include <iostream>
+#include <iterator>
#include <optional>
+#include <string>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -426,12 +430,12 @@
// Mode 1, where the ragged dimension is an lhs non-contracting dim (m).
// lhs : [b, m, k]
// rhs : [g, b, k, n]
-// group_sizes : [g]
+// group_sizes : [b, g]
// result : [b, m, n]
// Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k).
// lhs : [b, m, k]
// rhs : [b, k, n]
-// group_sizes : [g]
+// group_sizes : [b, g]
// result : [g, b, m, n]
// Mode 3, where the ragged dimension is an lhs/rhs batch dim (b).
// lhs : [b, m, k]
@@ -440,9 +444,18 @@
// result : [b, m, n]
// As with dot_general, the lhs and rhs can have arbitrary batching,
// contracting and non-contracting dimensions.
+// The group_sizes arg has the shape [b...,x...,g], where:
+// - b... are all the lhs batch dims before (outer-to) the lhs ragged dim,
+// - x... are,
+// - in mode 1, all the lhs non-contracting dims before the lhs ragged dim,
+// - in mode 2, all the lhs contracting dims before the lhs ragged dim, and
+// - in mode 3, empty;
+// - g is the number of groups in the lhs ragged dim.
// Additionally:
// - In all modes, the lhs must have exactly one ragged dimension.
// - In mode 1, the rhs must have exactly one group dimension.
+// - If a group_sizes of shape [g] is passed, it is broadcasted according to
+// the rules above.
LogicalResult checkRaggedDotConstraints(
std::optional<Location> location, RankedTensorType rankedLhsType,
RankedTensorType rankedRhsType, RankedTensorType rankedGroupSizesType,
@@ -452,14 +465,6 @@
ArrayRef<int64_t> rhsContractingDimensions,
ArrayRef<int64_t> lhsRaggedDimensions,
ArrayRef<int64_t> rhsGroupDimensions) {
- // Check that the group sizes has rank=1.
- if (rankedGroupSizesType.getRank() != 1) {
- return emitOptionalError(
- location, "expected rank of group_sizes of ragged dot to be 1, got ",
- rankedGroupSizesType.getRank());
- }
- auto numGroups = rankedGroupSizesType.getDimSize(0);
-
// Check that there is exactly one lhs ragged dimension.
if (lhsRaggedDimensions.size() != 1) {
return emitOptionalError(
@@ -473,6 +478,81 @@
"lhs_rank"))) {
return failure();
}
+
+ enum Mode {
+ // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [b,g] -> [b,m,n].
+ kNonContracting,
+ // Ragged contracting (k): [b,m,k], [b,k,n], [b,g] -> [g,b,m,n].
+ kContracting,
+ // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
+ kBatch
+ };
+ Mode mode;
+ if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim)) {
+ mode = kBatch;
+ } else if (llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) {
+ mode = kContracting;
+ } else {
+ mode = kNonContracting;
+ }
+
+ // Validate the shape of group_sizes.
+ {
+ // Construct the expected shape [b...,x...,g] of group_sizes.
+ SmallVector<int64_t> prefixDims;
+ prefixDims.reserve(rankedLhsType.getRank() - 1);
+ prefixDims.insert(prefixDims.end(), lhsBatchingDimensions.begin(),
+ lhsBatchingDimensions.end());
+ switch (mode) {
+ case kBatch:
+ prefixDims.resize(
+ std::distance(lhsBatchingDimensions.begin(),
+ llvm::find(lhsBatchingDimensions, lhsRaggedDim)));
+ break;
+ case kContracting:
+ prefixDims.insert(prefixDims.end(), lhsContractingDimensions.begin(),
+ llvm::find(lhsContractingDimensions, lhsRaggedDim));
+ break;
+ case kNonContracting:
+ for (int64_t i = 0; i < lhsRaggedDim; ++i) {
+ if (!llvm::is_contained(lhsBatchingDimensions, i) &&
+ !llvm::is_contained(lhsContractingDimensions, i)) {
+ prefixDims.push_back(i);
+ }
+ }
+ break;
+ }
+ SmallVector<int64_t> expectedPrefix;
+ expectedPrefix.reserve(prefixDims.size());
+ for (const int64_t dim : prefixDims) {
+ expectedPrefix.push_back(rankedLhsType.getDimSize(dim));
+ }
+
+ // Validate the actual shape, if it was passed as something other than [g].
+ if (rankedGroupSizesType.getRank() != 1) {
+ if (rankedGroupSizesType.getRank() != expectedPrefix.size() + 1) {
+ return emitOptionalError(location, "expected group_sizes to have rank ",
+ expectedPrefix.size() + 1, ", got ",
+ rankedGroupSizesType.getRank());
+ }
+ auto groupSizesShape = rankedGroupSizesType.getShape();
+ if (!std::equal(expectedPrefix.begin(), expectedPrefix.end(),
+ groupSizesShape.begin())) {
+ auto nonEmptyShapeStr = [](ArrayRef<int64_t> shape) {
+ std::string s = "";
+ for (int64_t i = 0; i < shape.size() - 1; ++i) {
+ s += std::to_string(shape[i]) + ", ";
+ }
+ return s + std::to_string(shape.back());
+ };
+ return emitOptionalError(
+ location, "group_sizes is expected to have shape [",
+ nonEmptyShapeStr(expectedPrefix), ", ", groupSizesShape.back(),
+ "], got [", nonEmptyShapeStr(groupSizesShape), "]");
+ }
+ }
+ }
+ const int64_t numGroups = rankedGroupSizesType.getShape().back();

// Validate basic properties of the rhs group dimension(s).
for (auto rhsGroupDim : rhsGroupDimensions) {
@@ -491,32 +571,34 @@
return failure();
}

- if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim) ||
- llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) {
- // Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
- // Ragged contracting (k): [b,m,k], [b,k,n], [g] -> [g,b,m,n].
- if (!rhsGroupDimensions.empty()) {
- return emitOptionalError(
- location,
- "There must be zero group dimensions in the rhs when the "
- "ragged dimension is batch or contracting.");
- }
- } else {
- // Ragged non-contracting (m): [b,m,k], [g,b,k,n], [g] -> [b,m,n].
- if (rhsGroupDimensions.size() != 1) {
- return emitOptionalError(
- location,
- "There must be exactly one group dimension in the rhs when the lhs "
- "ragged dimension is non-contracting.");
- }
- // Compare the group dimension size with the number of groups.
- const int64_t rhsGroupDim = rhsGroupDimensions[0];
- if (!hlo::verifyCompatibleDims(numGroups,
- rankedRhsType.getDimSize(rhsGroupDim))) {
- return emitOptionalError(
- location, "group_sizes is expected to have shape=[",
- rankedRhsType.getDimSize(rhsGroupDim), "], got [", numGroups, "]");
- }
+ switch (mode) {
+ case kBatch:
+ [[fallthrough]];
+ case kContracting:
+ if (!rhsGroupDimensions.empty()) {
+ return emitOptionalError(
+ location,
+ "There must be zero group dimensions in the rhs when the "
+ "ragged dimension is batch or contracting.");
+ }
+ break;
+ case kNonContracting:
+ if (rhsGroupDimensions.size() != 1) {
+ return emitOptionalError(
+ location,
+ "There must be exactly one group dimension in the rhs when the lhs "
+ "ragged dimension is non-contracting.");
+ }
+ // Compare the group dimension size with the number of groups.
+ const int64_t rhsGroupDim = rhsGroupDimensions[0];
+ if (!hlo::verifyCompatibleDims(numGroups,
+ rankedRhsType.getDimSize(rhsGroupDim))) {
+ return emitOptionalError(
+ location,
+ "rhs group dimension is expected to have size=", numGroups,
+ ", got ", rankedRhsType.getDimSize(rhsGroupDim));
+ }
+ break;
}
return success();
}
@@ -530,10 +612,10 @@
ArrayRef<int64_t> rhsContractingDimensions,
ArrayRef<int64_t> lhsRaggedDimensions,
ArrayRef<int64_t> rhsGroupDimensions) {
- // Must have already checked that group_sizes is 1-D.
- const int64_t numGroups = rankedGroupSizesType.getDimSize(0);
// Must have already checked that there is exactly one lhs ragged dim.
const int64_t lhsRaggedDim = lhsRaggedDimensions[0];
+ // Must have already checked the shape of group_sizes.
+ const int64_t numGroups = rankedGroupSizesType.getShape().back();

SmallVector<int64_t> dimensions;
// Add the group dimension to the result shape in case of ragged contracting.
diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.td b/stablehlo/stablehlo/dialect/ChloOps.td
--- stablehlo/stablehlo/dialect/ChloOps.td
+++ stablehlo/stablehlo/dialect/ChloOps.td
@@ -869,12 +869,12 @@
most one group dimension. The op has three modes, depending on the kind of
the lhs ragged dimension.

- In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`.
+ In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [b,g] -> [b,m,n]`.
Here the ragged dimension is an lhs non-contracting dimension (`m`). The
dimensions `b` and `k` represent batch and contracting dimensions
respectively. The rhs is required to have a group dimension (`g`).

- In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`.
+ In mode 2, the shape-signature is `[b,m,k], [b,k,n], [b,g] -> [g,b,m,n]`.
Here the ragged dimension is an lhs/rhs contracting dimension (`k`).

In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td
--- stablehlo/stablehlo/dialect/StablehloOps.td
+++ stablehlo/stablehlo/dialect/StablehloOps.td
Expand All @@ -340,6 +580,100 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d
HLO_FpTensor> { /*uniform_dequantize_c1, uniform_dequantize_c2*/
let summary = "UniformDequantize operation";
let description = [{
diff --ruN a/stablehlo/stablehlo/tests/ops_chlo.mlir b/stablehlo/stablehlo/tests/ops_chlo.mlir
--- stablehlo/stablehlo/tests/ops_chlo.mlir
+++ stablehlo/stablehlo/tests/ops_chlo.mlir
@@ -146,7 +146,7 @@
// -----

func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> {
- // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
+ // @expected-error@+1 {{expected group_sizes to have rank 1, got 2}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
@@ -163,8 +163,79 @@

// -----

-func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
- // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
+func.func @ragged_dot_mode1_group_sizes_broadcasted(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<19x17x11x7xf32> {
+ %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
+ ragged_dot_dimension_numbers = #chlo.ragged_dot<
+ lhs_batching_dimensions = [0],
+ rhs_batching_dimensions = [1],
+ lhs_contracting_dimensions = [3],
+ rhs_contracting_dimensions = [2],
+ lhs_ragged_dimensions = [2],
+ rhs_group_dimensions = [0]
+ >,
+ precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
+ } : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<3xi64>) -> tensor<19x17x11x7xf32>
+ func.return %0 : tensor<19x17x11x7xf32>
+}
+
+// -----
+
+func.func @ragged_dot_mode1_group_sizes_incorrect_shape(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32> {
+ // @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}}
+ %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
+ ragged_dot_dimension_numbers = #chlo.ragged_dot<
+ lhs_batching_dimensions = [0],
+ rhs_batching_dimensions = [1],
+ lhs_contracting_dimensions = [3],
+ rhs_contracting_dimensions = [2],
+ lhs_ragged_dimensions = [2],
+ rhs_group_dimensions = [0]
+ >,
+ precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
+ } : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32>
+ func.return %0 : tensor<19x17x11x7xf32>
+}
+
+// -----
+
+func.func @ragged_dot_mode2_group_sizes_incorrect_shape(%lhs : tensor<19x11x17x5xf32>, %rhs : tensor<19x17x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32> {
+ // @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}}
+ %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
+ ragged_dot_dimension_numbers = #chlo.ragged_dot<
+ lhs_batching_dimensions = [0],
+ rhs_batching_dimensions = [0],
+ lhs_contracting_dimensions = [2,3],
+ rhs_contracting_dimensions = [1,2],
+ lhs_ragged_dimensions = [3],
+ rhs_group_dimensions = []
+ >,
+ precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
+ } : (tensor<19x11x17x5xf32>, tensor<19x17x5x7xf32>, tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32>
+ func.return %0 : tensor<3x19x11x7xf32>
+}
+
+// -----
+
+func.func @ragged_dot_mode3_group_sizes_incorrect_shape(%lhs : tensor<17x19x11x5xf32>, %rhs : tensor<17x19x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<17x19x11x7xf32> {
+ // @expected-error@+1 {{group_sizes is expected to have shape [17, 3], got [19, 3]}}
+ %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
+ ragged_dot_dimension_numbers = #chlo.ragged_dot<
+ lhs_batching_dimensions = [0,1],
+ rhs_batching_dimensions = [0,1],
+ lhs_contracting_dimensions = [3],
+ rhs_contracting_dimensions = [2],
+ lhs_ragged_dimensions = [1],
+ rhs_group_dimensions = []
+ >,
+ precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
+ } : (tensor<17x19x11x5xf32>, tensor<17x19x5x7xf32>, tensor<19x3xi64>) -> tensor<17x19x11x7xf32>
+ func.return %0 : tensor<17x19x11x7xf32>
+}
+
+// -----
+
+func.func @ragged_dot_incorrect_group_dim_size(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
+ // @expected-error@+1 {{rhs group dimension is expected to have size=2, got 3}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir
--- stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir
+++ stablehlo/stablehlo/tests/ops_stablehlo_quantized.mlir
Expand Down
25 changes: 25 additions & 0 deletions xla/hlo/builder/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,31 @@ TEST(XlaBuilderTest, RaggedDotContractingWithPreferredElementType) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, BatchedRaggedDotNonContractingWithPreferredElementType) {
XlaBuilder b(TestName());
auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(S8, {19, 11, 5}), "lhs");
auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(S8, {3, 19, 5, 7}), "rhs");
auto group_sizes =
Parameter(&b, 2, ShapeUtil::MakeShape(U32, {19, 3}), "group_sizes");

DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_batch_dimensions(0);
dot_dnums.add_lhs_contracting_dimensions(2);
dot_dnums.add_rhs_batch_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(2);
RaggedDotDimensionNumbers ragged_dot_dnums;
*ragged_dot_dnums.mutable_dot_dimension_numbers() = dot_dnums;
ragged_dot_dnums.add_lhs_ragged_dimensions(1);
ragged_dot_dnums.add_rhs_group_dimensions(0);

RaggedDot(lhs, rhs, group_sizes, ragged_dot_dnums,
/*precision_config=*/nullptr, /*preferred_element_type=*/S32);
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("s32[19, 11, 7]"));
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, ConvolutionWithPreferredElementType) {
XlaBuilder b(TestName());
const Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
Expand Down
6 changes: 3 additions & 3 deletions xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2613,12 +2613,12 @@ class HloRaggedDotInstruction : public HloInstruction {
// Creates a ragged dot op with operands 'lhs', 'rhs', and 'group_sizes'.
// The `dimension_numbers` are for specifying:
// - batch and contracting dims for 'lhs'/'rhs' (as in HloDotInstruction),
// - exactly one 'lhs' ragged dimension, and
// - exactly one 'lhs' ragged dimension,
// - up to one 'rhs' group dimension.
// The op takes on one of three modes, based on the kind of the ragged dim:
// 1. [b,m,k], [g,b,k,n], [g] -> [b,m,n], where the ragged dimension is the
// 1. [b,m,k], [g,b,k,n], [b,g] -> [b,m,n], where the ragged dimension is the
// non-contracting dimension (m) of the 'lhs'.
// 2. [b,m,k], [b,k,n], [g] -> [g,b,m,n], where the ragged dimension is the
// 2. [b,m,k], [b,k,n], [b,g] -> [g,b,m,n], where the ragged dimension is the
// contracting dimension (k) of the 'lhs' and 'rhs'.
// 3. [b,m,k], [b,k,n], [g] -> [b,m,n], where the ragged dimension is the
// batch dimension (b) of the 'lhs' and 'rhs'.
Expand Down
Loading

0 comments on commit 51ba9b7

Please sign in to comment.