Skip to content

Commit

Permalink
[xla:cpu:onednn] Enable oneDNN thread pool targets
Browse files Browse the repository at this point in the history
Add build rules that have empty srcs/hdrs and deps when building without Graph API.

PiperOrigin-RevId: 726011881
  • Loading branch information
penpornk authored and Google-ML-Automation committed Feb 12, 2025
1 parent 8f9f056 commit 0fc06fb
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 64 deletions.
128 changes: 64 additions & 64 deletions xla/backends/cpu/runtime/onednn/BUILD
Original file line number Diff line number Diff line change
@@ -1,64 +1,64 @@
# copybara:uncomment_begin(google-only)
# load("//third_party/intel_dnnl:common.bzl", "if_graph_api")
# load("//xla:xla.bzl", "xla_cc_test")
# load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
#
# package(
# # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
# default_visibility = [":friends"],
# licenses = ["notice"],
# )
#
# package_group(
# name = "friends",
# includes = [
# "//xla:friends",
# ],
# )
#
# cc_library(
# name = "onednn_interop",
# hdrs = if_graph_api(["onednn_interop.h"]),
# deps = if_graph_api([
# "@com_google_absl//absl/base:core_headers",
# "@com_google_absl//absl/status",
# "@onednn//:mkl_dnn",
# "//xla:util",
# "//xla/tsl/platform:logging",
# ]),
# )
#
# cc_library(
# name = "onednn_threadpool",
# hdrs = if_graph_api(["onednn_threadpool.h"]),
# deps = if_graph_api([
# ":onednn_interop",
# "@onednn//:mkl_dnn",
# "//xla/backends/cpu/runtime:parallel_loop_runner",
# ]),
# )
#
# xla_cc_test(
# name = "onednn_threadpool_test",
# srcs = if_graph_api(["onednn_threadpool_test.cc"]),
# deps = if_graph_api([
# ":onednn_interop",
# ":onednn_threadpool",
# "@com_google_absl//absl/algorithm:container",
# "@com_google_absl//absl/status",
# "@com_google_absl//absl/status:statusor",
# "@com_google_absl//absl/synchronization",
# "@eigen_archive//:eigen3",
# "@onednn//:mkl_dnn",
# "@pthreadpool",
# "//xla/backends/cpu/runtime:parallel_loop_runner",
# "//xla/tsl/concurrency:async_value",
# "//xla/tsl/lib/core:status_test_util",
# "//xla/tsl/platform:env",
# "//xla/tsl/platform:statusor",
# "//xla/tsl/platform:test",
# "//xla/tsl/platform:test_benchmark",
# "//xla/tsl/platform:test_main",
# ]) + ["@com_google_googletest//:gtest_main"],
# )
# copybara:uncomment_end
load(
"//xla/tsl/mkl:graph.bzl",
"onednn_graph_cc_library",
"onednn_graph_cc_test",
)

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [":friends"],
licenses = ["notice"],
)

package_group(
name = "friends",
includes = [
"//xla:friends",
],
)

onednn_graph_cc_library(
name = "onednn_interop",
hdrs = ["onednn_interop.h"],
deps = [
"//xla:util",
"//xla/tsl/platform:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@onednn//:mkl_dnn",
],
)

onednn_graph_cc_library(
name = "onednn_threadpool",
hdrs = ["onednn_threadpool.h"],
deps = [
":onednn_interop",
"//xla/backends/cpu/runtime:parallel_loop_runner",
"@onednn//:mkl_dnn",
],
)

onednn_graph_cc_test(
name = "onednn_threadpool_test",
srcs = ["onednn_threadpool_test.cc"],
deps = [
":onednn_interop",
":onednn_threadpool",
"//xla/backends/cpu/runtime:parallel_loop_runner",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_benchmark",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"@onednn//:mkl_dnn",
"@pthreadpool",
],
)
9 changes: 9 additions & 0 deletions xla/tsl/mkl/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ def if_mkldnn_aarch64_acl_openmp(if_true, if_false = []):
"//conditions:default": if_false,
})

# Temporarily disable Graph API on aarch64 until we change the aarch64 BUILD
# file to support Graph API.
def if_graph_api(if_true, if_false = []):
"""Returns `if_true` if Graph API is used with oneDNN."""
return select({
"@xla//xla/tsl:linux_x86_64": if_true,
"//conditions:default": if_false,
})

def _enable_local_mkl(repository_ctx):
return _TF_MKL_ROOT in repository_ctx.os.environ

Expand Down
26 changes: 26 additions & 0 deletions xla/tsl/mkl/graph.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Starlark macros for oneDNN Graph API.
Contains build rules that builds with empty srcs, hdrs, and deps if not build with Graph API.
These rules have to be outside of mkl/build_defs.bzl, otherwise we would have cyclic dependency
(xla.bzl depends on tsl which depends on mkl/build_defs.bzl).
"""

load("//xla:xla.bzl", "xla_cc_test")
load("//xla/tsl/mkl:build_defs.bzl", "if_graph_api")

def onednn_graph_cc_library(srcs = [], hdrs = [], deps = [], **kwargs):
"""cc_library rule that has empty src, hdrs and deps if not building with Graph API."""
native.cc_library(
srcs = if_graph_api(srcs),
hdrs = if_graph_api(hdrs),
deps = if_graph_api(deps),
**kwargs
)

def onednn_graph_cc_test(srcs = [], deps = [], **kwargs):
"""xla_cc_test rule that has empty src and deps if not building with Graph API."""
xla_cc_test(
srcs = if_graph_api(srcs),
deps = if_graph_api(deps) + ["@com_google_googletest//:gtest_main"],
**kwargs
)

0 comments on commit 0fc06fb

Please sign in to comment.