Skip to content

Commit

Permalink
[xla:cpu] Implement dynamic versions of parallel loops
Browse files Browse the repository at this point in the history
Implement new parallel loop APIs to be compatible with latest XNNPACK.

PiperOrigin-RevId: 724408576
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Feb 10, 2025
1 parent b39e89e commit d3e9664
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 13 deletions.
12 changes: 6 additions & 6 deletions tsl_workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def _tf_repositories():
# LINT.IfChange
tf_http_archive(
name = "XNNPACK",
sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5",
strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"),
sha256 = "face04056299ca22e2dbbf122a27aef289443dc7b1ad7e33a52714d6acc084eb",
strip_prefix = "XNNPACK-e55b3998cadb320188759aaada27328cbacc3253",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/e55b3998cadb320188759aaada27328cbacc3253.zip"),
)
# LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake)

Expand All @@ -126,9 +126,9 @@ def _tf_repositories():

tf_http_archive(
name = "pthreadpool",
sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95",
strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"),
sha256 = "cb668c32d6e05099492cc7ea19168e2dad0d1dcc4cbaa0e34fd4b38d39f0e03e",
strip_prefix = "pthreadpool-f94ab76fe99754960035d520dce28e15b647e8cf",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/f94ab76fe99754960035d520dce28e15b647e8cf.zip"),
)

tf_http_archive(
Expand Down
12 changes: 6 additions & 6 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def _tf_repositories():
# LINT.IfChange
tf_http_archive(
name = "XNNPACK",
sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5",
strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"),
sha256 = "face04056299ca22e2dbbf122a27aef289443dc7b1ad7e33a52714d6acc084eb",
strip_prefix = "XNNPACK-e55b3998cadb320188759aaada27328cbacc3253",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/e55b3998cadb320188759aaada27328cbacc3253.zip"),
)
# LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake)

Expand All @@ -74,9 +74,9 @@ def _tf_repositories():

tf_http_archive(
name = "pthreadpool",
sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95",
strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"),
sha256 = "cb668c32d6e05099492cc7ea19168e2dad0d1dcc4cbaa0e34fd4b38d39f0e03e",
strip_prefix = "pthreadpool-f94ab76fe99754960035d520dce28e15b647e8cf",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/f94ab76fe99754960035d520dce28e15b647e8cf.zip"),
)

tf_http_archive(
Expand Down
18 changes: 18 additions & 0 deletions xla/backends/cpu/runtime/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
ScheduleAll(num_tasks, ParallelTask1DTile1D{range, tile, std::move(task)});
}

void ParallelLoopRunner::ParallelizeDynamic(size_t range, size_t tile,
Task1DTile1DDynamic task) {
Parallelize(range, tile, std::move(task));
}

struct ParallelLoopRunner::ParallelTask2DTile1D {
ABSL_ATTRIBUTE_ALWAYS_INLINE void operator()(size_t task_index) const {
auto x = Delinearize(task_index, range_i, range_j, tile_j);
Expand Down Expand Up @@ -312,6 +317,12 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
ParallelTask2DTile1D{range_i, range_j, tile_j, std::move(task)});
}

void ParallelLoopRunner::ParallelizeDynamic(size_t range_i, size_t range_j,
size_t tile_j,
Task2DTile1DDynamic task) {
Parallelize(range_i, range_j, tile_j, std::move(task));
}

struct ParallelLoopRunner::ParallelTask3DTile2D {
ABSL_ATTRIBUTE_ALWAYS_INLINE void operator()(size_t task_index) const {
auto x = Delinearize(task_index, range_i, range_j, range_k, tile_j, tile_k);
Expand Down Expand Up @@ -351,4 +362,11 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
tile_k, std::move(task)});
}

void ParallelLoopRunner::ParallelizeDynamic(size_t range_i, size_t range_j,
size_t range_k, size_t tile_j,
size_t tile_k,
Task3DTile2DDynamic task) {
Parallelize(range_i, range_j, range_k, tile_j, tile_k, std::move(task));
}

} // namespace xla::cpu
24 changes: 23 additions & 1 deletion xla/backends/cpu/runtime/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <atomic>
#include <cstddef>
#include <functional>
#include <optional>

