-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlora.py
176 lines (126 loc) · 6.26 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations
from exllamav2.config import ExLlamaV2Config
from exllamav2.linear import ExLlamaV2Linear
import os, json
from safetensors.torch import load_file as safe_load_file
from torch import load as load_file
import torch
from exllamav2.compat import safe_move_tensor
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2
class ExLlamaV2Lora:
model: ExLlamaV2
lora_config_path: str
lora_path: str
lora_config_path: str
lora_path: str
lora_r: int
lora_alpha: float
lora_scaling: float
config: ExLlamaV2Config
bias_ignored: bool
tensors: dict
target_modules: dict
@staticmethod
def from_directory(model, directory, lora_scaling = 1.0):
config_path = os.path.join(directory, "adapter_config.json")
lora_path_bin = os.path.join(directory, "adapter_model.bin")
lora_path_st = os.path.join(directory, "adapter_model.safetensors")
if os.path.exists(lora_path_bin): return ExLlamaV2Lora(model, config_path, lora_path_bin, lora_scaling)
if os.path.exists(lora_path_st): return ExLlamaV2Lora(model, config_path, lora_path_st, lora_scaling)
raise ValueError(f"No LoRA found in {directory}")
@torch.inference_mode
def __init__(self,
model: ExLlamaV2,
lora_config_path: str,
lora_path: str,
lora_scaling: float = 1.0):
self.lora_config_path = lora_config_path
self.lora_path = lora_path
self.model = model
self.config = model.config
self.tensors = {}
self.target_modules = {}
self.bias_ignored = False
self.lora_scaling = lora_scaling
# Grab relevant items from LoRA config
with open(lora_config_path, encoding = "utf8") as f:
read_config = json.load(f)
self.lora_r = read_config["r"]
self.lora_alpha = float(read_config["lora_alpha"])
self.lora_scaling *= self.lora_alpha / self.lora_r
if "fan_in_fan_out" in read_config and read_config["fan_in_fan_out"]:
raise ValueError(" ## Error: fan_in_fan_out mode not supported.")
# Load LoRA weights
if self.lora_path.endswith(".safetensors"):
f = safe_load_file(self.lora_path, device = "cpu")
else:
f = load_file(self.lora_path, map_location = "cpu")
for key in f.keys():
tensor = f[key]
# Find target
i = key.find("model.layers.")
if i == -1: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
target_key = key[i:]
ks = target_key.split(".")
decoder_idx = int(ks[2])
decoder_part = ks[3]
decoder_layer = ".".join(ks[4:-2])
lora_half = ks[-2]
if lora_half == "bias":
epsilon = 1e-6
if torch.max(tensor) > epsilon or torch.max(tensor) < -epsilon:
raise ValueError(f" ## Error: unsupported bias target {self.lora_path}: {key}")
self.bias_ignored = True
continue
target_module = self.model.modules_dict["model.layers." + str(decoder_idx) + "." + decoder_part + "." + decoder_layer]
# if decoder_part == "self_attn": target_module = target_module.self_attn
# elif decoder_part == "mlp": target_module = target_module.mlp
# else: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
# if decoder_layer == "q_proj": target_module = target_module.q_proj
# elif decoder_layer == "k_proj": target_module = target_module.k_proj
# elif decoder_layer == "v_proj": target_module = target_module.v_proj
# elif decoder_layer == "o_proj": target_module = target_module.o_proj
# elif decoder_layer == "gate_proj": target_module = target_module.gate_proj
# elif decoder_layer == "up_proj": target_module = target_module.up_proj
# elif decoder_layer == "down_proj": target_module = target_module.down_proj
# else: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
# Check that shape is compatible
assert isinstance(target_module, ExLlamaV2Linear)
if lora_half == "lora_A":
in_features = tensor.shape[1]
out_features = None
elif lora_half == "lora_B":
in_features = None
out_features = tensor.shape[0]
else: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
if (in_features and in_features != target_module.in_features) or (out_features and out_features != target_module.out_features):
raise ValueError(f" ## Error: incompatible tensor shape in {self.lora_path}: {key}")
# For efficiency, transpose adapter instead of transposing state during inference
tensor = tensor.T.contiguous()
# Pre-scale
if lora_half == "lora_B" and self.lora_scaling != 1.0: tensor.mul_(self.lora_scaling)
# Check that dtype is compatible, or convert
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
elif tensor.dtype == torch.float32:
tensor = tensor.to(torch.float16)
elif tensor.dtype == torch.float16:
pass
else: raise ValueError(f" ## Error: unsupported tensor dtype in {self.lora_path}")
# Move to target device
tensor = safe_move_tensor(tensor, target_module.device())
if lora_half == "lora_A": target_module.lora_a_tensors[self] = tensor
if lora_half == "lora_B": target_module.lora_b_tensors[self] = tensor
# Store adapter tensor
self.tensors[target_key] = tensor
self.target_modules[target_key] = target_module
self.model.update_loras()
def unload(self):
for k, v in self.target_modules.items():
if self in v.lora_a_tensors: del v.lora_a_tensors[self]
if self in v.lora_b_tensors: del v.lora_b_tensors[self]
self.tensors = {}
self.target_modules = {}
self.model.update_loras()