diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index 464cca7ac0834..488e43409cab0 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -675,11 +675,14 @@ AssembleStringArrayFromSingleDeviceStringArrays( return absl::InvalidArgumentError( "All single device arrays must be BasicStringArrays"); } - if (!llvm::isa(basic_string_array->sharding())) { - return absl::InvalidArgumentError(absl::StrFormat( - "All single device arrays must have single device sharding. got: %s " - "for shard index: %d", - basic_string_array->sharding().DebugString(), i)); + + if (!llvm::isa(basic_string_array->sharding()) && + (basic_string_array->sharding().devices()->size() != 1)) { + return absl::InvalidArgumentError( + absl::StrFormat("All single device arrays must have single device " + "sharding. got: %s " + "for shard index: %d", + basic_string_array->sharding().DebugString(), i)); } basic_string_array->buffers().OnReady( diff --git a/xla/python/types.cc b/xla/python/types.cc index 5f71a0121a7a5..4a184f0090351 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -378,6 +378,10 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { } absl::StatusOr DtypeToIfRtDType(nb_dtype dtype) { + // String does not have a corresponding XLA primitive type. + if (dtype.kind() == 'T') { + return ifrt::DType(ifrt::DType::kString); + } TF_ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype)); return ifrt::ToDType(primitive_type); } diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index b7341007d04cc..71c15643e4cfa 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 314 +_version = 315 # Version number for MLIR:Python components. mlir_api_version = 57