Skip to content

Commit

Permalink
[XLA:GPU][NFC] Clean-up directory structure. Make it like in emitters/.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726480722
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Feb 13, 2025
1 parent 96a9745 commit 83081f9
Show file tree
Hide file tree
Showing 31 changed files with 315 additions and 381 deletions.
229 changes: 5 additions & 224 deletions xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
Expand Down Expand Up @@ -143,7 +142,7 @@ cc_library(
"@com_google_absl//absl/status",
"@llvm-project//mlir:Pass",
] + if_gpu_is_configured([
":xla_triton_passes",
"//xla/backends/gpu/codegen/triton/transforms:passes",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ControlFlowToLLVM",
Expand Down Expand Up @@ -184,10 +183,7 @@ cc_library(
":compilation_pipeline",
":emitter_helpers",
":fusion_emitter_legacy_matmul",
":passes",
":support",
":xla_triton",
":xla_triton_passes",
"//xla:autotuning_proto_cc",
"//xla:permutation_util",
"//xla:shape_util",
Expand All @@ -197,6 +193,8 @@ cc_library(
"//xla:xla_proto_cc",
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
"//xla/backends/gpu/codegen/emitters/transforms:passes",
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
"//xla/backends/gpu/codegen/triton/transforms:passes",
"//xla/codegen:emitter_loc_op_builder",
"//xla/codegen/emitters:elemental_hlo_to_mlir",
"//xla/codegen/emitters/ir:xla",
Expand All @@ -222,6 +220,7 @@ cc_library(
"//xla/stream_executor:launch_dim",
"//xla/stream_executor/gpu:tma_metadata",
"//xla/tools:hlo_decomposer_lib",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -288,13 +287,13 @@ cc_library(
hdrs = ["fusion_emitter_legacy_matmul.h"],
deps = [
":emitter_helpers",
":xla_triton",
"//xla:comparison_util",
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
"//xla/codegen:emitter_loc_op_builder",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
Expand Down Expand Up @@ -397,224 +396,6 @@ xla_cc_test(
],
)

gentbl_cc_library(
name = "passes_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TritonFusionTransforms",
],
"passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "passes.td",
visibility = ["//visibility:private"],
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "passes",
srcs = [
"generalize_kernel_signature.cc",
],
hdrs = ["passes.h"],
deps = [
":passes_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

gentbl_cc_library(
name = "xla_triton_passes_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TritonFusionTransforms",
"-attrdefs-dialect=triton_xla",
],
"xla_triton_passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_passes.td",
visibility = ["//visibility:private"],
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "xla_triton_passes",
srcs = [
"xla_triton_int4_passes.cc",
"xla_triton_prevent_mmav3_loop_unrolling_pass.cc",
"xla_triton_sparse_passes.cc",
],
hdrs = [
"xla_triton_passes.h",
],
deps = [
":xla_triton",
":xla_triton_passes_inc_gen",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@triton//:TritonDialects",
"@triton//:TritonGPUToLLVM",
"@triton//:TritonGPUTransforms",
"@triton//:TritonToTritonGPU",
"@triton//third_party/nvidia:NVGPUDialect",
"@triton//third_party/nvidia:NVGPUToLLVM",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
],
)

td_library(
name = "xla_triton_td_files",
srcs = glob(["*.td"]),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "xla_triton_dialect_inc_gen",
strip_include_prefix = ".",
tbl_outs = [
(
["-gen-dialect-decls"],
"xla_triton_dialect.h.inc",
),
(
["-gen-dialect-defs"],
"xla_triton_dialect.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_dialect.td",
deps = [":xla_triton_td_files"],
)

gentbl_cc_library(
name = "xla_triton_ops_inc_gen",
strip_include_prefix = ".",
tbl_outs = [
(
["-gen-op-decls"],
"xla_triton_ops.h.inc",
),
(
["-gen-op-defs"],
"xla_triton_ops.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_ops.td",
deps = [
":xla_triton_td_files",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@triton//:td_files",
],
)

gentbl_cc_library(
name = "xla_triton_types_inc_gen",
strip_include_prefix = ".",
tbl_outs = [
(
[
"-gen-typedef-decls",
"-typedefs-dialect=triton_xla",
],
"xla_triton_types.h.inc",
),
(
[
"-gen-typedef-defs",
"-typedefs-dialect=triton_xla",
],
"xla_triton_types.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_types.td",
deps = [":xla_triton_td_files"],
)

gentbl_cc_library(
name = "xla_triton_attrs_inc_gen",
strip_include_prefix = ".",
tbl_outs = [
(
[
"-gen-attrdef-decls",
"-attrdefs-dialect=triton_xla",
],
"xla_triton_attrs.h.inc",
),
(
[
"-gen-attrdef-defs",
"-attrdefs-dialect=triton_xla",
],
"xla_triton_attrs.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "xla_triton_attrs.td",
deps = [
":xla_triton_td_files",
"@triton//:td_files",
],
)

cc_library(
name = "xla_triton",
srcs = [
"xla_triton_attrs.cc",
"xla_triton_dialect.cc",
"xla_triton_ops.cc",
"xla_triton_types.cc",
],
hdrs = ["xla_triton_ops.h"],
deps = [
"xla_triton_types_inc_gen",
":xla_triton_attrs_inc_gen",
":xla_triton_dialect_inc_gen",
":xla_triton_ops_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@triton//:TritonDialects",
],
)

xla_test(
name = "fusion_emitter_deviceless_test",
srcs = ["fusion_emitter_deviceless_test.cc"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h"
#include "xla/backends/gpu/codegen/triton/transforms/passes.h"
#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
#include "xla/service/hlo_module_config.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h"
#include "xla/backends/gpu/codegen/triton/transforms/passes.h"
#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/hlo_module_config.h"
Expand Down
9 changes: 4 additions & 5 deletions xla/backends/gpu/codegen/triton/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ limitations under the License.
#include "xla/backends/gpu/codegen/triton/compilation_pipeline.h"
#include "xla/backends/gpu/codegen/triton/emitter_helpers.h"
#include "xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.h"
#include "xla/backends/gpu/codegen/triton/passes.h"
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h"
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h"
#include "xla/backends/gpu/codegen/triton/transforms/passes.h"
#include "xla/codegen/emitter_loc_op_builder.h"
#include "xla/codegen/emitters/elemental_hlo_to_mlir.h"
#include "xla/codegen/emitters/ir/xla_ops.h"
Expand Down Expand Up @@ -120,12 +120,11 @@ limitations under the License.
#include "xla/stream_executor/gpu/tma_metadata.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/tools/hlo_decomposer.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/path.h"
#include "tsl/platform/statusor.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -1313,7 +1312,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
}
// Triton generates pointers to the global address space, while XLA needs a
// kernel signature with pointers to the generic address space.
pm.addPass(CreateGeneralizeKernelSignaturePass());
pm.addPass(mlir::triton::xla::CreateGeneralizeKernelSignaturePass());
// llvm::Linker::linkModules() segfaults if we don't strip locations.
pm.addPass(mlir::createStripDebugInfoPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ limitations under the License.
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "xla/backends/gpu/codegen/triton/emitter_helpers.h"
#include "xla/backends/gpu/codegen/triton/xla_triton_ops.h"
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h"
#include "xla/codegen/emitter_loc_op_builder.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand Down
Loading

0 comments on commit 83081f9

Please sign in to comment.