forked from haotian-liu/LLaVA
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
184 lines (156 loc) · 7.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import subprocess
import os
import shutil
import tarfile
import tempfile
import zipfile
import json
from cog import BaseModel, Input, Path
from llava.utils import disable_torch_init
from file_utils import is_url, download_file, download_weights, REPLICATE_WEIGHTS_URL, DEFAULT_WEIGHTS
# we don't use the huggingface hub cache, but we need to set this to a local folder
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/models"
def check_zip_contents(zip_path):
# Check if the ZIP file contains 'data.json' and a folder named 'images' in root
error_msgs = []
train_data_has_right_structure = True
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# List all contents of the zip file
zip_contents = zip_ref.namelist()
if "data.json" not in zip_contents:
wrong_locations = [item for item in zip_contents if "data.json" in item]
data_json_msg = f"{zip_path} does not contain a file named data.json in root. This file might be in the wrong location: {', '.join(wrong_locations)}"
error_msgs.append(data_json_msg)
train_data_has_right_structure = False
if "images/" not in zip_contents:
images_folder_msg = f"{zip_path} does not contain a folder named images in root."
error_msgs.append(images_folder_msg)
train_data_has_right_structure = False
if "data.json" in zip_contents:
# Read and load the content of 'data.json'
with zip_ref.open("data.json", 'r') as data_json_file:
data_json_content = json.load(data_json_file)
for datapoint in data_json_content:
img_filename = "images/" + datapoint.get('image')
if not img_filename in zip_contents:
missing_file_str = f"data.json refers to image {img_filename}, but this file is missing in {zip_path}"
error_msgs.append(missing_file_str)
train_data_has_right_structure = False
except zipfile.BadZipFile:
badzip_msg = f"File '{zip_path}' is not a valid ZIP file or is corrupted."
error_msgs.append(badzip_msg)
print(badzip_msg)
train_data_has_right_structure = False
return train_data_has_right_structure, error_msgs
def run_training(
image_folder: Path,
data_path: Path,
output_dir: Path,
num_train_epochs: int = 1,
learning_rate: float = 2e-4,
model_max_length: int = 2048
):
# Command and arguments as a list
command = [
'python',
'-m',
'deepspeed.launcher.runner',
'llava/train/train_mem.py',
'--model_name_or_path', 'liuhaotian/llava-v1.5-13b',
'--data_path', data_path,
'--image_folder', image_folder,
'--vision_tower', 'openai/clip-vit-large-patch14-336',
'--output_dir', output_dir,
'--lora_enable', 'True',
'--lora_r', '128',
'--lora_alpha', '256',
'--mm_projector_lr', '2e-5',
'--save_steps', '500',
'--deepspeed', './scripts/zero3.json',
'--version', 'v1',
'--mm_projector_type', 'mlp2x_gelu',
'--mm_vision_select_layer', '-2',
'--mm_use_im_start_end', 'False',
'--mm_use_im_patch_token', 'False',
'--image_aspect_ratio', 'pad',
'--group_by_modality_length', 'True',
'--bf16', 'True',
'--num_train_epochs', str(num_train_epochs),
'--per_device_train_batch_size', '16',
'--per_device_eval_batch_size', '4',
'--gradient_accumulation_steps', '1',
'--evaluation_strategy', 'no',
'--save_strategy', 'steps',
'--save_total_limit', '1',
'--learning_rate', str(learning_rate),
'--weight_decay', '0.',
'--warmup_ratio', '0.03',
'--lr_scheduler_type', 'cosine',
'--logging_steps', '1',
'--tf32', 'True',
'--model_max_length', str(model_max_length),
'--gradient_checkpointing', 'True',
'--dataloader_num_workers', '4',
'--lazy_preprocess', 'True',
'--report_to', 'none'
]
# Execute the command
env = os.environ.copy()
env['PYTHONPATH'] = os.getcwd()
subprocess.run(command, env=env, check=True)
class TrainingOutput(BaseModel):
# this must be a key named `weights`, otherwise image creation will silently fail
# source: https://github.com/replicate/api/blob/6b73b27e0da6afbea0531bb4162e9b4f5a74d744/pkg/server/internal.go#L282
weights: Path
def train(
train_data: str = Input(description="https url or path name of a zipfile containing training data. Training data should have a json file data.json and an images/ folder. data.json should link the images from images/ to conversations."),
num_train_epochs: int = Input(description="The number of training epochs", ge=1, le=1000, default=1),
learning_rate: float = Input(description="The learning rate during training", ge=1e-10, default=2e-4),
model_max_length: int = Input(description="The maximum length (in number of tokens) for the inputs to the model.", ge=1, default=2048),
) -> TrainingOutput:
# Create a temporary directory to unzip train_data
with tempfile.TemporaryDirectory() as tmp_dir_name:
tmp_dir = Path(tmp_dir_name)
# Download train_data if it is a URL
if is_url(train_data):
local_train_data_path = tmp_dir / "train_data_archive.zip"
download_file(train_data, local_train_data_path)
else:
local_train_data_path = Path(train_data)
# check the structure of the train_data zipfile
train_data_has_right_structure, errors = check_zip_contents(local_train_data_path)
if not train_data_has_right_structure:
raise ValueError(f"There was a problem with the training data in {train_data}:\n\n" + "\n".join(errors))
# download base models
for weight in DEFAULT_WEIGHTS:
download_weights(weight["src"], weight["dest"], weight["files"])
disable_torch_init()
# Path to the weights file
weights_file = Path("my_weights.tar")
# Remove old output tar if it exists
if weights_file.exists():
weights_file.unlink()
# Unzip train_data into tmp_dir
shutil.unpack_archive(str(local_train_data_path), tmp_dir)
# Define paths to data_path, image_folder, and output_dir within tmp_dir
data_path = tmp_dir / "data.json"
image_folder = tmp_dir / "images"
output_dir = tmp_dir / "output"
# Make sure the output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# Run the training command
run_training(
image_folder,
data_path,
output_dir,
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
model_max_length=model_max_length
)
# Tar the checkpoints and put into weights_file without compression
with tarfile.open(str(weights_file), "w") as tar:
tar.add(output_dir, arcname="")
# Return the path to the weights file
return TrainingOutput(weights=weights_file)
# todo: deal with recursive lora