Skip to content

Commit

Permalink
more nits
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Feb 10, 2025
1 parent 6e9542a commit 2bfec57
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
25 changes: 6 additions & 19 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,13 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.

### Input Masking
There are custom functions for masking the sequence of tokens associated with the `prompt` in a completion dataset
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
with masked input sequences, use the `--mask-inputs` argument.

This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
the sequence from loss calculations. For example (ChatML):

```yaml
response_template: "<|im_start|>assistant"
```
or (for the corresponding tokens of Gemma's response template)
```yaml
response_template: [106, 2516]
```
#### Prompt Masking

The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.

### Evaluate

Expand Down
4 changes: 3 additions & 1 deletion llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets

mask_prompt = getattr(args, "mask_prompt", False)

def create_hf_dataset(
dataset_name,
text_feature,
Expand All @@ -201,7 +203,7 @@ def create_hf_dataset(
)
if prompt_feature and completion_feature:
return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
ds, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature:
return ChatDataset(
Expand Down

0 comments on commit 2bfec57

Please sign in to comment.