Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-gpu: test_model_parallel_beam_search tests fail with "RuntimeError: Expected all tensors to be on the same device" #35762

Open
dvrogozh opened this issue Jan 18, 2025 · 1 comment · May be fixed by #35763

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Jan 18, 2025

With:

On:

  • 2 card Intel(R) Data Center GPU Max 1550 (aka PVC), note: each card has 2 tiles, in total there are 4 torch devices available

test_model_parallel_beam_search tests for a number of models fail with "RuntimeError: Expected all tensors to be on the same device":

# TRANSFORMERS_TEST_DEVICE_SPEC=spec.py python3 -m pytest -k test_model_parallel_beam_search \
  tests/models/aria \
  tests/models/falcon_mamba \
  tests/models/gpt2 \
  tests/models/gpt_bigcode \
  tests/models/idefics2 \
  tests/models/imagegpt \
  tests/models/instructblip \
  tests/models/instructblipvideo \
  tests/models/mamba \
  tests/models/mbart \
  tests/models/opt \
  tests/models/qwen2_vl \
  tests/models/xglm
...
FAILED tests/models/aria/test_modeling_aria.py::AriaForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:0!
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:2!
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMHAModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:0!
FAILED tests/models/imagegpt/test_modeling_imagegpt.py::ImageGPTModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/instructblip/test_modeling_instructblip.py::InstructBlipForConditionalGenerationDecoderOnlyTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:1!
FAILED tests/models/instructblipvideo/test_modeling_instructblipvideo.py::InstructBlipVideoForConditionalGenerationDecoderOnlyTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:1!
FAILED tests/models/mamba/test_modeling_mamba.py::MambaModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/mbart/test_modeling_mbart.py::MBartModelTest::test_model_parallel_beam_search - RuntimeError: Expected query, key, and value to have the same device type, but got query.device: xpu:1 key.device: xp...
FAILED tests/models/mllama/test_modeling_mllama.py::MllamaForConditionalGenerationModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:3! (when c...
FAILED tests/models/opt/test_modeling_opt.py::OPTModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and xpu:1!
FAILED tests/models/qwen2_vl/test_modeling_qwen2_vl.py::Qwen2VLModelTest::test_model_parallel_beam_search - RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (xpu:1)
FAILED tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_model_parallel_beam_search - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:1 and xpu:2!
=============================== 15 failed, 2 passed, 4399 deselected, 2 warnings in 13.95s ===============================

From the log, errors occur on the following lines:

padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]

hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)

hidden_states = inputs_embeds + position_embeds

hidden_states = inputs_embeds + position_embeds

hidden_states = inputs_embeds + position_embeds

new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states

hidden_states = inputs_embeds + position_embeds

inputs_embeds[special_image_mask] = language_model_inputs.flatten()

inputs_embeds[special_image_mask] = language_model_inputs.flatten()

hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)

attn_output = torch.nn.functional.scaled_dot_product_attention(

hidden_state = torch.cat([class_embedding, hidden_state], dim=1)

hidden_states = inputs_embeds + pos_embeds

input_ids = input_ids[attention_mask[i] == 1]

hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length)

CC: @SunMarc @ydshieh @faaany

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Jan 18, 2025
Fixing the following errors in few models:
```
>       hidden_states = inputs_embeds + pos_embeds
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:2 and xpu:3!
```

Fixes: huggingface#35762
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Jan 18, 2025
@dvrogozh
Copy link
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant