Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The convergence test test_mini_models_with_logits is failing with the latest transformers #543

Open
Tcc0403 opened this issue Jan 27, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 27, 2025

🐛 Describe the bug

In this convergence test, we compare losses and last step's logits to make sure two models (w/ and w/o monkey patch) can produce similar results.

For logits, LigerCrossEntropy stores logits gradients in itself to save memory (in-place operation). The reason why we could still collect the actual logits for the converrgence test is because we created shift_logits and shift_labels in CausalLMLoss context, i.e., we allocated extra tensors/memory to hold the shifted ones before passing the shifted one to LigerCrossEntropy. In other words, we performed in-place operations on these extra tensors, not the original logits tensors.

huggingface/transformers#35646 found a way to avoid extra memory allocations. However, it introduced a side-effect to liger's convergence test. The in-place property would make the test fail since it no longer allocate new tensors to track original logits. The test would compare logits from the vanilla model and logits gradients from the patched model instead.

Reproduce

Install transformers>= 8ebe9d7

pip install git+https://github.com/huggingface/transformers.git@8ebe9d7

Run test_mini_models_with_logits

python3 -m pytest test/convergence/test_mini_models_with_logits.py -v -rP

With some modifications, we can observe that two tensors differs significantly. They are actually logits and its gradients.

...
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] FAILED [  5%]
=============================================================== FAILURES ================================================================
__________________________ test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] ___________________________

model_name = 'mini_llama3', num_steps = 32, lr = 0.0001, dtype = torch.float32, loss_atol = 1e-08, loss_rtol = 2e-05
logits_atol = 0.0001, logits_rtol = 1e-05, param_atol = 0.005, param_rtol = 1e-05
...
(skipped for brevity) 
...
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 65530172
E           Mismatch at index (0, 0, 0): tensor1[(0, 0, 0)] = -0.6763022541999817, tensor2[(0, 0, 0)] = 2.3230221565806453e-11
E           Mismatch at index (0, 0, 1): tensor1[(0, 0, 1)] = 11.04751205444336, tensor2[(0, 0, 1)] = 2.8684196422545938e-06
E           Mismatch at index (0, 0, 2): tensor1[(0, 0, 2)] = 16.184457778930664, tensor2[(0, 0, 2)] = -3.9413275771948975e-06
E           Mismatch at index (0, 0, 3): tensor1[(0, 0, 3)] = -0.8037996292114258, tensor2[(0, 0, 3)] = 2.0449475446326915e-11
E           Mismatch at index (0, 0, 4): tensor1[(0, 0, 4)] = -0.6530527472496033, tensor2[(0, 0, 4)] = 2.3776682012144335e-11
E           ... and 65530167 more mismatched elements.

test/utils.py:118: AssertionError
--------------------------------------------------------- Captured stdout call ----------------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.651557922363281
Step 1, Loss: 2.179945468902588
Step 2, Loss: 0.8609191179275513
Step 3, Loss: 0.7009443044662476
Step 4, Loss: 0.6704152226448059
Step 5, Loss: 0.787455677986145
Step 6, Loss: 0.7168732285499573
Step 7, Loss: 0.974632978439331
Step 8, Loss: 0.6826125383377075
Step 9, Loss: 0.7657948136329651
Step 10, Loss: 0.7752344012260437
Step 11, Loss: 0.4787193238735199
Step 12, Loss: 0.708520770072937
Step 13, Loss: 0.6854400634765625
Step 14, Loss: 0.6153515577316284
Step 15, Loss: 0.7325770854949951
Step 16, Loss: 0.7106888294219971
Step 17, Loss: 0.7801141738891602
Step 18, Loss: 0.6898989081382751
Step 19, Loss: 0.7559919357299805
Step 20, Loss: 0.6201393008232117
Step 21, Loss: 0.38733774423599243
Step 22, Loss: 0.48507410287857056
Step 23, Loss: 0.46192264556884766
Step 24, Loss: 0.45433345437049866
Step 25, Loss: 0.42470914125442505
Step 26, Loss: 0.3781709372997284
Step 27, Loss: 0.5069137215614319
Step 28, Loss: 0.5949812531471252
Step 29, Loss: 0.3769552409648895
Step 30, Loss: 0.7420325875282288
Step 31, Loss: 0.4775380492210388
with_liger=False, output.logits[0]=tensor([[-0.6763, 11.0475, 16.1845,  ..., -0.6553, -0.0774,  0.2168],
        [-0.6783, 11.0428, 16.1866,  ..., -0.6574, -0.0813,  0.2185],
        [-0.6791, 11.0352, 16.1917,  ..., -0.6607, -0.0837,  0.2199],
        ...,
        [-0.8278, 11.5654, 15.6113,  ..., -0.7387,  0.1027,  0.1076],
        [-0.8261, 11.5653, 15.6117,  ..., -0.7419,  0.1008,  0.1118],
        [-0.7729,  1.9232,  2.2825,  ..., -1.0360, -0.8898, -0.4163]],
       device='cuda:0', grad_fn=<SelectBackward0>)
