From 8fddcafc2e9880e0d6d5ea5f82948020b5349447 Mon Sep 17 00:00:00 2001 From: xla authors Date: Fri, 14 Feb 2025 13:38:41 -0800 Subject: [PATCH] Adds support for string and binary data processing in Colocated Python. PiperOrigin-RevId: 727048049 --- xla/python/pjrt_ifrt/pjrt_client.cc | 13 ++++++++----- xla/python/types.cc | 4 ++++ xla/python/xla_client.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index 464cca7ac0834..488e43409cab0 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -675,11 +675,14 @@ AssembleStringArrayFromSingleDeviceStringArrays( return absl::InvalidArgumentError( "All single device arrays must be BasicStringArrays"); } - if (!llvm::isa(basic_string_array->sharding())) { - return absl::InvalidArgumentError(absl::StrFormat( - "All single device arrays must have single device sharding. got: %s " - "for shard index: %d", - basic_string_array->sharding().DebugString(), i)); + + if (!llvm::isa(basic_string_array->sharding()) && + (basic_string_array->sharding().devices()->size() != 1)) { + return absl::InvalidArgumentError( + absl::StrFormat("All single device arrays must have single device " + "sharding. got: %s " + "for shard index: %d", + basic_string_array->sharding().DebugString(), i)); } basic_string_array->buffers().OnReady( diff --git a/xla/python/types.cc b/xla/python/types.cc index 5f71a0121a7a5..4a184f0090351 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -378,6 +378,10 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { } absl::StatusOr DtypeToIfRtDType(nb_dtype dtype) { + // String does not have a corresponding XLA primitive type. + if (dtype.kind() == 'T') { + return ifrt::DType(ifrt::DType::kString); + } TF_ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype)); return ifrt::ToDType(primitive_type); } diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index b7341007d04cc..71c15643e4cfa 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 314 +_version = 315 # Version number for MLIR:Python components. mlir_api_version = 57