#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"
Expand Down Expand Up @@ -69,13 +68,24 @@ class ParallelLoopRunner {
using Task1D = std::function<void(size_t offset)>;

using Task1DTile1D = std::function<void(size_t offset, size_t extent)>;
using Task1DTile1DDynamic = std::function<void(size_t offset, size_t count)>;

using Task2DTile1D =
std::function<void(size_t offset_i, size_t offset_j, size_t extent_j)>;
using Task2DTile1DDynamic =
std::function<void(size_t offset_i, size_t offset_j, size_t count_j)>;

using Task3DTile2D =
std::function<void(size_t offset_i, size_t offset_j, size_t offset_k,
size_t extent_j, size_t extent_k)>;
using Task3DTile2DDynamic =
std::function<void(size_t offset_i, size_t offset_j, size_t offset_k,
size_t count_j, size_t count_k)>;

// IMPORTANT: For `dynamic` versions of the parallel loops, the runner is free
// to adjust `count` for tiled dimensions to minimize the number of launched
// tasks. Today we don't take advantage of this feature, and always launch the
// same number of tasks as in regular parallel loops.

// This function implements a parallel version of a following loop:
//
Expand All @@ -89,6 +99,9 @@ class ParallelLoopRunner {
// task(i, std::min(range - i, tile));
void Parallelize(size_t range, size_t tile, Task1DTile1D task);

// Implements a parallel version of 1D loop with dynamic task count.
void ParallelizeDynamic(size_t range, size_t tile, Task1DTile1DDynamic task);

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range_i; i++)
Expand All @@ -97,6 +110,10 @@ class ParallelLoopRunner {
void Parallelize(size_t range_i, size_t range_j, size_t tile_j,
Task2DTile1D task);

// Implements a parallel version of 2D loop with dynamic task count.
void ParallelizeDynamic(size_t range_i, size_t range_j, size_t tile_j,
Task2DTile1DDynamic task);

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range_i; i++)
Expand All @@ -106,6 +123,11 @@ class ParallelLoopRunner {
void Parallelize(size_t range_i, size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k, Task3DTile2D task);

// Implements a parallel version of 3D loop with dynamic task count.
void ParallelizeDynamic(size_t range_i, size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k,
Task3DTile2DDynamic task);

// Resets the parallel loop runner `done_event` and returns the previous one
// to the caller.
tsl::AsyncValueRef<tsl::Chain> ResetDoneEvent();
Expand Down
87 changes: 87 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,21 @@ static void Parallelize1DTile1D( // NOLINT
Cast(threadpool)->runner()->Parallelize(range, tile, std::move(task));
}

static void Parallelize1DTile1DDynamic( // NOLINT
pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_dynamic_t function,
void* context, size_t range, size_t tile, uint32_t flags) {
if (ABSL_PREDICT_FALSE(threadpool == nullptr)) {
function(context, 0, range);
return;
}

ParallelLoopRunner::Task1DTile1DDynamic task =
[function, context](size_t offset, size_t count) {
(*function)(context, offset, count);
};
Cast(threadpool)->runner()->ParallelizeDynamic(range, tile, std::move(task));
}

static void Parallelize2DTile1D(pthreadpool_t threadpool, // NOLINT
pthreadpool_task_2d_tile_1d_t function,
void* context, size_t range_i, size_t range_j,
Expand All @@ -215,6 +230,26 @@ static void Parallelize2DTile1D(pthreadpool_t threadpool, // NOLINT
->Parallelize(range_i, range_j, tile_j, std::move(task));
}

static void Parallelize2DTile1DDynamic( // NOLINT
pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_dynamic_t function,
void* context, size_t range_i, size_t range_j, size_t tile_j,
uint32_t flags) {
if (ABSL_PREDICT_FALSE(threadpool == nullptr)) {
for (size_t i = 0; i < range_i; i++) {
function(context, i, 0, range_j);
}
return;
}

ParallelLoopRunner::Task2DTile1DDynamic task =
[function, context](size_t offset_i, size_t offset_j, size_t extent_j) {
(*function)(context, offset_i, offset_j, extent_j);
};
Cast(threadpool)
->runner()
->ParallelizeDynamic(range_i, range_j, tile_j, std::move(task));
}

static void Parallelize3DTile2D(pthreadpool_t threadpool, // NOLINT
pthreadpool_task_3d_tile_2d_t function,
void* context, size_t range_i, size_t range_j,
Expand Down Expand Up @@ -242,6 +277,28 @@ static void Parallelize3DTile2D(pthreadpool_t threadpool, // NOLINT
->Parallelize(range_i, range_j, range_k, tile_j, tile_k, std::move(task));
}

static void Parallelize3DTile2DDynamic( // NOLINT
pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_dynamic_t function,
void* context, size_t range_i, size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k, uint32_t flags) {
if (ABSL_PREDICT_FALSE(threadpool == nullptr)) {
for (size_t i = 0; i < range_i; i++) {
function(context, i, 0, 0, range_j, range_k);
}
return;
}

ParallelLoopRunner::Task3DTile2DDynamic task =
[function, context](size_t offset_i, size_t offset_j, size_t offset_k,
size_t count_j, size_t count_k) {
(*function)(context, offset_i, offset_j, offset_k, count_j, count_k);
};
Cast(threadpool)
->runner()
->ParallelizeDynamic(range_i, range_j, range_k, tile_j, tile_k,
std::move(task));
}

} // namespace xla::cpu

#if defined(XLA_CPU_USE_CUSTOM_PTHREADPOOL)
Expand Down Expand Up @@ -285,6 +342,13 @@ extern "C" void pthreadpool_parallelize_1d_tile_1d(
flags);
}

extern "C" void pthreadpool_parallelize_1d_tile_1d_dynamic(
pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_dynamic_t function,
void* context, size_t range, size_t tile, uint32_t flags) {
xla::cpu::Parallelize1DTile1DDynamic(threadpool, function, context, range,
tile, flags);
}

extern "C" void pthreadpool_parallelize_2d(pthreadpool_t threadpool,
pthreadpool_task_2d_t function,
void* context, size_t range_i,
Expand All @@ -306,6 +370,14 @@ extern "C" void pthreadpool_parallelize_2d_tile_1d(
tile_j, flags);
}

extern "C" void pthreadpool_parallelize_2d_tile_1d_dynamic(
pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_dynamic_t function,
void* context, size_t range_i, size_t range_j, size_t tile_j,
uint32_t flags) {
xla::cpu::Parallelize2DTile1DDynamic(threadpool, function, context, range_i,
range_j, tile_j, flags);
}

extern "C" void pthreadpool_parallelize_2d_tile_1d_with_uarch(
pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_with_id_t function,
void* context, uint32_t default_uarch_index, uint32_t max_uarch_index,
Expand All @@ -328,6 +400,13 @@ extern "C" void pthreadpool_parallelize_2d_tile_2d(
LOG(FATAL) << "Not implemented";
}

extern "C" void pthreadpool_parallelize_2d_tile_2d_dynamic(
pthreadpool_t threadpool, pthreadpool_task_2d_tile_2d_dynamic_t function,
void* context, size_t range_i, size_t range_j, size_t tile_i, size_t tile_j,
uint32_t flags) {
LOG(FATAL) << "Not implemented";
}

extern "C" void pthreadpool_parallelize_2d_tile_2d_with_uarch(
pthreadpool_t threadpool, pthreadpool_task_2d_tile_2d_with_id_t function,
void* context, uint32_t default_uarch_index, uint32_t max_uarch_index,
Expand Down Expand Up @@ -383,6 +462,14 @@ extern "C" void pthreadpool_parallelize_3d_tile_2d(
range_k, tile_j, tile_k, flags);
}

extern "C" void pthreadpool_parallelize_3d_tile_2d_dynamic(
pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_dynamic_t function,
void* context, size_t range_i, size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k, uint32_t flags) {
xla::cpu::Parallelize3DTile2DDynamic(threadpool, function, context, range_i,
range_j, range_k, tile_j, tile_k, flags);
}

extern "C" void pthreadpool_parallelize_3d_tile_2d_with_uarch(
pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_with_id_t function,
void* context, uint32_t default_uarch_index, uint32_t max_uarch_index,
Expand Down

0 comments on commit d3e9664

Please sign in to comment.