Liger kernel patches have been reverted.
Step 0, Loss: 10.651558876037598
Step 1, Loss: 2.179945468902588
Step 2, Loss: 0.860919177532196
Step 3, Loss: 0.7009443640708923
Step 4, Loss: 0.6704152822494507
Step 5, Loss: 0.7874558568000793
Step 6, Loss: 0.7168733477592468
Step 7, Loss: 0.9746332168579102
Step 8, Loss: 0.6826125383377075
Step 9, Loss: 0.7657949328422546
Step 10, Loss: 0.7752344012260437
Step 11, Loss: 0.47871941328048706
Step 12, Loss: 0.708520770072937
Step 13, Loss: 0.6854398250579834
Step 14, Loss: 0.6153515577316284
Step 15, Loss: 0.7325772643089294
Step 16, Loss: 0.7106888890266418
Step 17, Loss: 0.7801142334938049
Step 18, Loss: 0.6898989081382751
Step 19, Loss: 0.7559918761253357
Step 20, Loss: 0.6201393008232117
Step 21, Loss: 0.38733789324760437
Step 22, Loss: 0.4850741922855377
Step 23, Loss: 0.46192267537117004
Step 24, Loss: 0.4543333351612091
Step 25, Loss: 0.4247092008590698
Step 26, Loss: 0.3781709372997284
Step 27, Loss: 0.5069136619567871
Step 28, Loss: 0.5949810743331909
Step 29, Loss: 0.37695515155792236
Step 30, Loss: 0.742032527923584
Step 31, Loss: 0.47753798961639404
with_liger=True, output.logits[0]=tensor([[ 2.3230e-11,  2.8684e-06, -3.9413e-06,  ...,  2.3724e-11,
          4.2283e-11,  5.6743e-11],
        [ 2.3136e-11,  2.8491e-06, -3.9194e-06,  ...,  2.3626e-11,
          4.2030e-11,  5.6722e-11],
        [ 2.3000e-11,  2.8131e-06, -3.8772e-06,  ...,  2.3428e-11,
          4.1715e-11,  5.6513e-11],
        ...,
        [ 3.4954e-11,  8.4293e-06, -1.0283e-05,  ...,  3.8209e-11,
          8.8631e-11,  8.9067e-11],
        [ 3.5002e-11, -4.8370e-04,  4.8185e-04,  ...,  3.8076e-11,
          8.8438e-11,  8.9416e-11],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0',

Versions

none

@Tcc0403 Tcc0403 mentioned this issue Jan 27, 2025
3 tasks
@Tcc0403 Tcc0403 added the bug Something isn't working label Jan 27, 2025
@DandinPower
Copy link
Contributor

Because the only information we can get in LigerCrossEntropy is loss and logits gradient, the first idea that comes to my mind is that maybe we can rematerialize the logits using loss and logits gradient. However, since the softmax calculation is not reversible (different logits can produce the same softmax result, meaning we cannot recover the original logits from the softmax output), this approach would not work.

If we don't want to store additional logits information, the only way to compare CrossEntropy and LigerCrossEntropy is to compute the logits gradients for CrossEntropy using the logits and the last step's labels in the run_mini_model function. Finally, instead of comparing the logits in test_mini_model, we should compare the logits gradient to ensure correctness.

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Jan 29, 2025

Makes sense to me. With logits.grad, loss and all other params of models, I think it is enough to ensure the correctness of patching when they are all close. The only problem in this scenario is that the test shouldn't be named with_logits (might be just me nitpicking).

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Feb 1, 2025

Update: The failure doesn't happen in bf16 because logits.float() is called before CrossEntropy. Therefore, LigerCrossEntropy saves gradients in fp32 tensor created by .float() instead of the original bf16 tensor. But it means we can further reduce memory allocation in bf16 scenario since LigerCrossEntropy always upcasts logits to fp32 inside kernel. See #406.

We can write a wrapper function to patch ForCausalLMLoss instead of just nn.functional.CrossEntropy

@DandinPower
Copy link
Contributor

Sounds great! So, by patching the original ForCausalLMLoss and removing casting, we can prevent the additional FP32 logits allocation. However, this also means that logits information will not be available in all scenarios.

I just thought of something: replacing the logits space to store its gradients aligns with the original design idea of LigerCrossEntropy. However, in the original scenario, due to shifted logits and the overhead from tensor creation caused by casting, we still retained logits information—even when replacing the space by inplacing the tensor.

Now, by patching the entire ForCausalLMLoss, minimizing memory usage, and ensuring that logits information is no longer retained, this actually better fits the original design, right?

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Feb 1, 2025

Correct!

@jp1924
Copy link
Contributor

jp1924 commented Feb 4, 2025

Um... Are you guys working on this?
Or should I open a PR separately?

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Feb 4, 2025

For convergence test, #546 is ready to merge. We can open a separate PR for the LigerForCausalLMLoss part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants