You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
2 card Intel(R) Data Center GPU Max 1550 (aka PVC), note: each card has 2 tiles, in total there are 4 torch devices available
test_model_parallel_beam_search tests for a number of models fail with "RuntimeError: Expected all tensors to be on the same device":
# TRANSFORMERS_TEST_DEVICE_SPEC=spec.py python3 -m pytest -k test_model_parallel_beam_search \
tests/models/aria \
tests/models/falcon_mamba \
tests/models/gpt2 \
tests/models/gpt_bigcode \
tests/models/idefics2 \
tests/models/imagegpt \
tests/models/instructblip \
tests/models/instructblipvideo \
tests/models/mamba \
tests/models/mbart \
tests/models/opt \
tests/models/qwen2_vl \
tests/models/xglm
...
FAILED tests/models/aria/test_modeling_aria.py::AriaForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:0!
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:2!
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMHAModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:0!
FAILED tests/models/imagegpt/test_modeling_imagegpt.py::ImageGPTModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/instructblip/test_modeling_instructblip.py::InstructBlipForConditionalGenerationDecoderOnlyTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:1!
FAILED tests/models/instructblipvideo/test_modeling_instructblipvideo.py::InstructBlipVideoForConditionalGenerationDecoderOnlyTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:1!
FAILED tests/models/mamba/test_modeling_mamba.py::MambaModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/mbart/test_modeling_mbart.py::MBartModelTest::test_model_parallel_beam_search - RuntimeError: Expected query, key, and value to have the same device type, but got query.device: xpu:1 key.device: xp...
FAILED tests/models/mllama/test_modeling_mllama.py::MllamaForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:3! (when c...
FAILED tests/models/opt/test_modeling_opt.py::OPTModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/qwen2_vl/test_modeling_qwen2_vl.py::Qwen2VLModelTest::test_model_parallel_beam_search - RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (xpu:1)
FAILED tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:2!
=============================== 15 failed, 2 passed, 4399 deselected, 2 warnings in 13.95s ===============================
From the log, errors occur on the following lines:
Fixing the following errors in few models:
```
> hidden_states = inputs_embeds + pos_embeds
E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3!
```
Fixes: huggingface#35762
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh
added a commit
to dvrogozh/transformers
that referenced
this issue
Jan 18, 2025
With:
On:
test_model_parallel_beam_search
tests for a number of models fail with "RuntimeError: Expected all tensors to be on the same device":From the log, errors occur on the following lines:
transformers/src/transformers/models/aria/modeling_aria.py
Line 1116 in 7d4b3dd
transformers/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Line 312 in 7d4b3dd
transformers/src/transformers/models/gpt2/modeling_gpt2.py
Line 821 in 7d4b3dd
transformers/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Line 962 in 7d4b3dd
transformers/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Line 962 in 7d4b3dd
transformers/src/transformers/models/idefics2/modeling_idefics2.py
Line 1307 in 7d4b3dd
transformers/src/transformers/models/imagegpt/modeling_imagegpt.py
Line 778 in 7d4b3dd
transformers/src/transformers/models/instructblip/modeling_instructblip.py
Line 1609 in 7d4b3dd
transformers/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
Line 1644 in 7d4b3dd
transformers/src/transformers/models/mamba/modeling_mamba.py
Line 264 in 7d4b3dd
transformers/src/transformers/models/mbart/modeling_mbart.py
Line 494 in 7d4b3dd
transformers/src/transformers/models/mllama/modeling_mllama.py
Line 1489 in 7d4b3dd
transformers/src/transformers/models/opt/modeling_opt.py
Line 885 in 7d4b3dd
transformers/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Line 1489 in 7d4b3dd
transformers/src/transformers/models/xglm/modeling_xglm.py
Line 597 in 7d4b3dd
CC: @SunMarc @ydshieh @faaany
The text was updated successfully, but these errors were encountered: