You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
on the rocm-platfrom, jax.distributed with the cluster-autodetection mechanisms is not working as it is with cuda. Whereas in jax._src.distributed.initialize() the visible_devices are set properly for both 'cuda' and 'rocm', in the xla_bridge, only the cuda_visible_devices are queried upon client creation:
This results in node-local processes seeing all local devices which can lead to OOM errors in case of over-subscription or hangs or errors in rccl communication, depending on the specific setup. Thus atm only settings with a single process per node are possible with rocm, whereas with cuda also one process per device works.
To my understanding and tests this could be easily fixed by querying also the 'jax_rocm_visible_devices' in xla_bridge and I would be happy to provide a PR in case the current behavior is not the intended one? According to a small set of simple tests, also the whole gpu-mock setup that is enclosed in the relevant context linked above works with rocm, so the changes might really be minimal.
The text was updated successfully, but these errors were encountered:
Looking over the block of code there that is setting up the mock devices settings, yeah, I don't think this would be too big of change to make. I think it comes down to changing the block to also run if the plugin is rocm, and checking _ROCM_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES and then running through the same mock settings block.
I am not sure if anyone has tried running one process per device yet with jax+rocm, but I don't really see any reason it wouldn't work.
I don't really have a slurm/mpi cluster on hand to easily test this, so if you wanted to build the change and verify it on your environment and push it up as a PR that would be helpful. Thanks!
This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.
Fixesjax-ml#26298
Thanks for looking into this issue! Please find the changes in the PR mentioned above. I tested with A100 and MI300 with slurm and rocm is now consistent with cuda wrt devices being visible to processes.
on the rocm-platfrom,
jax.distributed
with the cluster-autodetection mechanisms is not working as it is with cuda. Whereas injax._src.distributed.initialize()
the visible_devices are set properly for both 'cuda' and 'rocm', in the xla_bridge, only the cuda_visible_devices are queried upon client creation:jax/jax/_src/xla_bridge.py
Line 643 in 124e123
This results in node-local processes seeing all local devices which can lead to OOM errors in case of over-subscription or hangs or errors in rccl communication, depending on the specific setup. Thus atm only settings with a single process per node are possible with rocm, whereas with cuda also one process per device works.
To my understanding and tests this could be easily fixed by querying also the 'jax_rocm_visible_devices' in xla_bridge and I would be happy to provide a PR in case the current behavior is not the intended one? According to a small set of simple tests, also the whole gpu-mock setup that is enclosed in the relevant context linked above works with rocm, so the changes might really be minimal.
The text was updated successfully, but these errors were encountered: