-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The convergence test test_mini_models_with_logits
is failing with the latest transformers
#543
Comments
Because the only information we can get in If we don't want to store additional |
Makes sense to me. With |
Update: The failure doesn't happen in bf16 because logits.float() is called before CrossEntropy. Therefore, LigerCrossEntropy saves gradients in fp32 tensor created by .float() instead of the original bf16 tensor. But it means we can further reduce memory allocation in bf16 scenario since LigerCrossEntropy always upcasts logits to fp32 inside kernel. See #406. We can write a wrapper function to patch |
Sounds great! So, by patching the original ForCausalLMLoss and removing casting, we can prevent the additional FP32 logits allocation. However, this also means that logits information will not be available in all scenarios. I just thought of something: replacing the logits space to store its gradients aligns with the original design idea of LigerCrossEntropy. However, in the original scenario, due to shifted logits and the overhead from tensor creation caused by casting, we still retained logits information—even when replacing the space by inplacing the tensor. Now, by patching the entire ForCausalLMLoss, minimizing memory usage, and ensuring that logits information is no longer retained, this actually better fits the original design, right? |
Correct! |
Um... Are you guys working on this? |
For convergence test, #546 is ready to merge. We can open a separate PR for the LigerForCausalLMLoss part. |
🐛 Describe the bug
In this convergence test, we compare losses and last step's logits to make sure two models (w/ and w/o monkey patch) can produce similar results.
For logits, LigerCrossEntropy stores logits gradients in itself to save memory (in-place operation). The reason why we could still collect the actual logits for the converrgence test is because we created
shift_logits
andshift_labels
in CausalLMLoss context, i.e., we allocated extra tensors/memory to hold the shifted ones before passing the shifted one to LigerCrossEntropy. In other words, we performed in-place operations on these extra tensors, not the original logits tensors.huggingface/transformers#35646 found a way to avoid extra memory allocations. However, it introduced a side-effect to liger's convergence test. The in-place property would make the test fail since it no longer allocate new tensors to track original logits. The test would compare logits from the vanilla model and logits gradients from the patched model instead.
Reproduce
Install transformers>= 8ebe9d7
Run
test_mini_models_with_logits
With some modifications, we can observe that two tensors differs significantly. They are actually logits and its gradients.
Versions
none
The text was updated successfully, but these errors were encountered: