Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.distributed with slurm/mpi not working on rocm #26298

Open
ksebaz opened this issue Feb 4, 2025 · 2 comments
Open

jax.distributed with slurm/mpi not working on rocm #26298

ksebaz opened this issue Feb 4, 2025 · 2 comments
Assignees
Labels
AMD GPU Issues pertaining to AMD GPUs (ROCM)

Comments

@ksebaz
Copy link

ksebaz commented Feb 4, 2025

on the rocm-platfrom, jax.distributed with the cluster-autodetection mechanisms is not working as it is with cuda. Whereas in jax._src.distributed.initialize() the visible_devices are set properly for both 'cuda' and 'rocm', in the xla_bridge, only the cuda_visible_devices are queried upon client creation:

visible_devices = CUDA_VISIBLE_DEVICES.value

This results in node-local processes seeing all local devices which can lead to OOM errors in case of over-subscription or hangs or errors in rccl communication, depending on the specific setup. Thus atm only settings with a single process per node are possible with rocm, whereas with cuda also one process per device works.

To my understanding and tests this could be easily fixed by querying also the 'jax_rocm_visible_devices' in xla_bridge and I would be happy to provide a PR in case the current behavior is not the intended one? According to a small set of simple tests, also the whole gpu-mock setup that is enclosed in the relevant context linked above works with rocm, so the changes might really be minimal.

@dfm dfm added the AMD GPU Issues pertaining to AMD GPUs (ROCM) label Feb 4, 2025
@mrodden
Copy link
Collaborator

mrodden commented Feb 4, 2025

Looking over the block of code there that is setting up the mock devices settings, yeah, I don't think this would be too big of change to make. I think it comes down to changing the block to also run if the plugin is rocm, and checking _ROCM_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES and then running through the same mock settings block.

I am not sure if anyone has tried running one process per device yet with jax+rocm, but I don't really see any reason it wouldn't work.

I don't really have a slurm/mpi cluster on hand to easily test this, so if you wanted to build the change and verify it on your environment and push it up as a PR that would be helpful. Thanks!

ksebaz pushed a commit to ksebaz/jax that referenced this issue Feb 5, 2025
This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.

Fixes jax-ml#26298
@ksebaz
Copy link
Author

ksebaz commented Feb 5, 2025

Thanks for looking into this issue! Please find the changes in the PR mentioned above. I tested with A100 and MI300 with slurm and rocm is now consistent with cuda wrt devices being visible to processes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AMD GPU Issues pertaining to AMD GPUs (ROCM)
Projects
None yet
Development

No branches or pull requests

4 participants