diff --git a/xla/backends/gpu/codegen/triton/BUILD b/xla/backends/gpu/codegen/triton/BUILD index f9690e7e5dc27..471235623824e 100644 --- a/xla/backends/gpu/codegen/triton/BUILD +++ b/xla/backends/gpu/codegen/triton/BUILD @@ -1,4 +1,3 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//xla:xla.bzl", "xla_cc_test") @@ -143,7 +142,7 @@ cc_library( "@com_google_absl//absl/status", "@llvm-project//mlir:Pass", ] + if_gpu_is_configured([ - ":xla_triton_passes", + "//xla/backends/gpu/codegen/triton/transforms:passes", "@com_google_absl//absl/strings:str_format", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ControlFlowToLLVM", @@ -184,10 +183,7 @@ cc_library( ":compilation_pipeline", ":emitter_helpers", ":fusion_emitter_legacy_matmul", - ":passes", ":support", - ":xla_triton", - ":xla_triton_passes", "//xla:autotuning_proto_cc", "//xla:permutation_util", "//xla:shape_util", @@ -197,6 +193,8 @@ cc_library( "//xla:xla_proto_cc", "//xla/backends/gpu/codegen/emitters/ir:xla_gpu", "//xla/backends/gpu/codegen/emitters/transforms:passes", + "//xla/backends/gpu/codegen/triton/ir:triton_xla", + "//xla/backends/gpu/codegen/triton/transforms:passes", "//xla/codegen:emitter_loc_op_builder", "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/emitters/ir:xla", @@ -222,6 +220,7 @@ cc_library( "//xla/stream_executor:launch_dim", "//xla/stream_executor/gpu:tma_metadata", "//xla/tools:hlo_decomposer_lib", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -288,13 +287,13 @@ cc_library( hdrs = ["fusion_emitter_legacy_matmul.h"], deps = [ ":emitter_helpers", - ":xla_triton", "//xla:comparison_util", "//xla:literal", "//xla:shape_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", @@ -397,224 +396,6 @@ xla_cc_test( ], ) -gentbl_cc_library( - name = "passes_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TritonFusionTransforms", - ], - "passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "passes.td", - visibility = ["//visibility:private"], - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "passes", - srcs = [ - "generalize_kernel_signature.cc", - ], - hdrs = ["passes.h"], - deps = [ - ":passes_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], -) - -gentbl_cc_library( - name = "xla_triton_passes_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TritonFusionTransforms", - "-attrdefs-dialect=triton_xla", - ], - "xla_triton_passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "xla_triton_passes.td", - visibility = ["//visibility:private"], - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "xla_triton_passes", - srcs = [ - "xla_triton_int4_passes.cc", - "xla_triton_prevent_mmav3_loop_unrolling_pass.cc", - "xla_triton_sparse_passes.cc", - ], - hdrs = [ - "xla_triton_passes.h", - ], - deps = [ - ":xla_triton", - ":xla_triton_passes_inc_gen", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@triton//:TritonDialects", - "@triton//:TritonGPUToLLVM", - "@triton//:TritonGPUTransforms", - "@triton//:TritonToTritonGPU", - "@triton//third_party/nvidia:NVGPUDialect", - "@triton//third_party/nvidia:NVGPUToLLVM", - "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", - ], -) - -td_library( - name = "xla_triton_td_files", - srcs = glob(["*.td"]), - includes = ["."], - deps = [ - "@llvm-project//mlir:BuiltinDialectTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "xla_triton_dialect_inc_gen", - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-dialect-decls"], - "xla_triton_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "xla_triton_dialect.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "xla_triton_dialect.td", - deps = [":xla_triton_td_files"], -) - -gentbl_cc_library( - name = "xla_triton_ops_inc_gen", - strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], - "xla_triton_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "xla_triton_ops.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "xla_triton_ops.td", - deps = [ - ":xla_triton_td_files", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - "@triton//:td_files", - ], -) - -gentbl_cc_library( - name = "xla_triton_types_inc_gen", - strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - "-typedefs-dialect=triton_xla", - ], - "xla_triton_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "-typedefs-dialect=triton_xla", - ], - "xla_triton_types.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "xla_triton_types.td", - deps = [":xla_triton_td_files"], -) - -gentbl_cc_library( - name = "xla_triton_attrs_inc_gen", - strip_include_prefix = ".", - tbl_outs = [ - ( - [ - "-gen-attrdef-decls", - "-attrdefs-dialect=triton_xla", - ], - "xla_triton_attrs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "-attrdefs-dialect=triton_xla", - ], - "xla_triton_attrs.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "xla_triton_attrs.td", - deps = [ - ":xla_triton_td_files", - "@triton//:td_files", - ], -) - -cc_library( - name = "xla_triton", - srcs = [ - "xla_triton_attrs.cc", - "xla_triton_dialect.cc", - "xla_triton_ops.cc", - "xla_triton_types.cc", - ], - hdrs = ["xla_triton_ops.h"], - deps = [ - "xla_triton_types_inc_gen", - ":xla_triton_attrs_inc_gen", - ":xla_triton_dialect_inc_gen", - ":xla_triton_ops_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:Support", - "@triton//:TritonDialects", - ], -) - xla_test( name = "fusion_emitter_deviceless_test", srcs = ["fusion_emitter_deviceless_test.cc"], diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc index 83598d4ca102f..2074aad316123 100644 --- a/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc +++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc @@ -24,7 +24,7 @@ limitations under the License. #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index fe4d0f7a13a87..47128affab133 100644 --- a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/hlo_module_config.h" diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 8f24b88ce60ba..042bbab62dd7f 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -86,8 +86,8 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/compilation_pipeline.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" #include "xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h" -#include "xla/backends/gpu/codegen/triton/passes.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/codegen/emitters/ir/xla_ops.h" @@ -120,12 +120,11 @@ limitations under the License. #include "xla/stream_executor/gpu/tma_metadata.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/path.h" -#include "tsl/platform/statusor.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -1313,7 +1312,7 @@ absl::StatusOr CompileTritonToLLVM( } // Triton generates pointers to the global address space, while XLA needs a // kernel signature with pointers to the generic address space. - pm.addPass(CreateGeneralizeKernelSignaturePass()); + pm.addPass(mlir::triton::xla::CreateGeneralizeKernelSignaturePass()); // llvm::Linker::linkModules() segfaults if we don't strip locations. pm.addPass(mlir::createStripDebugInfoPass()); diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc index 1c44f401eda1b..d997af8068644 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc @@ -54,7 +54,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" diff --git a/xla/backends/gpu/codegen/triton/ir/BUILD b/xla/backends/gpu/codegen/triton/ir/BUILD new file mode 100644 index 0000000000000..c2c8a470f2de8 --- /dev/null +++ b/xla/backends/gpu/codegen/triton/ir/BUILD @@ -0,0 +1,141 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +td_library( + name = "triton_xla_td_files", + srcs = glob(["*.td"]), + includes = ["."], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_xla_dialect_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-dialect-decls"], + "triton_xla_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "triton_xla_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "triton_xla_dialect.td", + deps = [":triton_xla_td_files"], +) + +gentbl_cc_library( + name = "triton_xla_ops_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-op-decls"], + "triton_xla_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "triton_xla_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "triton_xla_ops.td", + deps = [ + ":triton_xla_td_files", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@triton//:td_files", + ], +) + +gentbl_cc_library( + name = "triton_xla_types_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + "-typedefs-dialect=triton_xla", + ], + "triton_xla_types.h.inc", + ), + ( + [ + "-gen-typedef-defs", + "-typedefs-dialect=triton_xla", + ], + "triton_xla_types.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "triton_xla_types.td", + deps = [":triton_xla_td_files"], +) + +gentbl_cc_library( + name = "triton_xla_attrs_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-attrdef-decls", + "-attrdefs-dialect=triton_xla", + ], + "triton_xla_attrs.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + "-attrdefs-dialect=triton_xla", + ], + "triton_xla_attrs.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "triton_xla_attrs.td", + deps = [ + ":triton_xla_td_files", + "@triton//:td_files", + ], +) + +cc_library( + name = "triton_xla", + srcs = [ + "triton_xla_attrs.cc", + "triton_xla_dialect.cc", + "triton_xla_ops.cc", + "triton_xla_types.cc", + ], + hdrs = ["triton_xla_ops.h"], + deps = [ + "triton_xla_types_inc_gen", + ":triton_xla_attrs_inc_gen", + ":triton_xla_dialect_inc_gen", + ":triton_xla_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@triton//:TritonDialects", + ], +) diff --git a/xla/backends/gpu/codegen/triton/xla_triton_attrs.cc b/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc similarity index 98% rename from xla/backends/gpu/codegen/triton/xla_triton_attrs.cc rename to xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc index 66d658c7407d3..23e82b9575740 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_attrs.cc +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // IWYU pragma: keep #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" diff --git a/xla/backends/gpu/codegen/triton/xla_triton_attrs.td b/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.td similarity index 81% rename from xla/backends/gpu/codegen/triton/xla_triton_attrs.td rename to xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.td index 969ec66631fd5..4cb168bfc48e9 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_attrs.td +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.td @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_ATTRS_TD_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_ATTRS_TD_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_ATTRS_TD_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_ATTRS_TD_ include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" -include "xla/backends/gpu/codegen/triton/xla_triton_dialect.td" +include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td" def TTXLA_SparseDotMetaEncodingAttr : DistributedEncoding<"SparseDotMetaEncoding", "sparse_dot_meta_encoding", [], XlaTritonDialect> { @@ -33,4 +33,4 @@ def TTXLA_SparseDotMetaEncodingAttr : DistributedEncoding<"SparseDotMetaEncoding } -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_ATTRS_TD_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_ATTRS_TD_ diff --git a/xla/backends/gpu/codegen/triton/xla_triton_dialect.cc b/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.cc similarity index 73% rename from xla/backends/gpu/codegen/triton/xla_triton_dialect.cc rename to xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.cc index e76a868929f95..c9253ca9e5832 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_dialect.cc +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.cc @@ -16,27 +16,27 @@ limitations under the License. #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #define GET_ATTRDEF_CLASSES -#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc.inc" #define GET_TYPEDEF_CLASSES -#include "xla/backends/gpu/codegen/triton/xla_triton_types.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_types.cc.inc" namespace mlir::triton::xla { void XlaTritonDialect::initialize() { addOperations< #define GET_OP_LIST -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc.inc" >(); addAttributes< #define GET_ATTRDEF_LIST -#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.cc.inc" >(); addTypes< #define GET_TYPEDEF_LIST -#include "xla/backends/gpu/codegen/triton/xla_triton_types.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_types.cc.inc" >(); } diff --git a/xla/backends/gpu/codegen/triton/xla_triton_dialect.td b/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td similarity index 86% rename from xla/backends/gpu/codegen/triton/xla_triton_dialect.td rename to xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td index f3283a134e76f..08a513ed0ea26 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_dialect.td +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_DIALECT_TD_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_DIALECT_TD_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_DIALECT_TD_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_DIALECT_TD_ include "mlir/IR/DialectBase.td" @@ -39,4 +39,4 @@ def XlaTritonDialect : Dialect { let useDefaultTypePrinterParser = 1; } -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_DIALECT_TD_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_DIALECT_TD_ diff --git a/xla/backends/gpu/codegen/triton/xla_triton_ops.cc b/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc similarity index 96% rename from xla/backends/gpu/codegen/triton/xla_triton_ops.cc rename to xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc index a6428e9bb1159..f12aa7c484b71 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_ops.cc +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include @@ -30,7 +30,7 @@ limitations under the License. #include "mlir/IR/Region.h" #include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep #include "mlir/IR/ValueRange.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_dialect.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.cc.inc" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Types.h" @@ -137,4 +137,4 @@ LogicalResult SparseDotOp::verify() { } // namespace mlir::triton::xla #define GET_OP_CLASSES -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.cc.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.cc.inc" diff --git a/xla/backends/gpu/codegen/triton/xla_triton_ops.h b/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h similarity index 80% rename from xla/backends/gpu/codegen/triton/xla_triton_ops.h rename to xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h index a6d71bcd8dc30..a540cf57e52b8 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_ops.h +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_H_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_H_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_H_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_H_ #include "mlir/IR/Attributes.h" // IWYU pragma: keep #include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep @@ -23,7 +23,7 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep -#include "xla/backends/gpu/codegen/triton/xla_triton_dialect.h.inc" // IWYU pragma: keep +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.h.inc" // IWYU pragma: keep #include "triton/Dialect/Triton/IR/Dialect.h" // IWYU pragma: keep #include "triton/Dialect/Triton/IR/Traits.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" // IWYU pragma: keep @@ -45,10 +45,10 @@ class DotLike } // namespace mlir::OpTrait #define GET_ATTRDEF_CLASSES -#include "xla/backends/gpu/codegen/triton/xla_triton_attrs.h.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_attrs.h.inc" #define GET_TYPEDEF_CLASSES -#include "xla/backends/gpu/codegen/triton/xla_triton_types.h.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_types.h.inc" #define GET_OP_CLASSES -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h.inc" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h.inc" -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_H_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_H_ diff --git a/xla/backends/gpu/codegen/triton/xla_triton_ops.td b/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td similarity index 93% rename from xla/backends/gpu/codegen/triton/xla_triton_ops.td rename to xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td index 8006761b50b4c..2c4edc20fa1df 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_ops.td +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_ops.td @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_TD_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_TD_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_TD_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_TD_ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType include "mlir/Interfaces/ViewLikeInterface.td" // OffsetSizeAndStrideOpInterface -include "xla/backends/gpu/codegen/triton/xla_triton_dialect.td" -include "xla/backends/gpu/codegen/triton/xla_triton_types.td" +include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td" +include "xla/backends/gpu/codegen/triton/ir/triton_xla_types.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" include "triton/Dialect/Triton/IR/TritonTypes.td" @@ -148,4 +148,4 @@ def TTXLA_InsertOp : TTXLA_Op<"insert", [Pure]> { }]; } -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_OPS_TD_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_TD_ diff --git a/xla/backends/gpu/codegen/triton/xla_triton_types.cc b/xla/backends/gpu/codegen/triton/ir/triton_xla_types.cc similarity index 95% rename from xla/backends/gpu/codegen/triton/xla_triton_types.cc rename to xla/backends/gpu/codegen/triton/ir/triton_xla_types.cc index 1b46e1f814421..1c7506b6a0f67 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_types.cc +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_types.cc @@ -18,7 +18,7 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/IR/Types.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" namespace mlir::triton::xla { diff --git a/xla/backends/gpu/codegen/triton/xla_triton_types.td b/xla/backends/gpu/codegen/triton/ir/triton_xla_types.td similarity index 88% rename from xla/backends/gpu/codegen/triton/xla_triton_types.td rename to xla/backends/gpu/codegen/triton/ir/triton_xla_types.td index e553d081d965d..cbf4a0c6ea8d9 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_types.td +++ b/xla/backends/gpu/codegen/triton/ir/triton_xla_types.td @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_TYPES_TD_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_TYPES_TD_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_TYPES_TD_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_TYPES_TD_ -include "xla/backends/gpu/codegen/triton/xla_triton_dialect.td" +include "xla/backends/gpu/codegen/triton/ir/triton_xla_dialect.td" include "mlir/IR/BuiltinTypes.td" // ValueSemantics include "mlir/IR/BuiltinTypeInterfaces.td" @@ -56,4 +56,4 @@ def TTXLA_TiledTensorType : TTXLA_Type<"TiledTensor", "tiled_tensor", [ }]; } -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_TYPES_TD_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_TYPES_TD_ diff --git a/xla/backends/gpu/codegen/triton/passes.h b/xla/backends/gpu/codegen/triton/passes.h deleted file mode 100644 index c2c450e344e24..0000000000000 --- a/xla/backends/gpu/codegen/triton/passes.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_H_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_H_ - -#include - -#include "mlir/Pass/Pass.h" - -namespace xla::gpu { - -#define GEN_PASS_DECL -#include "xla/backends/gpu/codegen/triton/passes.h.inc" - -std::unique_ptr CreateGeneralizeKernelSignaturePass(); - -#define GEN_PASS_REGISTRATION -#include "xla/backends/gpu/codegen/triton/passes.h.inc" - -} // namespace xla::gpu - -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_H_ diff --git a/xla/backends/gpu/codegen/triton/passes.td b/xla/backends/gpu/codegen/triton/passes.td deleted file mode 100644 index bd63c403e0c71..0000000000000 --- a/xla/backends/gpu/codegen/triton/passes.td +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ - -include "mlir/Pass/PassBase.td" - -def GeneralizeKernelSignaturePass - : Pass<"generalize-kernel-signature"> { - let summary = "Rewrite kernels to use generic data pointer arguments."; - let description = [{ - Rewrite signatures of kernel functions from global pointers to generic - pointers and cast them to global ones within the kernel. - }]; - let constructor = "CreateGeneralizeKernelSignaturePass()"; -} - -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ diff --git a/xla/backends/gpu/codegen/triton/transforms/BUILD b/xla/backends/gpu/codegen/triton/transforms/BUILD new file mode 100644 index 0000000000000..940d255323d80 --- /dev/null +++ b/xla/backends/gpu/codegen/triton/transforms/BUILD @@ -0,0 +1,71 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TritonXlaTransforms", + "-attrdefs-dialect=triton_xla", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + visibility = ["//visibility:private"], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "passes", + srcs = [ + "generalize_kernel_signature.cc", + "int4_passes.cc", + "prevent_mmav3_loop_unrolling_pass.cc", + "sparse_passes.cc", + ], + hdrs = ["passes.h"], + deps = [ + ":passes_inc_gen", + "//xla/backends/gpu/codegen/triton/ir:triton_xla", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@triton//:TritonDialects", + "@triton//:TritonGPUToLLVM", + "@triton//:TritonGPUTransforms", + "@triton//:TritonToTritonGPU", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/xla/backends/gpu/codegen/triton/generalize_kernel_signature.cc b/xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc similarity index 67% rename from xla/backends/gpu/codegen/triton/generalize_kernel_signature.cc rename to xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc index 90c56013c7133..e05ee6b4ed673 100644 --- a/xla/backends/gpu/codegen/triton/generalize_kernel_signature.cc +++ b/xla/backends/gpu/codegen/triton/transforms/generalize_kernel_signature.cc @@ -30,15 +30,14 @@ limitations under the License. #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/triton/passes.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" -namespace xla::gpu { +namespace mlir::triton::xla { namespace { // Extract additional attributes from an LLVM function that are not passed // to the builder directly. -mlir::SmallVector GetExtraAttrs( - mlir::LLVM::LLVMFuncOp func) { +SmallVector GetExtraAttrs(LLVM::LLVMFuncOp func) { llvm::StringSet<> registered_attr_names{ func.getSymNameAttrName().getValue(), func.getFunctionTypeAttrName().getValue(), @@ -48,48 +47,46 @@ mlir::SmallVector GetExtraAttrs( func.getArgAttrsAttrName().getValue(), func.getFunctionEntryCountAttrName().getValue()}; return llvm::to_vector( - llvm::make_filter_range(func->getAttrs(), [&](mlir::NamedAttribute attr) { + llvm::make_filter_range(func->getAttrs(), [&](NamedAttribute attr) { return !registered_attr_names.contains(attr.getName().getValue()); })); } // Strip address spaces from function parameters. -void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, - mlir::LLVM::LLVMFuncOp func) { +void StripParameterAddressSpaces(RewriterBase& rewriter, + LLVM::LLVMFuncOp func) { // Figure out what the new signature should be. - mlir::LLVM::LLVMFunctionType func_ty = func.getFunctionType(); - mlir::SmallVector generic_func_params( - llvm::map_range(func_ty.getParams(), [](mlir::Type type) -> mlir::Type { - auto ptr_ty = mlir::dyn_cast(type); + LLVM::LLVMFunctionType func_ty = func.getFunctionType(); + SmallVector generic_func_params( + llvm::map_range(func_ty.getParams(), [](Type type) -> Type { + auto ptr_ty = dyn_cast(type); if (!ptr_ty) return type; - if (ptr_ty.getAddressSpace() != mlir::NVVM::kGlobalMemorySpace) - return type; - return mlir::LLVM::LLVMPointerType::get(ptr_ty.getContext()); + if (ptr_ty.getAddressSpace() != NVVM::kGlobalMemorySpace) return type; + return LLVM::LLVMPointerType::get(ptr_ty.getContext()); })); - mlir::LLVM::LLVMFunctionType generic_func_ty = + LLVM::LLVMFunctionType generic_func_ty = func_ty.clone(generic_func_params, func_ty.getReturnTypes()); // Create a function with the new signature. - mlir::SmallVector arg_attrs(llvm::map_range( - func.getArgAttrsAttr().getValue(), [](mlir::Attribute attr) { - return mlir::cast(attr); - })); - auto generic_func = rewriter.create( + SmallVector arg_attrs(llvm::map_range( + func.getArgAttrsAttr().getValue(), + [](Attribute attr) { return cast(attr); })); + auto generic_func = rewriter.create( func.getLoc(), func.getSymName(), generic_func_ty, func.getLinkage(), func.getDsoLocal(), func.getCConv(), /*comdat=*/nullptr, GetExtraAttrs(func), arg_attrs, func.getFunctionEntryCount()); // Convert generic address spaces back to original ones within the function // body. - mlir::Block* entry = generic_func.addEntryBlock(rewriter); + Block* entry = generic_func.addEntryBlock(rewriter); rewriter.setInsertionPointToEnd(entry); - mlir::SmallVector converted_args; + SmallVector converted_args; for (auto [arg, type] : llvm::zip(generic_func.getArguments(), func_ty.getParams())) { - mlir::Value converted = arg; + Value converted = arg; if (arg.getType() != type) { converted = - rewriter.create(arg.getLoc(), type, arg); + rewriter.create(arg.getLoc(), type, arg); } converted_args.push_back(converted); } @@ -102,7 +99,7 @@ void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, } #define GEN_PASS_DEF_GENERALIZEKERNELSIGNATUREPASS -#include "xla/backends/gpu/codegen/triton/passes.h.inc" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" // Rewrite signatures of kernel functions to use generic data pointers and // cast them to global ones within the kernel. @@ -110,9 +107,9 @@ struct GeneralizeKernelSignaturePass : public impl::GeneralizeKernelSignaturePassBase< GeneralizeKernelSignaturePass> { void runOnOperation() override { - mlir::IRRewriter rewriter(&getContext()); - getOperation()->walk([&](mlir::LLVM::LLVMFuncOp func) { - if (!func->hasAttr(mlir::NVVM::NVVMDialect::getKernelFuncAttrName())) { + IRRewriter rewriter(&getContext()); + getOperation()->walk([&](LLVM::LLVMFuncOp func) { + if (!func->hasAttr(NVVM::NVVMDialect::getKernelFuncAttrName())) { return; } rewriter.setInsertionPointAfter(func); @@ -123,8 +120,8 @@ struct GeneralizeKernelSignaturePass } // namespace -std::unique_ptr CreateGeneralizeKernelSignaturePass() { +std::unique_ptr CreateGeneralizeKernelSignaturePass() { return std::make_unique(); } -} // namespace xla::gpu +} // namespace mlir::triton::xla diff --git a/xla/backends/gpu/codegen/triton/xla_triton_int4_passes.cc b/xla/backends/gpu/codegen/triton/transforms/int4_passes.cc similarity index 99% rename from xla/backends/gpu/codegen/triton/xla_triton_int4_passes.cc rename to xla/backends/gpu/codegen/triton/transforms/int4_passes.cc index e2a938af18219..36897a26c4b0e 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_int4_passes.cc +++ b/xla/backends/gpu/codegen/triton/transforms/int4_passes.cc @@ -50,7 +50,7 @@ namespace mt = ::mlir::triton; namespace ma = ::mlir::arith; #define GEN_PASS_DEF_LOADINT4REWRITEPASS -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h.inc" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" class I4ToI8Converter : public TypeConverter { public: diff --git a/xla/backends/gpu/codegen/triton/xla_triton_passes.h b/xla/backends/gpu/codegen/triton/transforms/passes.h similarity index 81% rename from xla/backends/gpu/codegen/triton/xla_triton_passes.h rename to xla/backends/gpu/codegen/triton/transforms/passes.h index 8fcf16901a63b..10a7a7273d38e 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_passes.h +++ b/xla/backends/gpu/codegen/triton/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_PASSES_H_ -#define XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_PASSES_H_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_TRANSFORMS_PASSES_H_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_TRANSFORMS_PASSES_H_ #include #include @@ -26,8 +26,9 @@ limitations under the License. namespace mlir::triton::xla { #define GEN_PASS_DECL -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h.inc" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" +std::unique_ptr CreateGeneralizeKernelSignaturePass(); std::unique_ptr CreateSparseAddEncodingPass( int32_t num_warps = 4, int32_t threads_per_warp = 32, int32_t num_ctas = 1); std::unique_ptr CreateSparseBlockedToMMAPass(); @@ -44,8 +45,8 @@ bool ContainsOp(mlir::Operation* op, llvm::function_ref fn); #define GEN_PASS_REGISTRATION -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h.inc" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" } // namespace mlir::triton::xla -#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_XLA_TRITON_PASSES_H_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_TRANSFORMS_PASSES_H_ diff --git a/xla/backends/gpu/codegen/triton/xla_triton_passes.td b/xla/backends/gpu/codegen/triton/transforms/passes.td similarity index 87% rename from xla/backends/gpu/codegen/triton/xla_triton_passes.td rename to xla/backends/gpu/codegen/triton/transforms/passes.td index 21db540475b39..d8e881b5a1d74 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_passes.td +++ b/xla/backends/gpu/codegen/triton/transforms/passes.td @@ -13,11 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_PASSES_TD_ -#define XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_PASSES_TD_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ +#define XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ include "mlir/Pass/PassBase.td" +def GeneralizeKernelSignaturePass + : Pass<"generalize-kernel-signature"> { + let summary = "Rewrite kernels to use generic data pointer arguments."; + let description = [{ + Rewrite signatures of kernel functions from global pointers to generic + pointers and cast them to global ones within the kernel. + }]; + let constructor = "CreateGeneralizeKernelSignaturePass()"; +} + def SparseAddEncodingPass : Pass<"sparse-add-encoding", "mlir::ModuleOp"> { let summary = "Add sparse encoding for all the arguments of a SparseDotOp."; let options = [ @@ -106,4 +116,4 @@ def LoadInt4RewritePass let constructor = "CreateInt4ToPackedInt4RewritePass()"; } -#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_PASSES_TD_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ diff --git a/xla/backends/gpu/codegen/triton/xla_triton_prevent_mmav3_loop_unrolling_pass.cc b/xla/backends/gpu/codegen/triton/transforms/prevent_mmav3_loop_unrolling_pass.cc similarity index 95% rename from xla/backends/gpu/codegen/triton/xla_triton_prevent_mmav3_loop_unrolling_pass.cc rename to xla/backends/gpu/codegen/triton/transforms/prevent_mmav3_loop_unrolling_pass.cc index 57f49ad7c5a09..19aa852ead6b0 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_prevent_mmav3_loop_unrolling_pass.cc +++ b/xla/backends/gpu/codegen/triton/transforms/prevent_mmav3_loop_unrolling_pass.cc @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -29,7 +29,7 @@ namespace mlir::triton::xla { namespace { #define GEN_PASS_DEF_PREVENTMMAV3LOOPUNROLLINGPASS -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h.inc" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" struct PreventMmaV3LoopUnrollingPass : public impl::PreventMmaV3LoopUnrollingPassBase< diff --git a/xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc b/xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc similarity index 99% rename from xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc rename to xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc index 9aa1f43cc5958..9a228742d0acd 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc +++ b/xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc @@ -54,8 +54,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" @@ -122,7 +122,7 @@ bool ContainsOp(mlir::Operation *op, #define GEN_PASS_DEF_SPARSELOCALLOADTOLLVMPASS #define GEN_PASS_DEF_SPARSEREMOVELAYOUTCONVERSIONPASS #define GEN_PASS_DEF_SPARSEWGMMAOPTOLLVMPASS -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h.inc" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" constexpr int kThreadsPerWarp = 32; // Each 16x16 original sparse matrix tile requires 16 metadata values of diff --git a/xla/backends/gpu/codegen/triton/tests/BUILD b/xla/backends/gpu/codegen/triton/transforms/tests/BUILD similarity index 100% rename from xla/backends/gpu/codegen/triton/tests/BUILD rename to xla/backends/gpu/codegen/triton/transforms/tests/BUILD diff --git a/xla/backends/gpu/codegen/triton/tests/int4_packed_dim_major_1d.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_major_1d.mlir similarity index 100% rename from xla/backends/gpu/codegen/triton/tests/int4_packed_dim_major_1d.mlir rename to xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_major_1d.mlir diff --git a/xla/backends/gpu/codegen/triton/tests/int4_packed_dim_major_2d.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_major_2d.mlir similarity index 100% rename from xla/backends/gpu/codegen/triton/tests/int4_packed_dim_major_2d.mlir rename to xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_major_2d.mlir diff --git a/xla/backends/gpu/codegen/triton/tests/int4_packed_dim_minor_1d.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_minor_1d.mlir similarity index 100% rename from xla/backends/gpu/codegen/triton/tests/int4_packed_dim_minor_1d.mlir rename to xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_minor_1d.mlir diff --git a/xla/backends/gpu/codegen/triton/tests/int4_packed_dim_minor_2d.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_minor_2d.mlir similarity index 100% rename from xla/backends/gpu/codegen/triton/tests/int4_packed_dim_minor_2d.mlir rename to xla/backends/gpu/codegen/triton/transforms/tests/int4_packed_dim_minor_2d.mlir diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 216e5025fc1f5..1b24e973795d1 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -613,8 +613,8 @@ lit_test_suite( # srcs = ["xla-opt.cc"], # deps = [ # "//xla/backends/gpu/codegen/emitters/transforms:passes", -# "//xla/backends/gpu/codegen/triton:xla_triton", -# "//xla/backends/gpu/codegen/triton:xla_triton_passes", +# "//xla/backends/gpu/codegen/triton/ir:triton_xla", +# "//xla/backends/gpu/codegen/triton/transforms:passes", # "//xla/codegen/emitters/transforms:passes", # "@llvm-project//mlir:AllExtensions", # "@llvm-project//mlir:MlirOptLib", diff --git a/xla/service/gpu/tests/xla-opt.cc b/xla/service/gpu/tests/xla-opt.cc index 42ad140f4325c..af5545773f08b 100644 --- a/xla/service/gpu/tests/xla-opt.cc +++ b/xla/service/gpu/tests/xla-opt.cc @@ -16,8 +16,8 @@ limitations under the License. #include "mlir/InitAllExtensions.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "xla/backends/gpu/codegen/emitters/transforms/passes.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h" -#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h" +#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" #include "xla/codegen/emitters/transforms/passes.h" #include "third_party/triton/bin/RegisterTritonDialects.h" @@ -26,7 +26,7 @@ int main(int argc, char **argv) { mlir::registerAllExtensions(registry); registerTritonDialects(registry); // This registers all passes as well. registry.insert(); - mlir::triton::xla::registerTritonFusionTransformsPasses(); + mlir::triton::xla::registerTritonXlaTransformsPasses(); xla::emitters::registerTransformsPasses(); xla::gpu::registerGpuFusionTransformsPasses();