Skip to content

Commit

Permalink
[xla:gpu] Implement XLA FFI handlers for GPU Jax callbacks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726451933
  • Loading branch information
danielsuo authored and Google-ML-Automation committed Feb 13, 2025
1 parent a4c2a4b commit 7c8a422
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 1 deletion.
12 changes: 12 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -567,19 +567,31 @@ cc_library(
deps = [
":callback",
":nb_numpy",
":py_host_callback",
":types",
"//xla:comparison_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
"//xla/python/ifrt",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:platform_util",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
] + if_rocm(
["@local_config_rocm//rocm:rocm_headers"],
Expand Down
137 changes: 136 additions & 1 deletion xla/python/py_client_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,46 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/service/custom_call_status.h"
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_runtime.h"
#else
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/driver_types.h"
#endif
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/transpose.h"
#include "xla/primitive_util.h"
#include "xla/python/callback.h"
#include "xla/python/ifrt/host_callback.h"
#include "xla/python/nb_numpy.h"
#include "xla/python/py_host_callback.h"
#include "xla/python/types.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"
#if TENSORFLOW_USE_ROCM
#define gpuSuccess hipSuccess
#define gpuStreamHandle hipStream_t
Expand Down Expand Up @@ -173,4 +187,125 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"xla_python_gpu_callback", &XlaPythonGpuCallback,
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()));

absl::Status XlaFfiPythonGpuCallback(
gpuStreamHandle stream,
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks,
uint64_t index, ffi::RemainingArgs args, ffi::RemainingRets rets) {
auto loaded_callback = llvm::dyn_cast_or_null<PyCpuLoadedHostCallback>(
callbacks->at(index).get());
if (loaded_callback == nullptr) {
return absl::InternalError(
"Expected a PyCpuLoadedHostCallback, got something else.");
}
CpuCallback* callback = loaded_callback->cpu_callback();
size_t arity = args.size();
std::vector<void*> host_input_buffers(arity);
// Copy input GPU buffers to host
for (size_t i = 0; i < arity; ++i) {
auto arg = args.get<ffi::AnyBuffer>(i);
if (arg->element_type() == TOKEN) {
host_input_buffers[i] = nullptr;
continue;
}
void* buf = new char[arg->size_bytes()];
host_input_buffers[i] = buf;
// TODO(b/238441608): Use pinned memory here to speed up the transfer.
auto gpu_res =
gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(),
gpuMemcpyDeviceToHost, stream);
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
}
CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess)
<< "Failed to gpuStreamSynchronize";
nb::gil_scoped_acquire gil;
nb::tuple host_input_arrays = nb::steal<nb::tuple>(PyTuple_New(arity));
for (size_t i = 0; i < arity; ++i) {
auto arg = args.get<ffi::AnyBuffer>(i);
PrimitiveType ptype = arg->element_type();
if (ptype == TOKEN) {
PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr());
} else {
nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept {
delete[] static_cast<char*>(ptr);
});
TF_ASSIGN_OR_RETURN(auto dtype, PrimitiveTypeToNbDtype(ptype));
auto array = nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt,
host_input_buffers[i], base);
array.attr("flags").attr("writeable") = nb::bool_(false);
PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr());
}
}

EnterHostCallback();
// TODO(dsuo): Change this to use the Python vectorcall protocol, which allows
// you to avoid constructing a tuple for the arguments.
absl::StatusOr<nb::tuple> maybe_result_tuple =
callback->FfiCall(host_input_arrays);
LeaveHostCallback();
TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple);

std::vector<void*> temp_buffers;
for (size_t i = 0; i < rets.size(); ++i) {
auto ret = rets.get<ffi::AnyBuffer>(i).value();
auto ptype = ret->element_type();
if (ptype == TOKEN) continue;
nb::object output =
nb::borrow<nb::object>(PyTuple_GetItem(result_tuple.ptr(), i));
nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output));
absl::Span<int64_t const> strides(
reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
// We expect the output to be in default numpy layout.
TF_ASSIGN_OR_RETURN(auto expected_shape, ShapeUtil::MakeValidatedShape(
ptype, ret->dimensions()));
auto expected_strides = ByteStridesForShape(expected_shape);
if (strides == expected_strides) {
auto gpu_res =
gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(),
gpuMemcpyHostToDevice, stream);
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
} else {
void* temp = new char[ret->size_bytes()];
temp_buffers.push_back(temp);
xla::TransposePlan::Options options;
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
absl::Span<int64_t const> dims(
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
options.dims = dims;
absl::InlinedVector<int64_t, 4> reversed_layout;
reversed_layout.resize(expected_shape.dimensions_size());
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
options.input_layout = xla::TransposePlan::Striding{strides};
TF_ASSIGN_OR_RETURN(auto plan,
callback->transpose_cache().GetOrCreate(options));
plan->Execute(array.data(), temp);
auto gpu_res =
gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(),
gpuMemcpyHostToDevice, stream);
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
}
}
nb::gil_scoped_release release;
CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess)
<< "Failed to gpuStreamSynchronize";
for (int i = 0; i < temp_buffers.size(); ++i) {
delete[] static_cast<char*>(temp_buffers[i]);
}
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStreamHandle>>()
.Ctx<ffi::UserData<
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>>>()
.Attr<uint64_t>("index")
.RemainingArgs()
.RemainingRets());
XLA_FFI_REGISTER_HANDLER(
ffi::GetXlaFfiApi(), "xla_ffi_python_gpu_callback",
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()),
kXlaFfiPythonGpuCallback);
} // namespace xla
3 changes: 3 additions & 0 deletions xla/python/py_client_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#else
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#include "xla/ffi/ffi.h"
#include "xla/service/custom_call_status.h"

#if TENSORFLOW_USE_ROCM
Expand All @@ -35,6 +36,8 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status);

XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback);

} // namespace xla

#endif // XLA_PYTHON_PY_CLIENT_GPU_H_

0 comments on commit 7c8a422

Please sign in to comment.