-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[xla:cpu:onednn] Enable oneDNN thread pool targets
Add build rules that have empty srcs/hdrs and deps when building without Graph API. PiperOrigin-RevId: 726011881
- Loading branch information
1 parent
8f9f056
commit 0fc06fb
Showing
3 changed files
with
99 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,64 @@ | ||
# copybara:uncomment_begin(google-only) | ||
# load("//third_party/intel_dnnl:common.bzl", "if_graph_api") | ||
# load("//xla:xla.bzl", "xla_cc_test") | ||
# load("//xla/tsl/platform:rules_cc.bzl", "cc_library") | ||
# | ||
# package( | ||
# # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], | ||
# default_visibility = [":friends"], | ||
# licenses = ["notice"], | ||
# ) | ||
# | ||
# package_group( | ||
# name = "friends", | ||
# includes = [ | ||
# "//xla:friends", | ||
# ], | ||
# ) | ||
# | ||
# cc_library( | ||
# name = "onednn_interop", | ||
# hdrs = if_graph_api(["onednn_interop.h"]), | ||
# deps = if_graph_api([ | ||
# "@com_google_absl//absl/base:core_headers", | ||
# "@com_google_absl//absl/status", | ||
# "@onednn//:mkl_dnn", | ||
# "//xla:util", | ||
# "//xla/tsl/platform:logging", | ||
# ]), | ||
# ) | ||
# | ||
# cc_library( | ||
# name = "onednn_threadpool", | ||
# hdrs = if_graph_api(["onednn_threadpool.h"]), | ||
# deps = if_graph_api([ | ||
# ":onednn_interop", | ||
# "@onednn//:mkl_dnn", | ||
# "//xla/backends/cpu/runtime:parallel_loop_runner", | ||
# ]), | ||
# ) | ||
# | ||
# xla_cc_test( | ||
# name = "onednn_threadpool_test", | ||
# srcs = if_graph_api(["onednn_threadpool_test.cc"]), | ||
# deps = if_graph_api([ | ||
# ":onednn_interop", | ||
# ":onednn_threadpool", | ||
# "@com_google_absl//absl/algorithm:container", | ||
# "@com_google_absl//absl/status", | ||
# "@com_google_absl//absl/status:statusor", | ||
# "@com_google_absl//absl/synchronization", | ||
# "@eigen_archive//:eigen3", | ||
# "@onednn//:mkl_dnn", | ||
# "@pthreadpool", | ||
# "//xla/backends/cpu/runtime:parallel_loop_runner", | ||
# "//xla/tsl/concurrency:async_value", | ||
# "//xla/tsl/lib/core:status_test_util", | ||
# "//xla/tsl/platform:env", | ||
# "//xla/tsl/platform:statusor", | ||
# "//xla/tsl/platform:test", | ||
# "//xla/tsl/platform:test_benchmark", | ||
# "//xla/tsl/platform:test_main", | ||
# ]) + ["@com_google_googletest//:gtest_main"], | ||
# ) | ||
# copybara:uncomment_end | ||
load( | ||
"//xla/tsl/mkl:graph.bzl", | ||
"onednn_graph_cc_library", | ||
"onednn_graph_cc_test", | ||
) | ||
|
||
package( | ||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], | ||
default_visibility = [":friends"], | ||
licenses = ["notice"], | ||
) | ||
|
||
package_group( | ||
name = "friends", | ||
includes = [ | ||
"//xla:friends", | ||
], | ||
) | ||
|
||
onednn_graph_cc_library( | ||
name = "onednn_interop", | ||
hdrs = ["onednn_interop.h"], | ||
deps = [ | ||
"//xla:util", | ||
"//xla/tsl/platform:logging", | ||
"@com_google_absl//absl/base:core_headers", | ||
"@com_google_absl//absl/status", | ||
"@onednn//:mkl_dnn", | ||
], | ||
) | ||
|
||
onednn_graph_cc_library( | ||
name = "onednn_threadpool", | ||
hdrs = ["onednn_threadpool.h"], | ||
deps = [ | ||
":onednn_interop", | ||
"//xla/backends/cpu/runtime:parallel_loop_runner", | ||
"@onednn//:mkl_dnn", | ||
], | ||
) | ||
|
||
onednn_graph_cc_test( | ||
name = "onednn_threadpool_test", | ||
srcs = ["onednn_threadpool_test.cc"], | ||
deps = [ | ||
":onednn_interop", | ||
":onednn_threadpool", | ||
"//xla/backends/cpu/runtime:parallel_loop_runner", | ||
"//xla/tsl/concurrency:async_value", | ||
"//xla/tsl/lib/core:status_test_util", | ||
"//xla/tsl/platform:env", | ||
"//xla/tsl/platform:statusor", | ||
"//xla/tsl/platform:test", | ||
"//xla/tsl/platform:test_benchmark", | ||
"//xla/tsl/platform:test_main", | ||
"@com_google_absl//absl/algorithm:container", | ||
"@com_google_absl//absl/status", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/synchronization", | ||
"@eigen_archive//:eigen3", | ||
"@onednn//:mkl_dnn", | ||
"@pthreadpool", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
"""Starlark macros for oneDNN Graph API. | ||
Contains build rules that builds with empty srcs, hdrs, and deps if not build with Graph API. | ||
These rules have to be outside of mkl/build_defs.bzl, otherwise we would have cyclic dependency | ||
(xla.bzl depends on tsl which depends on mkl/build_defs.bzl). | ||
""" | ||
|
||
load("//xla:xla.bzl", "xla_cc_test") | ||
load("//xla/tsl/mkl:build_defs.bzl", "if_graph_api") | ||
|
||
def onednn_graph_cc_library(srcs = [], hdrs = [], deps = [], **kwargs): | ||
"""cc_library rule that has empty src, hdrs and deps if not building with Graph API.""" | ||
native.cc_library( | ||
srcs = if_graph_api(srcs), | ||
hdrs = if_graph_api(hdrs), | ||
deps = if_graph_api(deps), | ||
**kwargs | ||
) | ||
|
||
def onednn_graph_cc_test(srcs = [], deps = [], **kwargs): | ||
"""xla_cc_test rule that has empty src and deps if not building with Graph API.""" | ||
xla_cc_test( | ||
srcs = if_graph_api(srcs), | ||
deps = if_graph_api(deps) + ["@com_google_googletest//:gtest_main"], | ||
**kwargs | ||
) |