Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables the huggingface checkpoint conversion to MaxText orbax. #1291

Merged
merged 19 commits into from
Feb 21, 2025

Conversation

wang2yn84
Copy link
Collaborator

Description

This PR converts the huggingface llama checkpoint to MaxText orbax format. The purpose of the PR is to convert the Deepseek distilled checkpoint to MaxText. The original llama_or_mistral_ckpt.py only works on the Pytorch checkpoint, but not the Huggingface checkpoint. This PR fixes the work flow and documented the whole process.

Right now the accuracy still has some issue and need further debug.

Tests

Run through the workflow multiple times and the generated checkpoint works.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@singh-mitali singh-mitali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many lint-only changes in this PR. Can those be skipped. Perhaps run "bash code_style.sh" under maxtext dir

@wang2yn84
Copy link
Collaborator Author

Many lint-only changes in this PR. Can those be skipped. Perhaps run "bash code_style.sh" under maxtext dir

True! Updated and removed those changes.

@anfals
Copy link
Collaborator

anfals commented Feb 20, 2025

@richjames0 had this PR: #1028 to support safetensors, but looks like this is doing more?

@wang2yn84
Copy link
Collaborator Author

@richjames0 had this PR: #1028 to support safetensors, but looks like this is doing more?

Yes I'm aware of this PR, but it doesn't work when I try that conversion function. More logics are required to handle the discrepancies between Huggingface model structure and MaxText. That's why I have this PR.

@anfals
Copy link
Collaborator

anfals commented Feb 20, 2025

@richjames0 had this PR: #1028 to support safetensors, but looks like this is doing more?

Yes I'm aware of this PR, but it doesn't work when I try that conversion function. More logics are required to handle the discrepancies between Huggingface model structure and MaxText. That's why I have this PR.

Gotcha! Yeah I was pointed to this older PR for converting a HF chpt, but we are seeing major issues with loss as I run with it. Your comment more or less confirms the ckpt conversion was the problem. But it'll be good once this PR lands and merges

@@ -168,6 +168,29 @@ def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict:
}


def _hf_to_maxtext_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe less error prone to have a function here which reverses key/value in previous dict.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually considering deleting that mapping cuz it's not working. Will refactor in the follow up PRs.

return x


def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ckpt, model_params, mem_info):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function and the following old function look very similar except the loading function. Could they be combined?

Copy link
Collaborator

@singh-mitali singh-mitali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments - but those could be addresses in a follow up CL.

@RissyRan
Copy link
Collaborator

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

@wang2yn84
Copy link
Collaborator Author

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before.

@RissyRan
Copy link
Collaborator

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before.

Yes, it works. @richjames0 and I were working on those ckpt from huggingface. To be specific we tested:

  1. download safetensors from https://huggingface.co/mistralai/Mixtral-8x22B-v0.1
  2. run script with this PR

What's the issue you met? Key not finding?

