Skip to content

Commit

Permalink
[XLA:GPU][TRITON:XLA] Add ops and types to support TMA.
Browse files Browse the repository at this point in the history
Tiled_tensor type and 3 ops: tile, insert, and extract to Triton_XLA dialect. These are going to be used to form common abstractions that would eventually lower to normal loads/stores or TMA variants.

PiperOrigin-RevId: 726062943
  • Loading branch information
Moerafaat authored and Google-ML-Automation committed Feb 12, 2025
1 parent 0c852dc commit cc0fcd7
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 21 deletions.
36 changes: 32 additions & 4 deletions xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,11 @@ cc_library(
)

td_library(
name = "xla_td_files",
name = "xla_triton_td_files",
srcs = glob(["*.td"]),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
Expand All @@ -513,7 +514,7 @@ gentbl_cc_library(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_dialect.td",
deps = [":xla_td_files"],
deps = [":xla_triton_td_files"],
)

gentbl_cc_library(
Expand All @@ -532,14 +533,38 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_ops.td",
deps = [
":xla_td_files",
":xla_triton_td_files",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@triton//:td_files",
],
)

gentbl_cc_library(
name = "xla_triton_types_inc_gen",
strip_include_prefix = ".",
tbl_outs = [
(
[
"-gen-typedef-decls",
"-typedefs-dialect=triton_xla",
],
"xla_triton_types.h.inc",
),
(
[
"-gen-typedef-defs",
"-typedefs-dialect=triton_xla",
],
"xla_triton_types.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_types.td",
deps = [":xla_triton_td_files"],
)

gentbl_cc_library(
name = "xla_triton_attrs_inc_gen",
strip_include_prefix = ".",
Expand All @@ -562,7 +587,7 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_attrs.td",
deps = [
":xla_td_files",
":xla_triton_td_files",
"@triton//:td_files",
],
)
Expand All @@ -571,10 +596,13 @@ cc_library(
name = "xla_triton",
srcs = [
"xla_triton_attrs.cc",
"xla_triton_dialect.cc",
"xla_triton_ops.cc",
"xla_triton_types.cc",
],
hdrs = ["xla_triton_ops.h"],
deps = [
"xla_triton_types_inc_gen",
":xla_triton_attrs_inc_gen",
":xla_triton_dialect_inc_gen",
":xla_triton_ops_inc_gen",
Expand Down
43 changes: 43 additions & 0 deletions xla/backends/gpu/codegen/triton/xla_triton_dialect.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h"

#define GET_ATTRDEF_CLASSES
#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.cc.inc"
#define GET_TYPEDEF_CLASSES
#include "xla/backends/gpu/codegen/triton/xla_triton_types.cc.inc"

namespace mlir::triton::xla {

void XlaTritonDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.cc.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.cc.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "xla/backends/gpu/codegen/triton/xla_triton_types.cc.inc"
>();
}

} // namespace mlir::triton::xla
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/triton/xla_triton_dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def XlaTritonDialect : Dialect {

let cppNamespace = "::mlir::triton::xla";
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_DIALECT_TD_
16 changes: 0 additions & 16 deletions xla/backends/gpu/codegen/triton/xla_triton_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ limitations under the License.
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"

#define GET_ATTRDEF_CLASSES
#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.cc.inc"

using mlir::Dialect;
using mlir::DictionaryAttr;
using mlir::Location;
Expand All @@ -51,19 +48,6 @@ using mlir::ValueRange;
using mlir::triton::gpu::TensorOrMemDesc;

namespace mlir::triton::xla {

// TODO (b/350928208): Move initialize to xla_triton_dialect.cc.
void XlaTritonDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.cc.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.cc.inc"
>();
}

LogicalResult SparseDotOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/gpu/codegen/triton/xla_triton_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class DotLike<triton::xla::SparseDotOp>

#define GET_ATTRDEF_CLASSES
#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.h.inc"
#define GET_TYPEDEF_CLASSES
#include "xla/backends/gpu/codegen/triton/xla_triton_types.h.inc"
#define GET_OP_CLASSES
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h.inc"

Expand Down
103 changes: 102 additions & 1 deletion xla/backends/gpu/codegen/triton/xla_triton_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/ViewLikeInterface.td" // OffsetSizeAndStrideOpInterface
include "xla/backends/gpu/codegen/triton/xla_triton_dialect.td"
include "xla/backends/gpu/codegen/triton/xla_triton_types.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
include "xla/backends/gpu/codegen/triton/xla_triton_dialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"

// -----------------------------------------------------------------------------
// Triton XLA Ops
// -----------------------------------------------------------------------------

class TTXLA_Op<string mnemonic, list<Trait> traits = []> :
Op<XlaTritonDialect, mnemonic, traits> {
}
Expand All @@ -47,4 +53,99 @@ def TTXLA_SparseDotOp : TTXLA_Op<"sparse_dot", [
let hasVerifier = 1;
}



def TTXLA_TileOp : TTXLA_Op<"tile", [Pure]> {
let summary = "Capture the tiling information of a tensor.";
let description = [{
Usage:
This op is used to capture the tiling information of a tensor. The tiling
information can later be used with triton_xla.extract and
triton_xla.insert.

Example:
```
%arg0: tensor<128x320xbf16>
...
%tiled_arg0 = triton_xla.tile %arg0 [0, 0] [1, 1] [16, 64]
: tensor<120x320xbf16> -> !triton_xla.tiled_tensor<16x64xbf16>
```
}];

let arguments = (ins
AnyRankedTensor:$tensor,
DenseI64ArrayAttr:$shape,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$offsets
);

let results = (outs TTXLA_TiledTensorType:$tiled_tensor);

let assemblyFormat = [{
$tensor $offsets $strides $shape attr-dict `:` type($tensor) `->` qualified(type($tiled_tensor))
}];
}

def TTXLA_ExtractOp : TTXLA_Op<"extract", [Pure]> {
let summary = "Extract a tile from a tensor.";
let description = [{
Usage:
This op is used to extract a tile from a tensor. The tiling information
can be captured using triton_xla.tile and passed to this op.

Example:
```
%tiled_arg0 = triton_xla.tile %arg0 [0, 0] [1, 1] [16, 64]
: tensor<120x320xbf16> -> !triton_xla.tiled_tensor<16x64xbf16>
...
%extracted_tensor = triton_xla.extract %tiled_arg0 [%cst, %cst]
: !triton_xla.tiled_tensor<16x64xbf16> -> tensor<16x64xbf16>
}];

let arguments = (ins
TTXLA_TiledTensorType:$tiled_tensor,
Variadic<Index>:$offsets
);

let results = (outs AnyRankedTensor:$extracted_tensor);

let assemblyFormat = [{
$tiled_tensor `[` $offsets `]` attr-dict `:`
qualified(type($tiled_tensor)) `->` type($extracted_tensor)
}];
}

def TTXLA_InsertOp : TTXLA_Op<"insert", [Pure]> {
let summary = "Insert a tile into a tensor.";
let description = [{
Usage:
This op is used to insert a tile into a tensor. The tiling information
can be captured using triton_xla.tile and passed to this op.

Example:
```
%tiled_arg2 = triton_xla.tile %tiled_arg2 [0, 0] [1, 1] [16, 64]
: !triton_xla.tiled_tensor<16x64xbf16>
...
%inserted_tensor = triton_xla.insert %arg0 into %tiled_arg2 [%cst, %cst]
: tensor<16x64xbf16> into !triton_xla.tiled_tensor<16x64xbf16>
-> tensor<16x64xbf16>
```
}];

let arguments = (ins
AnyRankedTensor:$source_tensor,
TTXLA_TiledTensorType:$dest_tiled_tensor,
Variadic<Index>:$offsets
);

let results = (outs AnyRankedTensor:$dest_tensor);

let assemblyFormat = [{
$source_tensor `into` $dest_tiled_tensor `[` $offsets `]` attr-dict `:`
type($source_tensor) `into` qualified(type($dest_tiled_tensor))
`->` type($dest_tensor)
}];
}

#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_TD_
42 changes: 42 additions & 0 deletions xla/backends/gpu/codegen/triton/xla_triton_types.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
#include "mlir/IR/Types.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h"

namespace mlir::triton::xla {

mlir::Type TiledTensorType::parse(mlir::AsmParser &parser) {
mlir::SmallVector<int64_t, 4> shape;
mlir::Type type;
if (parser.parseLess() ||
parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
parser.parseType(type) || parser.parseGreater()) {
return {};
}
return TiledTensorType::get(parser.getContext(), shape, type);
}

void TiledTensorType::print(mlir::AsmPrinter &printer) const {
printer << "<";
printer.printDimensionList(getShape());
printer << "x" << getElementType() << ">";
}

} // namespace mlir::triton::xla
59 changes: 59 additions & 0 deletions xla/backends/gpu/codegen/triton/xla_triton_types.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_TYPES_TD_
#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_TYPES_TD_

include "xla/backends/gpu/codegen/triton/xla_triton_dialect.td"
include "mlir/IR/BuiltinTypes.td" // ValueSemantics
include "mlir/IR/BuiltinTypeInterfaces.td"

// -----------------------------------------------------------------------------
// TiledTensorType
// -----------------------------------------------------------------------------

class TTXLA_Type<string name, string typeMnemonic, list<Trait> traits = []> :
TypeDef<XlaTritonDialect, name, traits> {
let mnemonic = typeMnemonic;
}

def TTXLA_TiledTensorType : TTXLA_Type<"TiledTensor", "tiled_tensor", [
ShapedTypeInterface, ValueSemantics]> {
let summary = "A tile of a tensor.";
let description = [{
Usage:
This type will typically be constructed via triton_xla.tile op. The intent
is to capture tiling information and pass it along to other ops such as
triton_xla.extract and triton_xla.insert. Refer to the ops for examples.
}];

let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType
);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
TiledTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return TiledTensorType::get(getContext(), shape.value_or(getShape()),
elementType);
}
bool hasRank() const { return true; }
}];
}

#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_TYPES_TD_
Loading

0 comments on commit cc0fcd7

Please sign in to comment.