Skip to content

Commit

Permalink
Only cache jax.Array._npy_value when a copy is required.
Browse files Browse the repository at this point in the history
As discovered in jax-ml/jax#26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.

This fixes jax-ml/jax#26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.

PiperOrigin-RevId: 726522955
  • Loading branch information
dfm authored and Google-ML-Automation committed Feb 13, 2025
1 parent d7a8923 commit 632cb10
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
19 changes: 15 additions & 4 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/optional.h" // IWYU pragma: keep
#include "nanobind/stl/pair.h" // IWYU pragma: keep
#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
Expand Down Expand Up @@ -822,14 +823,20 @@ absl::Status PyArray::BlockUntilResultStatusIsReady() {
return result_status.Await();
}

absl::StatusOr<nb::object> PyArray::SingleDeviceArrayToNumpyArray() {
absl::StatusOr<std::pair<nb::object, bool>>
PyArray::SingleDeviceArrayToNumpyArrayDidCopy() {
TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard());
auto result = arr.GetStorage().host_value.AsNumPyArray(
arr.GetStorage().dynamic_shape, arr.ifrt_array());
TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady());
return result;
}

absl::StatusOr<nb::object> PyArray::SingleDeviceArrayToNumpyArray() {
TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy());
return result.first;
}

absl::Status PyArray::CopySingleDeviceArrayToHostAsync() {
TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard());
return arr.GetStorage().host_value.CopyToHostAsync(
Expand Down Expand Up @@ -1679,7 +1686,7 @@ bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) {
PyHostValue::PyHostValue() = default;
PyHostValue::~PyHostValue() = default;

absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
absl::StatusOr<std::pair<nb::object, bool>> PyHostValue::AsNumPyArray(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
if (ifrt_array->IsDeleted()) {
return InvalidArgument("DeviceArray has been deleted.");
Expand Down Expand Up @@ -1724,7 +1731,7 @@ absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
nb_numpy_ndarray array(dtype, shape->dimensions(),
ByteStridesForShape(*shape), data, hold_capsule);
array.attr("flags").attr("writeable") = nb::bool_(false);
return array;
return std::make_pair(array, false);
}
}

Expand All @@ -1738,7 +1745,7 @@ absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
if (string_array_contents_ != nullptr) {
TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array));
}
return value_;
return std::make_pair(value_, true);
}

absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray(
Expand Down Expand Up @@ -1988,9 +1995,13 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) {
type.attr("on_device_size_in_bytes") = nb::cpp_function(
xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes),
nb::is_method());
// TODO(danfm): Remove this after JAX 0.5.1 release (or sooner!).
type.attr("_single_device_array_to_np_array") = nb::cpp_function(
xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArray),
nb::is_method());
type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function(
xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy),
nb::is_method());
type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function(
[](PyArray& self) {
xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync());
Expand Down
4 changes: 3 additions & 1 deletion xla/python/py_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class PyHostValue {
absl::Status CopyToHostAsync(std::optional<Shape>& dynamic_shape_holder,
ifrt::Array* ifrt_array);

absl::StatusOr<nanobind::object> AsNumPyArray(
absl::StatusOr<std::pair<nanobind::object, bool>> AsNumPyArray(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

private:
Expand Down Expand Up @@ -280,6 +280,8 @@ class PyArray : public nanobind::object {
absl::Status BlockUntilResultStatusIsReady();

absl::StatusOr<size_t> GetOnDeviceSizeInBytes();
absl::StatusOr<std::pair<nanobind::object, bool>>
SingleDeviceArrayToNumpyArrayDidCopy();
absl::StatusOr<nanobind::object> SingleDeviceArrayToNumpyArray();
absl::Status CopySingleDeviceArrayToHostAsync();
nanobind::dict CudaArrayInterface();
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 313
_version = 314

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down

0 comments on commit 632cb10

Please sign in to comment.