Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialization of executables fails on non-zero ranks when deserializing single-device executable #18286

Open
jaro-sevcik opened this issue Oct 14, 2024 · 1 comment · Fixed by #18663 or tensorflow/tensorflow#78608

Comments

@jaro-sevcik
Copy link
Contributor

If we deserialize executables via the wrapper C API client, the compile options are ignored. In practice, this means that JAX compilation cache fails when deserializing executables for rank zero on non-zero ranks.

JAX repro:

import jax
import jax.numpy as jnp
import logging
import argparse
import socket

parser = argparse.ArgumentParser(
                    prog='mock-test',
                    description='Tests mocking',
                    epilog='...')
parser.add_argument('-r', '--rank', type=int)
args = parser.parse_args()

if args.rank == 1:
  logging.basicConfig(format='%(asctime)s %(message)s')
  logging.getLogger("jax._src.compiler").setLevel(logging.DEBUG)

jax.config.update("jax_compilation_cache_dir", "/tmp/compilation_cache")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

jax.distributed.initialize(socket.gethostname() + ":1234", 2, args.rank, local_device_ids = [args.rank])

f = jax.jit(lambda x: x)
jax.block_until_ready(f(jnp.zeros((1,))))

Run (python test.py -r0 &) && (python test.py -r1) twice on a machine with 2+ GPUs.

The output (from rank 1) then contains the following error message:

/opt/jax/jax/_src/compiler.py:691: UserWarning: Error reading persistent compilation cache entry for 'jit__lambda_':
XlaRuntimeError: INVALID_ARGUMENT: Device assignment (Computations: 1 Replicas: 1
Computation 0: 0
) does not have any local devices.
  warnings.warn(
2024-10-14 09:16:30,222 PERSISTENT COMPILATION CACHE MISS for 'jit__lambda_' with key 'jit__lambda_-502ff86f0064419e429f73e9641f94cc3ab91a275910dec17b3ba6186556a297'
copybara-service bot pushed a commit that referenced this issue Oct 23, 2024
Imported from GitHub PR #18287

This enables JAX to supply different device assignment when deserializing single-device executables from compilation cache.

Fixes #18286.
Copybara import of the project:

--
35b505b by Jaroslav Sevcik <[email protected]>:

Pass the compile options for deserialization via PJRT C API

--
c983f61 by Jaroslav Sevcik <[email protected]>:

Add compile options comment, reorder fields

--
182b4a6 by Jaroslav Sevcik <[email protected]>:

Fix a little use-after-free

--
cee9982 by Jaroslav Sevcik <[email protected]>:

Rename field, improve comments

--
2b9fbd9 by Jaroslav Sevcik <[email protected]>:

Bump minor version, changelog update

Merging this change closes #18287

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18287 from jaro-sevcik:deserialize-compile-options 2b9fbd9
PiperOrigin-RevId: 686893447
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 23, 2024
Imported from GitHub PR openxla/xla#18287

This enables JAX to supply different device assignment when deserializing single-device executables from compilation cache.

Fixes openxla/xla#18286.
Copybara import of the project:

--
35b505bf81e90fa6762c6d45d5abb99028032769 by Jaroslav Sevcik <[email protected]>:

Pass the compile options for deserialization via PJRT C API

--
c983f61f9c61e9e1ed52f3676d590d83a2455e63 by Jaroslav Sevcik <[email protected]>:

Add compile options comment, reorder fields

--
182b4a6da6f9e584bc84ec84f88b88d9add508db by Jaroslav Sevcik <[email protected]>:

Fix a little use-after-free

--
cee998286ae98cc29e2096fd52bd22caac67a5fe by Jaroslav Sevcik <[email protected]>:

Rename field, improve comments

--
2b9fbd992458795c1d5bf4da5c41e69c2b55df55 by Jaroslav Sevcik <[email protected]>:

Bump minor version, changelog update

Merging this change closes #18287

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18287 from jaro-sevcik:deserialize-compile-options 2b9fbd992458795c1d5bf4da5c41e69c2b55df55
PiperOrigin-RevId: 686893447
copybara-service bot pushed a commit that referenced this issue Oct 23, 2024
Imported from GitHub PR #18287

This enables JAX to supply different device assignment when deserializing single-device executables from compilation cache.

Fixes #18286.
Copybara import of the project:

--
35b505b by Jaroslav Sevcik <[email protected]>:

Pass the compile options for deserialization via PJRT C API

--
c983f61 by Jaroslav Sevcik <[email protected]>:

Add compile options comment, reorder fields

--
182b4a6 by Jaroslav Sevcik <[email protected]>:

Fix a little use-after-free

--
cee9982 by Jaroslav Sevcik <[email protected]>:

Rename field, improve comments

--
2b9fbd9 by Jaroslav Sevcik <[email protected]>:

Bump minor version, changelog update

Merging this change closes #18287

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18287 from jaro-sevcik:deserialize-compile-options 2b9fbd9
PiperOrigin-RevId: 686893447
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 23, 2024
Imported from GitHub PR openxla/xla#18287

This enables JAX to supply different device assignment when deserializing single-device executables from compilation cache.

Fixes openxla/xla#18286.
Copybara import of the project:

--
35b505bf81e90fa6762c6d45d5abb99028032769 by Jaroslav Sevcik <[email protected]>:

Pass the compile options for deserialization via PJRT C API

--
c983f61f9c61e9e1ed52f3676d590d83a2455e63 by Jaroslav Sevcik <[email protected]>:

Add compile options comment, reorder fields

--
182b4a6da6f9e584bc84ec84f88b88d9add508db by Jaroslav Sevcik <[email protected]>:

Fix a little use-after-free

--
cee998286ae98cc29e2096fd52bd22caac67a5fe by Jaroslav Sevcik <[email protected]>:

Rename field, improve comments

--
2b9fbd992458795c1d5bf4da5c41e69c2b55df55 by Jaroslav Sevcik <[email protected]>:

Bump minor version, changelog update

Merging this change closes #18287

PiperOrigin-RevId: 688990985
@loislo loislo reopened this Nov 5, 2024
@loislo
Copy link
Member

loislo commented Nov 5, 2024

the pr was reverted due to the crashes in production

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment