Skip to content

Commit

Permalink
[xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726185954
  • Loading branch information
danielsuo authored and Google-ML-Automation committed Feb 12, 2025
1 parent ae9d50a commit 8332b0c
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 2 deletions.
87 changes: 85 additions & 2 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ cc_library(
"py_device.cc",
"py_device_list.cc",
"py_executable.cc",
"py_host_callback.cc",
"py_memory_space.cc",
"py_program.cc",
"py_values.cc",
Expand All @@ -310,7 +309,6 @@ cc_library(
"py_device.h",
"py_device_list.h",
"py_executable.h",
"py_host_callback.h",
"py_memory_space.h",
"py_program.h",
"py_values.h",
Expand All @@ -333,6 +331,8 @@ cc_library(
":nb_helpers",
":nb_numpy",
":pprof_profile_builder",
":py_client_cpu",
":py_host_callback",
":py_host_callback_proto_cc",
":python_ref_manager",
":traceback",
Expand Down Expand Up @@ -427,6 +427,48 @@ cc_library(
] + if_google(["@com_google_protobuf//:any_cc_proto"]),
)

cc_library(
name = "py_host_callback",
srcs = ["py_host_callback.cc"],
hdrs = ["py_host_callback.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":callback",
":py_host_callback_proto_cc",
":python_ref_manager",
":types",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/pjrt:host_callback",
"//xla/pjrt:pjrt_compiler",
"//xla/python/ifrt",
"//xla/python/pjrt_ifrt",
"//xla/python/pjrt_ifrt:xla_host_callback_proto_cc",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
] + if_google([
"@com_google_protobuf//:any_cc_proto",
]),
)

cc_library(
name = "callback",
srcs = [
Expand All @@ -446,6 +488,7 @@ cc_library(
":python_ref_manager",
"//xla:comparison_util",
"//xla:xla_data_proto_cc",
"//xla/ffi",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
"//xla/service:custom_call_status",
Expand All @@ -462,6 +505,46 @@ cc_library(
],
)

cc_library(
name = "py_client_cpu",
srcs = ["py_client_cpu.cc"],
hdrs = ["py_client_cpu.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":callback",
":nb_numpy",
":py_host_callback",
":types",
"//xla:comparison_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
"//xla/python/ifrt",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:platform_util",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
"@tsl//tsl/platform:errors",
],
)

cc_library(
name = "py_client_gpu",
srcs = if_google(
Expand Down
13 changes: 13 additions & 0 deletions xla/python/callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "xla/ffi/ffi.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/transpose.h"
#include "xla/primitive_util.h"
Expand Down Expand Up @@ -181,4 +182,16 @@ void XlaPythonCpuCallback(void* output, void** inputs,
}
}

absl::StatusOr<nb::tuple> CpuCallback::FfiCall(nb::tuple args) {
nb::tuple result_tuple;
try {
auto result_object = callable_(*nb::borrow<nb::args>(args));
result_tuple = nb::cast<nb::tuple>(result_object);
} catch (nb::python_error& e) {
return absl::InternalError(
absl::StrFormat("CpuCallback error calling callback: %s", e.what()));
}
return result_tuple;
}

} // namespace xla
2 changes: 2 additions & 0 deletions xla/python/callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class CpuCallback {

absl::StatusOr<nanobind::tuple> Call(nanobind::tuple args);

absl::StatusOr<nanobind::tuple> FfiCall(nanobind::tuple args);

private:
nanobind::callable callable_;
std::vector<Arg> args_;
Expand Down
136 changes: 136 additions & 0 deletions xla/python/py_client_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/py_client_cpu.h"

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/transpose.h"
#include "xla/primitive_util.h"
#include "xla/python/callback.h"
#include "xla/python/ifrt/host_callback.h"
#include "xla/python/nb_numpy.h"
#include "xla/python/py_host_callback.h"
#include "xla/python/types.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"

namespace nb = nanobind;

namespace xla {

absl::Status XlaFfiPythonCpuCallback(
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks,
uint64_t index, ffi::RemainingArgs args, ffi::RemainingRets rets) {
auto loaded_callback = llvm::dyn_cast_or_null<PyCpuLoadedHostCallback>(
callbacks->at(index).get());
if (loaded_callback == nullptr) {
return absl::InternalError(
"Expected a PyCpuLoadedHostCallback, got something else.");
}
CpuCallback* callback = loaded_callback->cpu_callback();

nb::gil_scoped_acquire gil;
auto nb_args = nb::steal<nb::tuple>(PyTuple_New(args.size()));
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args.get<ffi::AnyBuffer>(i);
auto ptype = arg->element_type();
if (ptype == TOKEN) {
PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr());
} else {
TF_ASSIGN_OR_RETURN(auto dtype, PrimitiveTypeToNbDtype(ptype));
// We pass in data using default numpy layout i.e., std::nullopt.
auto array = nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt,
arg.value().untyped_data());
array.attr("flags").attr("writeable") = nb::bool_(false);
PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr());
}
}

EnterHostCallback();
// TODO(dsuo): Change this to use the Python vectorcall protocol, which allows
// you to avoid constructing a tuple for the arguments.
absl::StatusOr<nb::tuple> maybe_result_tuple =
callback->FfiCall(std::move(nb_args));
LeaveHostCallback();
TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple);

for (size_t i = 0; i < rets.size(); ++i) {
auto arg = rets.get<ffi::AnyBuffer>(i).value();
auto ptype = arg->element_type();
if (ptype == TOKEN) continue;
nb::object output =
nb::borrow<nb::object>(PyTuple_GetItem(result_tuple.ptr(), i));
nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output));
absl::Span<int64_t const> strides(
reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
// We expect the output to be in default numpy layout.
TF_ASSIGN_OR_RETURN(auto expected_shape, ShapeUtil::MakeValidatedShape(
ptype, arg->dimensions()));
auto expected_strides = ByteStridesForShape(expected_shape);
if (strides == expected_strides) {
std::memcpy(arg->untyped_data(), array.data(), arg->size_bytes());
} else {
xla::TransposePlan::Options options;
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
absl::Span<int64_t const> dims(
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
options.dims = dims;
absl::InlinedVector<int64_t, 4> reversed_layout;
reversed_layout.resize(expected_shape.dimensions_size());
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
options.input_layout = xla::TransposePlan::Striding{strides};
TF_ASSIGN_OR_RETURN(auto plan,
callback->transpose_cache().GetOrCreate(options));
plan->Execute(array.data(), arg->untyped_data());
}
}

return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback,
ffi::Ffi::Bind()
.Ctx<ffi::UserData<
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>>>()
.Attr<uint64_t>("index")
.RemainingArgs()
.RemainingRets());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback",
"HOST", kXlaFfiPythonCpuCallback);

} // namespace xla
27 changes: 27 additions & 0 deletions xla/python/py_client_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_PY_CLIENT_CPU_H_
#define XLA_PYTHON_PY_CLIENT_CPU_H_

#include "xla/ffi/ffi.h"

namespace xla {

XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback);

} // namespace xla

#endif // XLA_PYTHON_PY_CLIENT_CPU_H_
2 changes: 2 additions & 0 deletions xla/python/py_host_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class PyCpuLoadedHostCallback final
return absl::bit_cast<uint64_t>(cpu_callback_.get());
}

CpuCallback* cpu_callback() { return cpu_callback_.get(); }

// LoadedHostCallback implementation.

~PyCpuLoadedHostCallback() override = default;
Expand Down

0 comments on commit 8332b0c

Please sign in to comment.