-
Notifications
You must be signed in to change notification settings - Fork 502
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[xla:cpu] Add OneDnnThreadPool based on parallel loop runner
PiperOrigin-RevId: 724703087
- Loading branch information
1 parent
a3ede7c
commit 8bac4a2
Showing
6 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# copybara:uncomment_begin(google-only) | ||
# load("//xla:xla.bzl", "xla_cc_test") | ||
# load("//xla/tsl/platform:rules_cc.bzl", "cc_library") | ||
# | ||
# package( | ||
# # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], | ||
# default_visibility = [":friends"], | ||
# licenses = ["notice"], | ||
# ) | ||
# | ||
# package_group( | ||
# name = "friends", | ||
# includes = [ | ||
# "//xla:friends", | ||
# ], | ||
# ) | ||
# | ||
# cc_library( | ||
# name = "onednn_interop", | ||
# hdrs = ["onednn_interop.h"], | ||
# deps = [ | ||
# "@com_google_absl//absl/base:core_headers", | ||
# "@com_google_absl//absl/status", | ||
# "@onednn//:mkl_dnn", | ||
# "//xla:util", | ||
# "//xla/tsl/platform:logging", | ||
# ], | ||
# ) | ||
# | ||
# cc_library( | ||
# name = "onednn_threadpool", | ||
# hdrs = ["onednn_threadpool.h"], | ||
# deps = [ | ||
# ":onednn_interop", | ||
# "@onednn//:mkl_dnn", | ||
# "//xla/backends/cpu/runtime:parallel_loop_runner", | ||
# ], | ||
# ) | ||
# | ||
# xla_cc_test( | ||
# name = "onednn_threadpool_test", | ||
# srcs = ["onednn_threadpool_test.cc"], | ||
# deps = [ | ||
# ":onednn_interop", | ||
# ":onednn_threadpool", | ||
# "@com_google_googletest//:gtest_main", | ||
# "@com_google_absl//absl/algorithm:container", | ||
# "@com_google_absl//absl/status", | ||
# "@com_google_absl//absl/status:statusor", | ||
# "@com_google_absl//absl/synchronization", | ||
# "@eigen_archive//:eigen3", | ||
# "@onednn//:mkl_dnn", | ||
# "@pthreadpool", | ||
# "//xla/backends/cpu/runtime:parallel_loop_runner", | ||
# "//xla/tsl/concurrency:async_value", | ||
# "//xla/tsl/lib/core:status_test_util", | ||
# "//xla/tsl/platform:env", | ||
# "//xla/tsl/platform:statusor", | ||
# "//xla/tsl/platform:test", | ||
# "//xla/tsl/platform:test_benchmark", | ||
# "//xla/tsl/platform:test_main", | ||
# ], | ||
# ) | ||
# copybara:uncomment_end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* Copyright 2025 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_INTEROP_H_ | ||
#define XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_INTEROP_H_ | ||
|
||
#include "oneapi/dnnl/dnnl_graph.hpp" | ||
#include "absl/base/optimization.h" | ||
#include "absl/status/status.h" | ||
#include "xla/tsl/platform/logging.h" | ||
#include "xla/util.h" | ||
|
||
namespace xla::cpu { | ||
|
||
#define ONEDNN_RETURN_IF_ERROR(expr) \ | ||
do { \ | ||
absl::Status s = OneDnnStatusToStatus(expr); \ | ||
if (!s.ok()) { \ | ||
return s; \ | ||
} \ | ||
} while (0) | ||
|
||
#define ONEDNN_LOG_IF_ERROR(expr) \ | ||
do { \ | ||
absl::Status s = OneDnnStatusToStatus(expr); \ | ||
if (!s.ok()) { \ | ||
LOG(ERROR) << "DNNL operation failed: " << s; \ | ||
} \ | ||
} while (0) | ||
|
||
// Statically initializes XNNPACK for the current process. | ||
absl::Status InitializeXnnPack(); | ||
|
||
// Converts oneDNN status to absl::Status. | ||
inline absl::Status OneDnnStatusToStatus(dnnl::graph::status status) { | ||
if (ABSL_PREDICT_TRUE(status == dnnl::graph::status::success)) { | ||
return absl::OkStatus(); | ||
} | ||
|
||
auto error_message = [](dnnl::graph::status status) { | ||
switch (status) { | ||
case dnnl::graph::status::success: | ||
return ""; | ||
case dnnl::graph::status::out_of_memory: | ||
return "out of memory"; | ||
case dnnl::graph::status::invalid_arguments: | ||
return "invalid arguments"; | ||
case dnnl::graph::status::unimplemented: | ||
return "unimplemented"; | ||
case dnnl::graph::status::last_impl_reached: | ||
return "last implementation reached"; | ||
case dnnl::graph::status::runtime_error: | ||
return "runtime error"; | ||
case dnnl::graph::status::not_required: | ||
return "not required"; | ||
case dnnl::graph::status::invalid_graph: | ||
return "invalid graph"; | ||
case dnnl::graph::status::invalid_graph_op: | ||
return "invalid graph op"; | ||
case dnnl::graph::status::invalid_shape: | ||
return "invalid shape"; | ||
case dnnl::graph::status::invalid_data_type: | ||
return "invalid data type"; | ||
} | ||
}; | ||
|
||
return Internal("DNNL operation failed: %s", error_message(status)); | ||
} | ||
|
||
} // namespace xla::cpu | ||
|
||
#endif // XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_INTEROP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* Copyright 2025 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_THREADPOOL_H_ | ||
#define XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_THREADPOOL_H_ | ||
|
||
#include <cstddef> | ||
#include <cstdint> | ||
#include <functional> | ||
|
||
#include "oneapi/dnnl/dnnl_threadpool_iface.hpp" | ||
#include "xla/backends/cpu/runtime/parallel_loop_runner.h" | ||
|
||
namespace xla::cpu { | ||
|
||
class OneDnnThreadPool final | ||
: public dnnl::threadpool_interop::threadpool_iface { | ||
public: | ||
explicit OneDnnThreadPool(ParallelLoopRunner* runner) : runner_(runner) {} | ||
|
||
int get_num_threads() const final; | ||
bool get_in_parallel() const final; | ||
uint64_t get_flags() const final; | ||
|
||
void parallel_for(int n, const std::function<void(int, int)>& fn) final; | ||
|
||
private: | ||
ParallelLoopRunner* runner_; | ||
}; | ||
|
||
inline int OneDnnThreadPool::get_num_threads() const { | ||
return runner_->num_threads(); | ||
} | ||
|
||
inline bool OneDnnThreadPool::get_in_parallel() const { | ||
return runner_->is_in_runner(); | ||
} | ||
|
||
inline uint64_t OneDnnThreadPool::get_flags() const { return 0; } | ||
|
||
inline void OneDnnThreadPool::parallel_for( | ||
int n, const std::function<void(int, int)>& fn) { | ||
runner_->Parallelize(n, [fn, n](size_t task_index) { fn(task_index, n); }); | ||
} | ||
|
||
} // namespace xla::cpu | ||
|
||
#endif // XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_THREADPOOL_H_ |
118 changes: 118 additions & 0 deletions
118
xla/backends/cpu/runtime/onednn/onednn_threadpool_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
/* Copyright 2025 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "xla/backends/cpu/runtime/onednn/onednn_threadpool.h" | ||
|
||
#include <cmath> | ||
#include <cstdint> | ||
#include <vector> | ||
|
||
#include "oneapi/dnnl/dnnl.hpp" | ||
#include "oneapi/dnnl/dnnl_common.hpp" | ||
#include "oneapi/dnnl/dnnl_graph.hpp" | ||
#include "oneapi/dnnl/dnnl_threadpool.hpp" | ||
#include <gtest/gtest.h> | ||
#include "absl/status/statusor.h" | ||
#include "xla/backends/cpu/runtime/onednn/onednn_interop.h" | ||
#include "xla/backends/cpu/runtime/parallel_loop_runner.h" | ||
#include "xla/tsl/concurrency/async_value_ref.h" | ||
#include "xla/tsl/platform/env.h" | ||
#include "xla/tsl/platform/statusor.h" | ||
#include "xla/tsl/platform/test.h" | ||
#include "xla/tsl/platform/threadpool.h" | ||
|
||
#define EIGEN_USE_THREADS | ||
#include "unsupported/Eigen/CXX11/Tensor" | ||
|
||
namespace xla::cpu { | ||
namespace { | ||
|
||
// Creates a graph with a single Exp operation. | ||
static absl::StatusOr<dnnl::graph::graph> CreateExpGraph( | ||
const dnnl::graph::logical_tensor& src_tensor, | ||
const dnnl::graph::logical_tensor& dst_tensor) { | ||
dnnl::graph::op exp_op(0, dnnl::graph::op::kind::Exp, {src_tensor}, | ||
{dst_tensor}); | ||
|
||
dnnl::graph::graph g(dnnl::engine::kind::cpu); | ||
ONEDNN_RETURN_IF_ERROR(g.add_op(exp_op)); | ||
g.finalize(); | ||
|
||
return g; | ||
} | ||
|
||
TEST(OneDnnThreadPoolTest, Binary) { | ||
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 32); | ||
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), | ||
threads.NumThreads()); | ||
|
||
ParallelLoopRunner runner(&device); | ||
OneDnnThreadPool threadpool(&runner); | ||
|
||
int64_t d0 = 100; | ||
int64_t d1 = 1000; | ||
int64_t num_elements = d0 * d1; | ||
|
||
// We use row-major layout for both source and destination tensors. | ||
dnnl::graph::logical_tensor::dims src_dims = {d0, d1}; | ||
dnnl::graph::logical_tensor::dims dst_dims = {d0, d1}; | ||
|
||
dnnl::graph::logical_tensor::dims src_strides = {d1, 1}; | ||
dnnl::graph::logical_tensor::dims dst_strides = {d1, 1}; | ||
|
||
dnnl::graph::logical_tensor src_tensor( | ||
0, dnnl::graph::logical_tensor::data_type::f32, src_dims, src_strides); | ||
dnnl::graph::logical_tensor dst_tensor( | ||
1, dnnl::graph::logical_tensor::data_type::f32, dst_dims, dst_strides); | ||
|
||
// Compile oneDNN graph with a single Exp operation. | ||
TF_ASSERT_OK_AND_ASSIGN(dnnl::graph::graph g, | ||
CreateExpGraph(src_tensor, dst_tensor)); | ||
std::vector<dnnl::graph::partition> partitions = g.get_partitions(); | ||
|
||
// Create oneDNN engine for running the graph on CPU. | ||
dnnl::engine engine(dnnl::engine::kind::cpu, 0); | ||
|
||
// Create oneDNN stream backed by parallel loop runner. | ||
dnnl::stream stream = | ||
dnnl::stream(dnnl::threadpool_interop::make_stream(engine, &threadpool)); | ||
|
||
// Compile graph partitions for given engine. | ||
std::vector<dnnl::graph::compiled_partition> compiled_partitions; | ||
for (const auto& partition : partitions) { | ||
compiled_partitions.push_back( | ||
partition.compile({src_tensor}, {dst_tensor}, engine)); | ||
} | ||
|
||
// Create tensors for source and destination. | ||
std::vector<float> src_data(num_elements, 1.0f); | ||
std::vector<float> dst_data(num_elements, 0.0f); | ||
|
||
dnnl::graph::tensor src(src_tensor, engine, src_data.data()); | ||
dnnl::graph::tensor dst(dst_tensor, engine, dst_data.data()); | ||
|
||
// Execute compiled oneDNN graph on the CPU stream. | ||
compiled_partitions[0].execute(stream, {src}, {dst}); | ||
|
||
// Wait for the completion of parallel loops scheduled into the runner. | ||
tsl::BlockUntilReady(runner.done_event()); | ||
|
||
for (int i = 0; i < num_elements; ++i) { | ||
EXPECT_NEAR(dst_data[i], std::exp(1.0f), 1e-5); | ||
} | ||
} | ||
|
||
} // namespace | ||
} // namespace xla::cpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters