Skip to content

Commit

Permalink
[Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Browse files Browse the repository at this point in the history
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Feb 4, 2025
1 parent 654a2f6 commit 124e123
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11649,7 +11649,8 @@ def replace(tup, val):
j = 0
for i in range(rank):
if i == axis_int:
indices = _normalize_index(indices, axis_size)
if mode != 'promise_in_bounds':
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,11 @@ def _gather_lowering_rule(
slice_sizes == (1, 1)
and not unique_indices
and not indices_are_sorted
and mode == lax.GatherScatterMode.FILL_OR_DROP
and mode
in (
lax.GatherScatterMode.FILL_OR_DROP,
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
)
):
if dimension_numbers == lax.GatherDimensionNumbers(
offset_dims=(),
Expand Down
10 changes: 5 additions & 5 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,10 @@ def kernel(x_ref, o_ref):
ref = jax.jit(lambda x: round_fn(x).astype(target))(x)
np.testing.assert_array_equal(out, ref)

@parameterized.product(axis=[0, 1])
def test_dynamic_gather_along_axis(self, axis):
if not jtu.if_cloud_tpu_at_least(2025, 2, 3):
self.skipTest("Requires libtpu built after 2025-02-03")
@parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None])
def test_dynamic_gather_along_axis(self, axis, mode):
if not jtu.if_cloud_tpu_at_least(2025, 2, 5):
self.skipTest("Requires libtpu built after 2025-02-05")
if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or (
axis == 1 and not jtu.is_device_tpu_at_least(version=4)
):
Expand All @@ -401,7 +401,7 @@ def test_dynamic_gather_along_axis(self, axis):
shape = (8, 128)

def kernel(x, indices, out):
out[...] = jnp.take_along_axis(x[...], indices[...], axis)
out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode)

x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
idx = jax.random.randint(
Expand Down

0 comments on commit 124e123

Please sign in to comment.