-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathext.py
279 lines (209 loc) · 8.24 KB
/
ext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch
from torch.utils.cpp_extension import load
import os
import sys
import platform
extension_name = "exllamav2_ext"
verbose = True # Print wall of text when compiling
ext_debug = False # Compile with debug options
# Determine if we're on Windows
windows = (os.name == "nt")
# Determine if extension is already installed or needs to be built
build_jit = True
try:
import exllamav2_ext
except ModuleNotFoundError:
build_jit = True
if build_jit:
# Kludge to get compilation working on Windows
if windows:
def find_msvc():
# Possible locations for MSVC, in order of preference
program_files_x64 = os.environ["ProgramW6432"]
program_files_x86 = os.environ["ProgramFiles(x86)"]
msvc_dirs = \
[
a + "\\Microsoft Visual Studio\\" + b + "\\" + c + "\\VC\Tools\\MSVC\\"
for b in ["2022", "2019", "2017"]
for a in [program_files_x64, program_files_x86]
for c in ["BuildTools", "Community", "Professional", "Enterprise", "Preview"]
]
for msvc_dir in msvc_dirs:
if not os.path.exists(msvc_dir): continue
# Prefer the latest version
versions = sorted(os.listdir(msvc_dir), reverse = True)
for version in versions:
compiler_dir = msvc_dir + version + "\\bin\\Hostx64\\x64"
if os.path.exists(compiler_dir) and os.path.exists(compiler_dir + "\\cl.exe"):
return compiler_dir
# No path found
return None
import subprocess
# Check if cl.exe is already in the path
try:
subprocess.check_output(["where", "/Q", "cl"])
# If not, try to find an installation of Visual Studio and append the compiler dir to the path
except subprocess.CalledProcessError as e:
cl_path = find_msvc()
if cl_path:
if verbose:
print(" -- Injected compiler path:", cl_path)
os.environ["path"] += ";" + cl_path
else:
print(" !! Unable to find cl.exe; compilation will probably fail", file = sys.stderr)
# gcc / cl.exe flags
extra_cflags = ["/Ox"] if windows else ["-O3"]
if ext_debug:
extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
# nvcc flags
extra_cuda_cflags = ["-lineinfo", "-O3"]
if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
# linker flags
extra_ldflags = []
if windows:
extra_ldflags += ["cublas.lib"]
if sys.base_prefix != sys.prefix:
extra_ldflags += [f"/LIBPATH:{os.path.join(sys.base_prefix, 'libs')}"]
# sources
library_dir = os.path.dirname(os.path.abspath(__file__))
sources_dir = os.path.join(library_dir, extension_name)
sources_ = \
[
"ext_bindings.cpp",
"ext_cache.cpp",
"ext_gemm.cpp",
"ext_hadamard.cpp",
"ext_norm.cpp",
"ext_qattn.cpp",
"ext_qmatrix.cpp",
"ext_qmlp.cpp",
"ext_quant.cpp",
"ext_rope.cpp",
"ext_safetensors.cpp",
"ext_sampling.cpp",
"cuda/h_add.cu",
"cuda/h_gemm.cu",
"cuda/lora.cu",
"cuda/pack_tensor.cu",
"cuda/quantize.cu",
"cuda/q_matrix.cu",
"cuda/q_attn.cu",
"cuda/q_mlp.cu",
"cuda/q_gemm.cu",
"cuda/rms_norm.cu",
"cuda/head_norm.cu",
"cuda/layer_norm.cu",
"cuda/rope.cu",
"cuda/cache.cu",
"cuda/util.cu",
"cuda/comp_units/kernel_select.cu",
"cuda/comp_units/unit_gptq_1.cu",
"cuda/comp_units/unit_gptq_2.cu",
"cuda/comp_units/unit_gptq_3.cu",
"cuda/comp_units/unit_exl2_1a.cu",
"cuda/comp_units/unit_exl2_1b.cu",
"cuda/comp_units/unit_exl2_2a.cu",
"cuda/comp_units/unit_exl2_2b.cu",
"cuda/comp_units/unit_exl2_3a.cu",
"cuda/comp_units/unit_exl2_3b.cu",
"cpp/quantize_func.cpp",
"cpp/profiling.cpp",
"cpp/sampling.cpp",
"cpp/sampling_avx2.cpp",
"cpp/safetensors.cpp"
]
sources = [os.path.join(sources_dir, s) for s in sources_]
# Load extension
exllamav2_ext = load \
(
name = extension_name,
sources = sources,
extra_include_paths = [sources_dir],
verbose = verbose,
extra_ldflags = extra_ldflags,
extra_cuda_cflags = extra_cuda_cflags,
extra_cflags = extra_cflags
)
ext_c = exllamav2_ext
# Dummy tensor to pass to C++ extension in place of None/NULL
none_tensor = torch.empty((1, 1), device = "meta")
# Group map needed for irregular group sizes
def make_group_map(q_groups: torch.Tensor, num_qrows: int) -> torch.Tensor:
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
row = 0
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype = torch.short, device = q_groups.device)
# Create Q matrix
def make_q_matrix(w: dict,
temp_dq: torch.Tensor,
key: str = None,
prescale: float = 1,
max_dq_rows = 0):
# EXL2
if "q_weight" in w:
w["q_scale_max"] *= prescale / 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
return ext_c.make_q_matrix(w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
w["q_group_map"],
none_tensor,
none_tensor,
none_tensor,
w.get("bias", none_tensor),
temp_dq,
max_dq_rows)
# GPTQ
elif "qweight" in w:
if prescale != 1: w["scales"] *= prescale
if w["scales"].dtype == torch.float: w["scales"] = w["scales"].half()
# GPTQ with g_idx (act_order)
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
w["q_invperm"] = torch.empty_like(w["q_perm"])
return ext_c.make_q_matrix(w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
w.get("bias", none_tensor),
temp_dq,
max_dq_rows)
# GPTQ without g_idx
else:
return ext_c.make_q_matrix(w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,
w.get("bias", none_tensor),
temp_dq,
max_dq_rows)