Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Completion only fine-tuning of instruction models with collections of HF datasets #1103

Merged
merged 29 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
604be3c
Add input_masked loss calculation and batching w/ padding
chimezie Jun 7, 2024
79a0427
Replace iterate_input_masked_batches with iterate_delineated_batches,…
chimezie Nov 5, 2024
84fc1bd
Minor documentation update
chimezie Nov 5, 2024
27cd361
Updates CL lora tuner with input masking that uses default_loss (and …
chimezie Nov 6, 2024
30fd5af
Fix variable reference
chimezie Nov 6, 2024
02abeea
Update sublist search and calculation of input id length
chimezie Nov 6, 2024
71d9f8c
Fix
chimezie Nov 7, 2024
3496cbe
Add input masking for fine-tuning in documentation
chimezie Nov 10, 2024
14a75f3
Generalize HF datasets to a collection of HF dataasets via `datasets`…
chimezie Nov 4, 2024
8ec802f
Updates to LoRA documentation
chimezie Nov 4, 2024
214c79b
Fixes to config format in documentattion
chimezie Nov 4, 2024
387c45e
Fixes to references to hf_datasets
chimezie Nov 4, 2024
78c33e5
Fix keyword argument invokation
chimezie Nov 4, 2024
a4a86ad
Fix iteration over HF dataset collection
chimezie Nov 4, 2024
a5b866c
Fix index calculation
chimezie Nov 4, 2024
4890870
Add ability to fetch raw prompt and completion text from completion d…
chimezie Nov 6, 2024
69282ab
Minor fix
chimezie Nov 6, 2024
3f08dfc
Don't dupe BOS
chimezie Nov 10, 2024
5ce58e4
Update documentation
chimezie Nov 10, 2024
f989401
Default for hf_datasets configuration
chimezie Dec 6, 2024
6df285e
Synch use of special tokens with iterate_batches
chimezie Dec 6, 2024
cb87f6f
Add response template (or token) argument
chimezie Dec 8, 2024
95e1f22
Incorporate use of response template for completion masking
chimezie Dec 8, 2024
7989d0a
Move response template to LoRA configuration
chimezie Dec 8, 2024
b9748e9
Generalize the get_item method to all CompletionDatasets
chimezie Dec 9, 2024
6ace6dc
simplify collections
awni Feb 9, 2025
6e9542a
put offset in prompt, simplify
awni Feb 10, 2025
bb2c8bc
more nits
awni Feb 10, 2025
eda597b
simplify
awni Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.

#### Prompt Masking

The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.

### Evaluate

To compute test set perplexity use:
Expand Down Expand Up @@ -290,11 +298,27 @@ hf_dataset:

- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
dataset. Use `chat_feature` to specify the key for a chat dataset.

- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.

You can specify a list of Hugging Face datasets with a list of records each
with the same structure as above. For example:

```yaml
hf_dataset:
- name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```

- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

Expand Down
9 changes: 9 additions & 0 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)

parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)

parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -219,6 +227,7 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)

# Train model
train(
model=model,
Expand Down
6 changes: 6 additions & 0 deletions llms/mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from functools import partial
from typing import List

from transformers import AutoTokenizer

Expand Down Expand Up @@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
detokenizer_class,
eos_token_ids=eos_token_ids,
)


def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
184 changes: 116 additions & 68 deletions llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import itertools
import json
import types
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -34,14 +36,24 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""

def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
mask_prompt: bool = False,
):
self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)

def __getitem__(self, idx: int):
return self._data[idx]
Expand All @@ -63,16 +75,36 @@ def __init__(
tokenizer: PreTrainedTokenizer,
prompt_key: str,
completion_key: str,
mask_prompt: bool,
):
self._data = [
tokenizer.apply_chat_template(
self._data = []
for d in data:
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
if mask_prompt:
offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)

def __getitem__(self, idx: int):
return self._data[idx]

def __len__(self):
return len(self._data)


class ConcatenatedDataset:
def __init__(self, data: List[Any]):
self._data = list(itertools.chain(*data))

def __getitem__(self, idx: int):
return self._data[idx]
Expand All @@ -84,18 +116,26 @@ def __len__(self):
def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
mask_prompt = getattr(config, "mask_prompt", False)
prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data, tokenizer)
if prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature in sample:
return ChatDataset(
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
Expand All @@ -106,15 +146,14 @@ def create_dataset(
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
return create_dataset(data, tokenizer, config)

names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
Expand All @@ -124,8 +163,7 @@ def load_subset(path):
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
from datasets import exceptions, load_dataset

Expand All @@ -136,9 +174,7 @@ def load_hf_dataset(

train, valid, test = [
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
create_dataset(dataset[n], tokenizer, config)
if n in dataset.keys()
else []
)
Expand All @@ -154,61 +190,73 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets

hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")

def create_hf_dataset(split: str = None):
def create_hf_dataset(dataset_name, config, split, hf_config):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
**hf_config,
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(ds, tokenizer, text_key=text_feature)
return create_dataset(ds, tokenizer, config)

dataset_collection = args.hf_dataset
if isinstance(dataset_collection, dict):
dataset_collection = [dataset_collection]

collection = []
for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
config = types.SimpleNamespace(**ds)
hf_config = ds.get("config", {})
if args.train:
train_split = ds.get("train_split", "train[:80%]")
valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset(
ds_name,
config,
train_split,
hf_config,
)
valid = create_hf_dataset(
ds_name,
config,
valid_split,
hf_config,
)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
train, valid = [], []

if args.test:
test_split = ds.get("test_split")
test = create_hf_dataset(
ds_name,
config,
test_split,
hf_config,
)
else:
test = []

if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
collection.append((train, valid, test))

return train, valid, test
if len(collection) == 1:
return collection[0]

# Otherwise concatenate them
return tuple(map(ConcatenatedDataset, zip(*collection)))


def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)

prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_local_dataset(data_path, tokenizer, args)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_hf_dataset(args.data, tokenizer, args)

if args.train and len(train) == 0:
raise ValueError(
Expand Down
Loading