Skip to content

Commit

Permalink
[config] Add a llama2-7b-int8 test config for decode
Browse files Browse the repository at this point in the history
Also add a small accuracy check for the generate text prefix.
  • Loading branch information
xy12181 committed Feb 19, 2025
1 parent 8350fef commit c842825
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
29 changes: 28 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,34 @@
"version": "0.2.0",
"configurations": [
{
"name": "Debug MaxText Decode",
"name": "Debug MaxText Decode (llama2-7b-int8)",
"type": "python",
"request": "launch",
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"program": "${workspaceFolder}/MaxText/decode.py",
"args": ["MaxText/configs/base.yml",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"dataset_path=gs://test-maxtext-dataset",
"model_name=llama2-7b",
"load_parameters_path=gs://msingh-bkt/checkpoints/quant_llama2-7b-chat/20241120034012/int8_",
"tokenizer_path=assets/tokenizer.llama2",
"per_device_batch_size=8",
"max_prefill_predict_length=8",
"max_target_length=20",
"weight_dtype=bfloat16",
"ici_fsdp_parallelism=1",
"ici_tensor_parallelism=-1",
"scan_layers=false",
"quantization=int8",
"checkpoint_is_quantized=true",
"attention=dot_product",
"autoregressive_decode_assert=travel and explore new places"]
},
{
"name": "Debug MaxText Decode (Test)",
"type": "python",
"request": "launch",
"console": "integratedTerminal",
Expand Down
7 changes: 3 additions & 4 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ def main(argv: Sequence[str]) -> None:
output = tokenizer_model.decode(results)
print(f"Input `{text}` -> `{output}`")

if config.autoregressive_decode_assert != "":
assert (
output == config.autoregressive_decode_assert
), f"generated text mismatch {output=} {config.autoregressive_decode_assert=}"
assert output.startswith(
config.autoregressive_decode_assert
), f"generated text mismatch {output=}, {config.autoregressive_decode_assert=}"


def validate_config(config):
Expand Down

0 comments on commit c842825

Please sign in to comment.