@@ -0,0 +1,15 @@
"""
Copyright 2023 Google LLC
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: 2025

@@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config):
return ids, decoder_segment_ids, decoder_positions, logits


def main(config, test_args):
def main(config, test_args): # pylint: disable=W0621
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will bash code_style.sh work for those?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I used bash code_style.sh and it report test_args is redefined within main function. That's weird and functionally everything is working fine. So I have to disable it for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh.... test_args shouldn't be redefined and overwritten. Probably we shouldn't disable it, but find the cause?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not redefined anywhere. That's the myth.

Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks Lance for adding this script. I had few minor comments.

wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim])
wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim])

if model_size[:8] == "llama3.1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic applicable for 3.1 version or all version after 3.1 as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only have 3.1 for now. Yes it should work for say, 3.3

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed later? update it to 3.3?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we are still using 3.1 to represent all the versions after 3.1. We can refactor that later to be more accurate. But right now, not only this place, there are other code in the database also depends on 3.1 to recognize the pattern.

self_attention["value"]["kernel"][layer_idx, ...] = wv # pylint: disable=E1137
self_attention["out"]["kernel"][layer_idx, ...] = w_post # pylint: disable=E1137

self_attention["query"]["kernel"] = np.transpose(self_attention["query"]["kernel"], axes=(1, 0, 2, 3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can you add a comment about the hardcoded axes (1, 0, 2, 3) and what they refer to in maxtext/Jax?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

layer_weight["post_self_attention_layer_norm"]["scale"][layer_idx, ...] = post_self_attention_layernorm # pylint: disable=E1137

layer_weight["pre_self_attention_layer_norm"]["scale"] = np.transpose(
layer_weight["pre_self_attention_layer_norm"]["scale"], axes=(1, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comment about what (1, 0) axes refer to here.


if num_experts is None:
# swap the layer index
layer_weight["mlp"]["wi_0"]["kernel"] = np.transpose(layer_weight["mlp"]["wi_0"]["kernel"], axes=(1, 0, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add doc about hardcoded axes for posterity...

Comment on lines 475 to 476
if huggingface_ckpt:
return _convert_huggingface_to_jax_weights(base_model_path, model_size, model_params, mem_info)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably for future PRs, but can we also consolidate the rest of the logic into _convert_pytorch_to_jax_weights() to make it cleaner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Updated!

@wang2yn84
Copy link
Collaborator Author

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before.

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before.

Yes, it works. @richjames0 and I were working on those ckpt from huggingface. To be specific we tested:

  1. download safetensors from https://huggingface.co/mistralai/Mixtral-8x22B-v0.1
  2. run script with this PR

What's the issue you met? Key not finding?

As far as I can see, llama_or_mistral_ckpt.py doesn't have the safetensor loader and it can only load pth file. How do you load safetensor checkpoint?

@RissyRan
Copy link
Collaborator

Are

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before.

Thanks for the change! The old PR works fine on my side a while ago, I'd like to take a look at this PR.

Does it converts the Huggingface checkpoint? The name mapping has the wrong direction, it's from MaxText to Huggingface. I would be surprised if it worked before.

Yes, it works. @richjames0 and I were working on those ckpt from huggingface. To be specific we tested:

  1. download safetensors from https://huggingface.co/mistralai/Mixtral-8x22B-v0.1
  2. run script with this PR

What's the issue you met? Key not finding?

As far as I can see, llama_or_mistral_ckpt.py doesn't have the safetensor loader and it can only load pth file. How do you load safetensor checkpoint?

You should be able to find this one in https://github.com/AI-Hypercomputer/maxtext/pull/1028/files

def load_safetensors_checkpoint(ckpt_paths):
  chkpt_vars_raw = {}
  for i, ckpt_path in enumerate(ckpt_paths):
    max_logging.log(f"Loading checkpoint path {i+1} of {len(ckpt_paths)} ...")
    with safe_open(ckpt_path, framework="pt") as f:
      for k in f.keys():
        assert k not in chkpt_vars_raw
        chkpt_vars_raw[k] = f.get_tensor(k)
  chkpt_vars = [_HFNamespaceMapper(chkpt_vars_raw)]
  return chkpt_vars

Sorry that we haven't merged this PR in time due to a minor comment. Please don't merge until we are aligned. Due the urgency, I am ok to save this in a branch or copy of a separate file.

@@ -518,6 +519,7 @@ inference_metadata_file: "" # path to a json file
inference_server: "MaxtextInterleavedServer" # inference server to start
inference_benchmark_test: False
enable_model_warmup: False
hf_model_path: "" # inference checkpoint correctness verification
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to add into base.yml (if only uses in llama_or_mistral_ckpt.py)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I run the script, if I don't add it to base.yml, it complains it's configured in the command line but not in the config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! How come? as max_kl_div, atol, etc are not in the base.yml as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline and agreed to removed it from here. It's excluded in the test. Should use "--" to pass in the config.

wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim])
wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim])

if model_size[:8] == "llama3.1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed later? update it to 3.3?

@@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config):
return ids, decoder_segment_ids, decoder_positions, logits


def main(config, test_args):
def main(config, test_args): # pylint: disable=W0621
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh.... test_args shouldn't be redefined and overwritten. Probably we shouldn't disable it, but find the cause?


def test_huggingface_to_maxtext_back_to_huggingface_flow():
base_num_query_heads = 16
head_dim = 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why those 2 config are defined/hardcode in the test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this test to a separate file. Basically this is a unit test for the permutation function. So everything else is hardcoded.

@copybara-service copybara-service bot merged commit e7038bc into main Feb 21, 2025
12 of 20 checks passed
@copybara-service copybara-service bot deleted the lance-deepseek branch February 21, 2025 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants