Skip to content

Commit

Permalink
Integrate Triton up to [c5036b9b](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Feb 14, 2025
1 parent e9063a9 commit 84f798c
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 332 deletions.
27 changes: 0 additions & 27 deletions third_party/triton/temporary/addition_to_sparsity.patch

This file was deleted.

81 changes: 0 additions & 81 deletions third_party/triton/temporary/enable_peer_access.patch

This file was deleted.

13 changes: 0 additions & 13 deletions third_party/triton/temporary/fix_assert.patch

This file was deleted.

63 changes: 0 additions & 63 deletions third_party/triton/temporary/fix_fence_insertion_race.patch

This file was deleted.

46 changes: 0 additions & 46 deletions third_party/triton/temporary/mlir_types.patch

This file was deleted.

2 changes: 0 additions & 2 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ those to this list.
"""

temporary_patch_list = [
"//third_party/triton:temporary/fix_fence_insertion_race.patch",
"//third_party/triton:temporary/enable_peer_access.patch",
"//third_party/triton:temporary/sm120.patch",
# Add new patches just above this line
]
11 changes: 10 additions & 1 deletion xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tsl/platform/statusor.h"
#include "third_party/triton/include/triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

Expand Down Expand Up @@ -91,7 +91,11 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
pm->addPass(mlir::createCSEPass());

if (cc.IsAtLeastBlackwell()) {
pm->addPass(mt::gpu::createTritonGPUFuseNestedLoops());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createLoopInvariantCodeMotionPass());
pm->addPass(mt::gpu::createTritonGPUOptimizeAccumulatorInit());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mt::gpu::createTritonGPULoopScheduling({num_stages}));
pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages}));
pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf());
Expand All @@ -101,10 +105,15 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
} else if (cc.IsAtLeastAmpere()) {
// Even though we don't run on pre-Ampere architectures anymore, we keep
// this check for consistency with the upstream pipeline
pm->addPass(mt::gpu::createTritonGPUFuseNestedLoops());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createLoopInvariantCodeMotionPass());
pm->addPass(mt::gpu::createTritonGPUOptimizeAccumulatorInit());
pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf());
pm->addPass(mt::gpu::createTritonGPULoopScheduling({num_stages}));
pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages}));
} else {
pm->addPass(mlir::createLoopInvariantCodeMotionPass());
}
pm->addPass(mt::gpu::createTritonGPUPrefetch());
pm->addPass(
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
num_stages, /*stream_prefetch=*/true));
pm->addPass(mlir::createCanonicalizerPass());
}
pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass());
pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass("default"));
pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication());
Expand Down Expand Up @@ -134,7 +134,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createSymbolDCEPass());
pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
cc.gfx_version(), num_stages, "default"));
cc.gfx_version(), num_stages));
pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true));
// There is no clusters in ROCm for now.
out_cluster_info.clusterDimX = 1;
Expand Down
18 changes: 0 additions & 18 deletions xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@ limitations under the License.
==============================================================================*/

#include <cstdint>
#include <optional>

#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
Expand All @@ -31,20 +27,6 @@ limitations under the License.
namespace mlir::triton::xla {

//--- SparseDotMetaEncodingAttr ---
unsigned SparseDotMetaEncodingAttr::getTotalElemsPerThread(
ArrayRef<int64_t> shape, Type eltTy) const {
constexpr int kMetadataElementsPerWarp = 16;
auto mmaLayout = mlir::cast<gpu::NvidiaMmaEncodingAttr>(getParent());
return product<int64_t>(shape) /
(mmaLayout.getWarpsPerCTA()[0] * kMetadataElementsPerWarp);
}

SmallVector<unsigned> SparseDotMetaEncodingAttr::getElemsPerThread(
ArrayRef<int64_t> shape, Type eltTy) const {
llvm_unreachable("getElemsPerThread is not supported for sparse dot meta");
return SmallVector<unsigned>();
}

SmallVector<unsigned> SparseDotMetaEncodingAttr::getCTAsPerCGA() const {
return gpu::getCTAsPerCGA(getParent());
}
Expand Down
5 changes: 5 additions & 0 deletions xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h"

#include <cassert>
#include <optional>

#include "llvm/ADT/SmallVector.h"
Expand All @@ -30,6 +31,8 @@ limitations under the License.
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.cc.inc"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
Expand Down Expand Up @@ -77,6 +80,8 @@ LogicalResult SparseDotOp::inferReturnTypes(
return success();
}

bool SparseDotOp::verifyDims() { return true; }

LogicalResult SparseDotOp::verify() {
// Implied properties of 2:4 sparse dots.
constexpr int kContractingFactor = 2;
Expand Down
21 changes: 3 additions & 18 deletions xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,11 @@ limitations under the License.
#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep
#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.h.inc" // IWYU pragma: keep
#include "triton/Dialect/Triton/IR/Dialect.h" // IWYU pragma: keep
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" // IWYU pragma: keep
#include "triton/Dialect/Triton/IR/Dialect.h" // IWYU pragma: keep
#include "triton/Dialect/Triton/IR/OpInterfaces.h" // IWYU pragma: keep
#include "triton/Dialect/TritonGPU/IR/Dialect.h" // IWYU pragma: keep
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" // IWYU pragma: keep

namespace mlir::triton::xla {
class SparseDotOp;
}
namespace mlir::OpTrait {
// Template specialization for DotLike<SparseDotOp> to skip verification, which
// would fail because the sparse dot has different shapes and operands.
template <>
class DotLike<triton::xla::SparseDotOp>
: public TraitBase<triton::xla::SparseDotOp, DotLike> {
public:
// TODO (b/350928208) : Add a proper verifier for SparseDotOp.
static LogicalResult verifyTrait(Operation *op) { return success(); }
};
} // namespace mlir::OpTrait

#define GET_ATTRDEF_CLASSES
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.h.inc"
#define GET_TYPEDEF_CLASSES
Expand Down
Loading

0 comments on commit 84f798c

Please sign in to comment.