-
Notifications
You must be signed in to change notification settings - Fork 883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama/unshard on load #174
Conversation
I love this change!! Could you rebase on main and then I will review? |
Sure, after taking a brief look at the changes, I guess the new quantization makes this more complicated. It expects the entire model to be loaded. I could think of moving the the quantization into its own script, so you would run
If we really want to do quantize in convert.py it would require the unsharding there I guess. I don't know enough about the new Given that quantization should enable smaller machines, it would be nice if we could do the quantization without merging all the unquantized weights in memory first. Not sure though, if there is a clever way to achieve this. |
41c8efe
to
dff87bc
Compare
@awni I think I found a good way to refactor it to support quantize in convert.py. It will still unshard for quantization, but keeps shard loading and conversion lazy and memory friendly. Looking forward to your review. |
@dastrobu I like where this is going, but I suggest we reorganize the computation to avoid the need to unshard in the final loading script. Here's my suggestion:
Does that make sense? |
So your changes to |
a7d08be
to
7f95a25
Compare
@awni yes, it does. Thanks for your review and suggestions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great and much simpler, thanks for adding this!! I left a couple of comments, please address then we can merge.
llms/llama/convert.py
Outdated
@@ -140,6 +139,21 @@ def quantize(weights, config, args): | |||
return quantized_weights, quantized_config | |||
|
|||
|
|||
def make_shards(weights: dict, max_file_size_GiB: int = 15): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: max_file_size_gb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as we are using 2**30 = GiB I'd suggest to use: max_file_size_gibibyte
as I find max_file_size_gb
wrong and max_file_size_gib
unreadable.
llms/llama/convert.py
Outdated
shards = [] | ||
shard, shard_size = {}, 0 | ||
for k, v in weights.items(): | ||
estimated_size = len(v.flatten()) * v.dtype.itemsize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you check this with quantization? I think this line might break as dtype
doesn't have an itemsize
?
We really ought to expose nbytes
in python.. for consistency with numpy. For now you can do:
v.size * v.dtype.size if isintance(v, mx.array) else v.nbytes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't aware that quantization stores mx arrays already...
your suggestion seems to be a good intermediate solution. Exposing nbytes
sounds even better, I'll create a PR, sounds like a small change.
7f95a25
to
a11e3f8
Compare
Thanks, should be all fixed now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks!!
Similar to #92 I noticed that converting the llama-2 70b models takes quite a bit of RAM (succeeded at around 140GB on an 128GB machine with swaping).
However, the resulting huge weights files are still very hard to handle (e.g. upload them to HF is impossible and would required extra steps).
So I think changing the conversion algorithm a bit: keep the shards on model conversion and then unshard the weights on loading. This would be more RAM efficient and file size friendly.
This PR
I tested locally with tiny llama, llama-2-13b-chat and llama-2-70b-chat. The largest 70b model now takes around 16GB on average while converting, with peaks around 32GB. On inference it still requires around 128 GB, which makes sense, given that weights are around 128 GB on disk. With a bit of swapping one can run it on a 128GB machine, though not really productive on an Apple M3 Max 128GB: