Skip to content

Commit

Permalink
Adds support for string and binary data processing in Colocated Pytho…
Browse files Browse the repository at this point in the history
…n with Pathways backend.

PiperOrigin-RevId: 726705362
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
1 parent 6da089d commit 5bee1be
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
13 changes: 8 additions & 5 deletions xla/python/pjrt_ifrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,14 @@ AssembleStringArrayFromSingleDeviceStringArrays(
return absl::InvalidArgumentError(
"All single device arrays must be BasicStringArrays");
}
if (!llvm::isa<SingleDeviceSharding>(basic_string_array->sharding())) {
return absl::InvalidArgumentError(absl::StrFormat(
"All single device arrays must have single device sharding. got: %s "
"for shard index: %d",
basic_string_array->sharding().DebugString(), i));

if (!llvm::isa<SingleDeviceSharding>(basic_string_array->sharding()) &&
(basic_string_array->sharding().devices()->size() != 1)) {
return absl::InvalidArgumentError(
absl::StrFormat("All single device arrays must have single device "
"sharding. got: %s "
"for shard index: %d",
basic_string_array->sharding().DebugString(), i));
}

basic_string_array->buffers().OnReady(
Expand Down
4 changes: 4 additions & 0 deletions xla/python/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
}

absl::StatusOr<ifrt::DType> DtypeToIfRtDType(nb_dtype dtype) {
// String does not have a corresponding XLA primitive type.
if (dtype.kind() == 'T') {
return ifrt::DType(ifrt::DType::kString);
}
TF_ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype));
return ifrt::ToDType(primitive_type);
}
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 314
_version = 315

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down

0 comments on commit 5bee1be

Please sign in to comment.