Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit(mutable_array) should raise an error #26349

Open
ayaka14732 opened this issue Feb 6, 2025 · 1 comment
Open

jit(mutable_array) should raise an error #26349

ayaka14732 opened this issue Feb 6, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

import jax
import jax.numpy as jnp
from jax._src.core import mutable_array

a = jnp.int32(0)
a_ref = mutable_array(a)
print(type(a_ref))

a_ref = jax.jit(mutable_array)(a)
print(type(a_ref))

Output:

<class 'jax._src.core.MutableArray'>
<class 'jaxlib.xla_extension.ArrayImpl'>

However, this should be an error because we don't allow returning a mutable array from a function.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.13.0rc3 (main, Oct  2 2024, 17:18:08) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.10.11-1rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.11-1rodete2 (2024-10-16)', machine='x86_64')
@mattjj
Copy link
Collaborator

mattjj commented Feb 6, 2025

This raises an error if you set the JAX_MUTABLE_ARRAY_CHECKS option:

name='jax_mutable_array_checks',

$ JAX_MUTABLE_ARRAY_CHECKS=1 python 26349.py
<class 'jax._src.core.MutableArray'>
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/26349.py", line 9, in <module>
    a_ref = jax.jit(mutable_array)(a)
            ^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: function mutable_array at /usr/local/google/home/mattjj/packages/jax/jax/_src/core.py:2076 traced for jit returned a mutable array reference of type Ref{int32[]} at output tree path , but mutable array references cannot be returned.

The returned mutable array was created on line /usr/local/google/home/mattjj/packages/jax/26349.py:9:8 (<module>).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

(I just noticed the "at output tree path " part is busted!)

We should set the flag default to on. I suppose we can leave this issue open until we do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants