-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enables the huggingface checkpoint conversion to MaxText orbax. #1291
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many lint-only changes in this PR. Can those be skipped. Perhaps run "bash code_style.sh" under maxtext dir
b28d784
to
36bdf05
Compare
True! Updated and removed those changes. |
@richjames0 had this PR: #1028 to support safetensors, but looks like this is doing more? |
Yes I'm aware of this PR, but it doesn't work when I try that conversion function. More logics are required to handle the discrepancies between Huggingface model structure and MaxText. That's why I have this PR. |
Gotcha! Yeah I was pointed to this older PR for converting a HF chpt, but we are seeing major issues with loss as I run with it. Your comment more or less confirms the ckpt conversion was the problem. But it'll be good once this PR lands and merges |
@@ -168,6 +168,29 @@ def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: | |||
} | |||
|
|||
|
|||
def _hf_to_maxtext_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe less error prone to have a function here which reverses key/value in previous dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually considering deleting that mapping cuz it's not working. Will refactor in the follow up PRs.
MaxText/llama_or_mistral_ckpt.py
Outdated
return x | ||
|
||
|
||
def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ckpt, model_params, mem_info): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function and the following old function look very similar except the loading function. Could they be combined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments - but those could be addresses in a follow up CL.
Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR. |
Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before. |
Yes, it works. @richjames0 and I were working on those ckpt from huggingface. To be specific we tested:
What's the issue you met? Key not finding? |
@@ -0,0 +1,15 @@ | |||
""" | |||
Copyright 2023 Google LLC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: 2025
@@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config): | |||
return ids, decoder_segment_ids, decoder_positions, logits | |||
|
|||
|
|||
def main(config, test_args): | |||
def main(config, test_args): # pylint: disable=W0621 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will bash code_style.sh
work for those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I used bash code_style.sh and it report test_args is redefined within main function. That's weird and functionally everything is working fine. So I have to disable it for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh.... test_args shouldn't be redefined and overwritten. Probably we shouldn't disable it, but find the cause?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not redefined anywhere. That's the myth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks Lance for adding this script. I had few minor comments.
wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) | ||
wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) | ||
|
||
if model_size[:8] == "llama3.1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this logic applicable for 3.1
version or all version after 3.1
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only have 3.1 for now. Yes it should work for say, 3.3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change needed later? update it to 3.3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically we are still using 3.1 to represent all the versions after 3.1. We can refactor that later to be more accurate. But right now, not only this place, there are other code in the database also depends on 3.1 to recognize the pattern.
MaxText/llama_or_mistral_ckpt.py
Outdated
self_attention["value"]["kernel"][layer_idx, ...] = wv # pylint: disable=E1137 | ||
self_attention["out"]["kernel"][layer_idx, ...] = w_post # pylint: disable=E1137 | ||
|
||
self_attention["query"]["kernel"] = np.transpose(self_attention["query"]["kernel"], axes=(1, 0, 2, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can you add a comment about the hardcoded axes (1, 0, 2, 3)
and what they refer to in maxtext/Jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
layer_weight["post_self_attention_layer_norm"]["scale"][layer_idx, ...] = post_self_attention_layernorm # pylint: disable=E1137 | ||
|
||
layer_weight["pre_self_attention_layer_norm"]["scale"] = np.transpose( | ||
layer_weight["pre_self_attention_layer_norm"]["scale"], axes=(1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comment about what (1, 0)
axes refer to here.
|
||
if num_experts is None: | ||
# swap the layer index | ||
layer_weight["mlp"]["wi_0"]["kernel"] = np.transpose(layer_weight["mlp"]["wi_0"]["kernel"], axes=(1, 0, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add doc about hardcoded axes for posterity...
MaxText/llama_or_mistral_ckpt.py
Outdated
if huggingface_ckpt: | ||
return _convert_huggingface_to_jax_weights(base_model_path, model_size, model_params, mem_info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably for future PRs, but can we also consolidate the rest of the logic into _convert_pytorch_to_jax_weights()
to make it cleaner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Updated!
As far as I can see, llama_or_mistral_ckpt.py doesn't have the safetensor loader and it can only load pth file. How do you load safetensor checkpoint? |
Are
You should be able to find this one in https://github.com/AI-Hypercomputer/maxtext/pull/1028/files
Sorry that we haven't merged this PR in time due to a minor comment. Please don't merge until we are aligned. Due the urgency, I am ok to save this in a branch or copy of a separate file. |
@@ -518,6 +519,7 @@ inference_metadata_file: "" # path to a json file | |||
inference_server: "MaxtextInterleavedServer" # inference server to start | |||
inference_benchmark_test: False | |||
enable_model_warmup: False | |||
hf_model_path: "" # inference checkpoint correctness verification |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to add into base.yml (if only uses in llama_or_mistral_ckpt.py)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I run the script, if I don't add it to base.yml, it complains it's configured in the command line but not in the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! How come? as max_kl_div
, atol
, etc are not in the base.yml as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline and agreed to removed it from here. It's excluded in the test. Should use "--" to pass in the config.
wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) | ||
wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) | ||
|
||
if model_size[:8] == "llama3.1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change needed later? update it to 3.3?
@@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config): | |||
return ids, decoder_segment_ids, decoder_positions, logits | |||
|
|||
|
|||
def main(config, test_args): | |||
def main(config, test_args): # pylint: disable=W0621 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh.... test_args shouldn't be redefined and overwritten. Probably we shouldn't disable it, but find the cause?
|
||
def test_huggingface_to_maxtext_back_to_huggingface_flow(): | ||
base_num_query_heads = 16 | ||
head_dim = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why those 2 config are defined/hardcode in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this test to a separate file. Basically this is a unit test for the permutation function. So everything else is hardcoded.
Description
This PR converts the huggingface llama checkpoint to MaxText orbax format. The purpose of the PR is to convert the Deepseek distilled checkpoint to MaxText. The original llama_or_mistral_ckpt.py only works on the Pytorch checkpoint, but not the Huggingface checkpoint. This PR fixes the work flow and documented the whole process.
Right now the accuracy still has some issue and need further debug.
Tests
Run through the workflow multiple times and the generated checkpoint works.
Checklist
Before submitting this PR, please make sure (put X in square brackets):