diff --git a/eval/humaneval.py b/eval/humaneval.py index 8283f7ea..3240d57b 100644 --- a/eval/humaneval.py +++ b/eval/humaneval.py @@ -5,7 +5,7 @@ from exllamav2 import model_init from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8 from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler -import argparse, contextlib +import argparse, contextlib, subprocess import util # Args @@ -20,6 +20,7 @@ parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion") parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ") parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating") +parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling") model_init.add_args(parser) args = parser.parse_args() @@ -52,6 +53,13 @@ "<|start_header_id|>assistant<|end_header_id|>\n\n" "Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ", " " + ), + "gemma": ( + "user\n" + "Complete the following Python function:\n\n{{problem}}<|eot_id|>" + "model\n" + "```python\n{{problem}} ", + " " ) } @@ -192,3 +200,8 @@ print(f" -- Saving: {args.output}") write_jsonl(args.output, samples) +# Optionally launch eval script + +if args.eval: + subprocess.run(["evaluate_functional_correctness", args.output]) + diff --git a/examples/chat.py b/examples/chat.py index 16b04057..70963a9f 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -61,7 +61,7 @@ parser.add_argument("-ngram", "--ngram_decoding", action = "store_true", help = "Use n-gram speculative decoding") -parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings after each prompt") +parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings/stats after each prompt") parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context after every response") # Arrrgs @@ -235,7 +235,9 @@ def get_tokenized_context(max_len): # Stop conditions -generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer)) +sc = prompt_format.stop_conditions(tokenizer) +sc = [x for x in sc if x] +generator.set_stop_conditions(sc) # ANSI color codes @@ -393,8 +395,9 @@ def get_tokenized_context(max_len): else: sd_stats = "" + ctx_tokens = active_context.shape[-1] print() - print(col_sysprompt + f"(Response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default) + print(col_sysprompt + f"(Context: {ctx_tokens} tokens, response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default) # Optionally forget context after each response diff --git a/examples/chat_prompts.py b/examples/chat_prompts.py index 9ea0042c..00474eed 100644 --- a/examples/chat_prompts.py +++ b/examples/chat_prompts.py @@ -229,6 +229,7 @@ def subs_prompt(self): def stop_conditions(self, tokenizer): return \ [tokenizer.eos_token_id, + tokenizer.single_id("<|im_end|>"), """<|im_end|>"""] def encoding_options(self): diff --git a/examples/dynamic_gen.py b/examples/dynamic_gen.py index ec0e3192..bfeb9d96 100644 --- a/examples/dynamic_gen.py +++ b/examples/dynamic_gen.py @@ -136,6 +136,7 @@ def main(): if use_draft_model: draft_config = ExLlamaV2Config(draft_model_dir) + draft_config.arch_compat_overrides() draft_model = ExLlamaV2(draft_config) draft_cache = ExLlamaV2Cache( @@ -155,6 +156,7 @@ def main(): # 2048, which will also be the limit of the chunk size for prefill used by the dynamic generator. config = ExLlamaV2Config(model_dir) + config.arch_compat_overrides() config.max_input_len = max_chunk_size config.max_attention_size = max_chunk_size ** 2 model = ExLlamaV2(config) diff --git a/examples/inference.py b/examples/inference.py index 0353d45c..9ac28d63 100644 --- a/examples/inference.py +++ b/examples/inference.py @@ -7,6 +7,7 @@ model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_async.py b/examples/inference_async.py index 94629195..c12c6408 100644 --- a/examples/inference_async.py +++ b/examples/inference_async.py @@ -9,6 +9,7 @@ async def main(): model_dir = "/mnt/str/models/llama3-8b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) + config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_banned_strings.py b/examples/inference_banned_strings.py index c6ad1885..7648b0e6 100644 --- a/examples/inference_banned_strings.py +++ b/examples/inference_banned_strings.py @@ -9,6 +9,7 @@ model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/6.0bpw/" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_cfg.py b/examples/inference_cfg.py index 4ecb5452..b1c86e71 100644 --- a/examples/inference_cfg.py +++ b/examples/inference_cfg.py @@ -8,6 +8,7 @@ model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_dedup.py b/examples/inference_dedup.py index 335f2289..bd39fd6a 100644 --- a/examples/inference_dedup.py +++ b/examples/inference_dedup.py @@ -8,6 +8,7 @@ model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_json.py b/examples/inference_json.py index b6e7608d..eb53262f 100644 --- a/examples/inference_json.py +++ b/examples/inference_json.py @@ -13,6 +13,7 @@ model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_lora.py b/examples/inference_lora.py index 310f0213..33292fee 100644 --- a/examples/inference_lora.py +++ b/examples/inference_lora.py @@ -7,6 +7,7 @@ model_dir = "/mnt/str/models/llama2-7b-exl2/5.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_speculative.py b/examples/inference_speculative.py index 51ea311b..53123dbd 100644 --- a/examples/inference_speculative.py +++ b/examples/inference_speculative.py @@ -12,12 +12,14 @@ draft_model_dir = "/mnt/str/models/qwen2-1.5b-instruct-exl2/4.0bpw" draft_config = ExLlamaV2Config(draft_model_dir) +draft_config.arch_compat_overrides() draft_model = ExLlamaV2(draft_config) draft_cache = ExLlamaV2Cache(draft_model, max_seq_len = total_cache_tokens, lazy = True) draft_model.load_autosplit(draft_cache, progress = True) model_dir = "/mnt/str/models/qwen2-72b-instruct-exl2/6.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, max_seq_len = total_cache_tokens, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/inference_stream.py b/examples/inference_stream.py index 3f9cfa80..ce94bc4d 100644 --- a/examples/inference_stream.py +++ b/examples/inference_stream.py @@ -8,6 +8,7 @@ model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw" config = ExLlamaV2Config(model_dir) +config.arch_compat_overrides() model = ExLlamaV2(config) cache = ExLlamaV2Cache(model, lazy = True) model.load_autosplit(cache, progress = True) diff --git a/examples/util.py b/examples/util.py index 8ddc6b7f..45c8b1ee 100644 --- a/examples/util.py +++ b/examples/util.py @@ -29,6 +29,12 @@ def format_prompt(prompt_format, sp, p): f"{p}<|im_end|>\n" f"<|im_start|>assistant\n" ) + elif prompt_format == "gemma": + return ( + f"user\n" + f"{p}\n" + f"model\n" + ) def get_stop_conditions(prompt_format, tokenizer): if prompt_format == "llama": @@ -37,7 +43,8 @@ def get_stop_conditions(prompt_format, tokenizer): return [tokenizer.single_id("<|eot_id|>")] elif prompt_format == "granite": return [tokenizer.eos_token_id, "\n\nQuestion:"] - + elif prompt_format == "gemma": + return [tokenizer.eos_token_id, ""] # Cached dataset loader diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index df8b7de0..9bb7eca3 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -9,6 +9,12 @@ ["ln_2"]] layer_keys_yi_norms = [["ln1", "input_layernorm"], ["ln2", "post_attention_layernorm"]] +layer_keys_gemma2_norms = [["input_layernorm"], + ["post_attention_layernorm"], + ["pre_feedforward_layernorm"], + ["post_feedforward_layernorm"]] +layer_keys_internlm2_norms = [["attention_norm"], + ["ffn_norm"]] layer_keys_llama_attn = [["self_attn.q_proj"], ["self_attn.k_proj"], ["self_attn.v_proj"], @@ -17,6 +23,10 @@ ["self_attn.c_attn", "self_attn.k_proj"], ["self_attn.c_attn", "self_attn.v_proj"], ["self_attn.o_proj"]] +layer_keys_internlm2_attn = [["self_attn.wqkv", "self_attn.q_proj"], + ["self_attn.wqkv", "self_attn.k_proj"], + ["self_attn.wqkv", "self_attn.v_proj"], + ["self_attn.o_proj"]] layer_keys_dbrx_attn = [["self_attn.Wqkv", "self_attn.q_proj"], ["self_attn.Wqkv", "self_attn.k_proj"], ["self_attn.Wqkv", "self_attn.v_proj"], @@ -28,6 +38,9 @@ layer_keys_llama_mlp = [["mlp.down_proj"], ["mlp.gate_proj"], ["mlp.up_proj"]] +layer_keys_internlm2_mlp = [["feed_forward.w1"], + ["feed_forward.w2"], + ["feed_forward.w3"]] layer_keys_phi3_mlp = [["mlp.down_proj"], ["mlp.gate_up_proj", "mlp.gate_proj"], ["mlp.gate_up_proj", "mlp.up_proj"]] @@ -76,6 +89,10 @@ ("$h.", "model.layers."), ("$wte.", "model.embed_tokens."), ("$wpe.", "model.wpe.")] +internlm2_keymap = [("$output.", "lm_head."), + ("$model.tok_embeddings.", "model.embed_tokens."), + (".attention.", ".self_attn."), + (".wo.", ".o_proj.")] class RopeStyle(Enum): NONE = 0 @@ -100,6 +117,18 @@ def __init__(self, arch_string, read_config): self.orig_weights_transposed = False self.logit_scale_basedim = False + self.norm_key_1_post = None + self.norm_key_2_post = None + + self.swa = False + self.alternating_swa = False + + self.eager_attn_only = False + self.clamp_hidden_states = False + self.residual_stream_fp32 = False + + self.fused_qkv_altpack = False + # Mistral if arch_string == "MistralForCausalLM": @@ -305,6 +334,45 @@ def __init__(self, arch_string, read_config): self.mqa = False self.scale_attn_weights = False + # Gemma2 + + if arch_string == "Gemma2ForCausalLM": + arch_recognized = True + self.layer_keys += \ + layer_keys_gemma2_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.expect_keys += \ + expect_keys_gemma + self.norm_eps_key = "rms_norm_eps" + self.attention_bias_qkv = False + self.attention_bias_o = False + self.mlp_bias = False + self.mlp_gate = True + self.mlp_key_gate = ".mlp.gate_proj" + self.mlp_key_up = ".mlp.up_proj" + self.mlp_key_down = ".mlp.down_proj" + self.mlp_act_func = "gelu" + self.is_moe = False + self.norm = "rmsnorm" + self.lm_head_key = "model.embed_tokens" + self.normalize_embeddings = True + self.norm_key_1 = ".input_layernorm" + self.norm_key_1_post = ".post_attention_layernorm" + self.norm_key_2 = ".pre_feedforward_layernorm" + self.norm_key_2_post = ".post_feedforward_layernorm" + self.norm_constant_bias = 1 + self.parallel_decoder_blocks = False + self.requires_bos = True + self.rope_style = RopeStyle.NEOX + self.keymap = None + self.fused_qkv_key = None + self.mqa = False + self.scale_attn_weights = False + self.pre_post_layernorm = True + self.alternating_swa = True + self.residual_stream_fp32 = True + # StarCoder2 if arch_string == "Starcoder2ForCausalLM": @@ -586,6 +654,41 @@ def __init__(self, arch_string, read_config): self.scale_attn_weights = False self.logit_scale_basedim = True + # InternLM2 + + if arch_string == "InternLM2ForCausalLM": + arch_recognized = True + self.layer_keys += \ + layer_keys_internlm2_norms + \ + layer_keys_internlm2_attn + \ + layer_keys_internlm2_mlp + self.expect_keys += \ + expect_keys_llama + self.norm_eps_key = "rms_norm_eps" + self.attention_bias_qkv = False + self.attention_bias_o = False + self.mlp_bias = False + self.mlp_gate = True + self.mlp_key_gate = ".feed_forward.w1" + self.mlp_key_up = ".feed_forward.w3" + self.mlp_key_down = ".feed_forward.w2" + self.mlp_act_func = "silu" + self.is_moe = False + self.norm = "rmsnorm" + self.lm_head_key = "lm_head" + self.normalize_embeddings = False + self.norm_key_1 = ".attention_norm" + self.norm_key_2 = ".ffn_norm" + self.norm_constant_bias = 0 + self.parallel_decoder_blocks = False + self.requires_bos = False + self.rope_style = RopeStyle.NEOX + self.keymap = internlm2_keymap + self.fused_qkv_key = "wqkv" + self.fused_qkv_altpack = True + self.mqa = False + self.scale_attn_weights = False + # Llama (default + fallback) if arch_string != "LlamaForCausalLM" and not arch_recognized: @@ -637,6 +740,11 @@ def __init__(self, arch_string, read_config): self.expect_keys.remove(["lm_head"]) self.lm_head_key = "model.embed_tokens" + # Sanity checks + + if self.residual_stream_fp32: + assert self.norm_key_1_post and self.norm_key_2_post, \ + "FP32 residual stream only implemented for arch with post layernorms" def make_fused_mlp(self): diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 8157760d..56ce3860 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -15,6 +15,8 @@ import math # from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak import torch.nn.functional as F +import inspect +import os # from line_profiler import profile from typing import TYPE_CHECKING @@ -25,44 +27,57 @@ has_flash_attn = False has_flash_attn_with_paged = False +has_flash_attn_with_window = False +has_flash_attn_with_softcap = False +if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ: -try: - import flash_attn - flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] - is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) + try: + import flash_attn + flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] + is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) - if not is_ampere_or_newer_gpu: - print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") + if not is_ampere_or_newer_gpu: + print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") - if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: - from flash_attn import flash_attn_func - has_flash_attn = True + if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: + from flash_attn import flash_attn_func + has_flash_attn = True - if [2, 5, 7] <= flash_attn_ver: - from flash_attn import flash_attn_func, flash_attn_with_kvcache - import flash_attn_2_cuda as flash_attn_cuda + if [2, 5, 7] <= flash_attn_ver: + from flash_attn import flash_attn_func, flash_attn_with_kvcache + # import flash_attn_2_cuda as flash_attn_cuda - has_flash_attn = True - has_flash_attn_with_paged = True + has_flash_attn = True + has_flash_attn_with_paged = True + + signature = list(inspect.signature(flash_attn_func).parameters) + has_flash_attn_with_window = "window_size" in signature + has_flash_attn_with_softcap = "softcap" in signature + + except ModuleNotFoundError: + pass -except ModuleNotFoundError: - pass has_xformers = False -try: - import xformers.ops as xops - # LowerTriangularFromBottomRightMask was added in xformers version 2.4 - from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask - has_xformers = True -except ModuleNotFoundError: - pass +if 'EXLLAMA_NO_XFORMERS' not in os.environ: + + try: + import xformers.ops as xops + # LowerTriangularFromBottomRightMask was added in xformers version 2.4 + from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask + has_xformers = True + except ModuleNotFoundError: + pass + has_lower_right_sdpa = False -try: - from torch.nn.attention.bias import causal_lower_right - has_lower_right_sdpa = True -except ImportError: - pass +if 'EXLLAMA_NO_SDPA' not in os.environ: + try: + from torch.nn.attention.bias import causal_lower_right + has_lower_right_sdpa = True + except ImportError: + pass + def assert_paged_attn(): global has_flash_attn_with_paged @@ -75,7 +90,8 @@ class ExLlamaV2Attention(ExLlamaV2Module): name: str = "Attention" layer_idx: int - input_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + pre_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + post_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None q_proj: ExLlamaV2Linear | None k_proj: ExLlamaV2Linear | None v_proj: ExLlamaV2Linear | None @@ -97,6 +113,8 @@ class ExLlamaV2Attention(ExLlamaV2Module): has_norm: bool has_residual: bool + scaling: float + sliding_window: int class Params: @@ -164,9 +182,9 @@ def get_past_lens(self, device) -> torch.Tensor | None: self.past_lens_tensor = safe_move_tensor(self.past_lens_tensor, device) return self.past_lens_tensor - def get_attn_mask(self, device) -> torch.Tensor | None: + def get_attn_mask(self, device, force: bool = False) -> torch.Tensor | None: if self.attn_mask is None: - self.attn_mask = self.build_attn_mask(device) + self.attn_mask = self.build_attn_mask(device, force) elif self.attn_mask.device != device: self.attn_mask = safe_move_tensor(self.attn_mask, device) return self.attn_mask @@ -189,9 +207,9 @@ def build_single_attn_mask(self, batch_size, seq_len, past_len, device, input_ma attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part) return attn_mask - def build_attn_mask(self, device) -> torch.Tensor | None: + def build_attn_mask(self, device, force: bool = False) -> torch.Tensor | None: assert not self.multi_cache, "Building single mask for multiple caches" - if self.input_mask is None and self.seq_len == 1: return None + if self.input_mask is None and self.seq_len == 1 and not force: return None return self.build_single_attn_mask(self.batch_size, self.seq_len, self.past_len, device, self.input_mask) def build_attn_masks(self, device) -> torch.Tensor | None: @@ -274,7 +292,8 @@ def __init__(self, key: str, layer_idx: int, has_norm: bool = True, - has_residual: bool = True): + has_residual: bool = True, + sliding_window: int = 0): super().__init__(model, key) @@ -291,11 +310,14 @@ def __init__(self, if self.has_norm: if cfg.arch.norm == "layernorm": - self.input_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1) + self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1) + self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None elif cfg.arch.norm == "rmsnorm": - self.input_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1) + self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1) + self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_1_post) if cfg.arch.norm_key_1_post else None else: - self.input_layernorm = None + self.pre_layernorm = None + self.post_layernorm = None f_a = 0 f_b = cfg.num_attention_heads * cfg.head_dim @@ -303,9 +325,9 @@ def __init__(self, f_d = f_c + cfg.num_key_value_heads * cfg.head_dim f_key = (key + ".self_attn." + cfg.arch.fused_qkv_key) if cfg.arch.fused_qkv_key else None - self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b) - self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c) - self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d) + self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, cfg.num_attention_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = cfg.arch.fused_qkv_altpack) + self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = cfg.arch.fused_qkv_altpack) + self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, cfg.num_key_value_heads * cfg.head_dim, cfg.arch.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = cfg.arch.fused_qkv_altpack) self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", cfg.num_attention_heads * cfg.head_dim, hidden_size, cfg.arch.attention_bias_o, prescale = cfg.scale_depth) if cfg.use_qk_norm: @@ -319,18 +341,19 @@ def __init__(self, self.k_proj, self.v_proj, self.o_proj] - if self.has_norm: - self.submodules += [self.input_layernorm] + if self.pre_layernorm: + self.submodules += [self.pre_layernorm] + if self.post_layernorm: + self.submodules += [self.post_layernorm] if cfg.use_qk_norm: - self.submodules += [self.q_norm, - self.k_norm] + self.submodules += [self.q_norm, self.k_norm] + + if cfg.query_pre_attn_scalar: + self.scaling = cfg.query_pre_attn_scalar ** (-0.5) + else: + self.scaling = 1 / math.sqrt(cfg.head_dim) - # if cfg.arch.scale_attn_weights: - # self.unscale_factor = self.layer_idx + 1 - # self.scale_factor = 1 / self.unscale_factor - # else: - self.unscale_factor = 1 - self.scale_factor = 1 + self.sliding_window = sliding_window def numel(self) -> int: @@ -340,7 +363,8 @@ def numel(self) -> int: self.v_proj.numel() + \ self.o_proj.numel() - if self.input_layernorm is not None: numel += self.input_layernorm.numel() + if self.pre_layernorm is not None: numel += self.pre_layernorm.numel() + if self.post_layernorm is not None: numel += self.post_layernorm.numel() if self.q_norm is not None: numel += self.q_norm.numel() if self.k_norm is not None: numel += self.k_norm.numel() @@ -352,7 +376,8 @@ def load(self): cfg = self.model.config - if self.input_layernorm is not None: self.input_layernorm.load() + if self.pre_layernorm is not None: self.pre_layernorm.load() + if self.post_layernorm is not None: self.post_layernorm.load() self.q_proj.load() self.k_proj.load() self.v_proj.load() @@ -374,16 +399,23 @@ def load(self): # self.temp_kv = device_tensors.get_scratch_slice(self.temp_kv_size()) if cfg.num_attention_heads != cfg.num_key_value_heads else None if self.has_norm: - norm_weight = self.input_layernorm.weight if self.input_layernorm.weight is not None else none_tensor - norm_bias = self.input_layernorm.bias if self.input_layernorm.bias is not None else none_tensor - is_rms = isinstance(self.input_layernorm, ExLlamaV2RMSNorm) - eps = self.input_layernorm.variance_epsilon + norm_weight = self.pre_layernorm.weight if self.pre_layernorm.weight is not None else none_tensor + norm_bias = self.pre_layernorm.bias if self.pre_layernorm.bias is not None else none_tensor + is_rms = isinstance(self.pre_layernorm, ExLlamaV2RMSNorm) + eps = self.pre_layernorm.variance_epsilon else: norm_weight = none_tensor norm_bias = none_tensor is_rms = False eps = 0 + if self.post_layernorm is not None: + post_norm_weight = self.post_layernorm.weight if self.post_layernorm.weight is not None else none_tensor + post_norm_bias = self.post_layernorm.bias if self.post_layernorm.bias is not None else none_tensor + else: + post_norm_weight = none_tensor + post_norm_bias = none_tensor + if self.q_norm is None: q_norm = none_tensor else: @@ -417,7 +449,10 @@ def load(self): self.has_residual, cfg.arch.rope_style.value, q_norm, - k_norm + k_norm, + post_norm_weight, + post_norm_bias, + cfg.arch.residual_stream_fp32 ) @@ -426,7 +461,8 @@ def unload(self): ext_c.free_q_attn(self.q_handle) self.q_handle = None - if self.input_layernorm is not None: self.input_layernorm.unload() + if self.pre_layernorm is not None: self.pre_layernorm.unload() + if self.post_layernorm is not None: self.post_layernorm.unload() if self.q_proj is not None: self.q_proj.unload() if self.k_proj is not None: self.k_proj.unload() if self.v_proj is not None: self.v_proj.unload() @@ -445,8 +481,10 @@ def weight_footprint(self): self.k_proj.weight_footprint() + \ self.v_proj.weight_footprint() + \ self.o_proj.weight_footprint() - if self.input_layernorm is not None: - fp += self.input_layernorm.weight_footprint() + if self.pre_layernorm is not None: + fp += self.pre_layernorm.weight_footprint() + if self.post_layernorm is not None: + fp += self.post_layernorm.weight_footprint() if self.q_norm is not None: fp += self.q_norm.weight_footprint() if self.k_norm is not None: @@ -475,7 +513,7 @@ def scratch_space(self): def temp_state_size(self): cfg = self.model.config - return cfg.max_input_len * cfg.max_batch_size * cfg.num_attention_heads * cfg.head_dim * 2 + 128 + return cfg.max_input_len * cfg.max_batch_size * max(cfg.num_attention_heads * cfg.head_dim, cfg.hidden_size) * 2 + 128 def temp_q_size(self): @@ -530,7 +568,8 @@ def temp_attn_size(self): def set_device_idx(self, idx): super().set_device_idx(idx) - if self.input_layernorm is not None: self.input_layernorm.set_device_idx(idx) + if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx) + if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx) self.q_proj.set_device_idx(idx) self.k_proj.set_device_idx(idx) self.v_proj.set_device_idx(idx) @@ -614,7 +653,7 @@ def forward_paged(self, ) else: residual = hidden_states - hidden_states = self.input_layernorm.forward(hidden_states) if self.has_norm else hidden_states + hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states q = self.q_proj.forward(hidden_states, loras = loras) k = self.k_proj.forward(hidden_states, loras = loras) v = self.v_proj.forward(hidden_states, loras = loras) @@ -652,30 +691,31 @@ def forward_paged(self, if cache.q_block == 1: cache.get_kv_state(self.layer_idx, batch_size, 0, attn_params.max_cache_seqlen, page_size, cache_seqlens, block_table) - # attn_output = flash_attn_with_kvcache( - # q = q, - # k = k, - # v = v, - # k_cache = k_cache, - # v_cache = v_cache, - # cache_seqlens = cache_seqlens_a, - # block_table = block_table, - # causal = True - # ) - attn_output, _ = flash_attn_cuda.fwd_kvcache( - q, k_cache, v_cache, k, v, - cache_seqlens_a, - None, None, - None, - block_table, - None, - None, - 1 / math.sqrt(cfg.head_dim), - True, - -1, -1, - True, - 0, + flash_kwargs = {} + if self.sliding_window: + # assert has_flash_attn_with_window, \ + # "Installed version of flash-attn does not support sliding window" + if has_flash_attn_with_window: + flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) + if cfg.attn_logit_softcapping: + # assert has_flash_attn_with_softcap, \ + # "Installed version of flash-attn does not support softcapping" + if has_flash_attn_with_softcap: + flash_kwargs["softcap"] = cfg.attn_logit_softcapping + + attn_output = flash_attn_with_kvcache( + q = q, + k = k, + v = v, + k_cache = k_cache, + v_cache = v_cache, + cache_seqlens = cache_seqlens_a, + block_table = block_table, + causal = True, + softmax_scale = self.scaling, + **flash_kwargs ) + attn_output = attn_output.view((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table) @@ -694,6 +734,8 @@ def forward_paged(self, ) else: hidden_states = self.o_proj.forward(attn_output, loras = loras) + if self.post_layernorm: + hidden_states = self.post_layernorm.forward(hidden_states) if self.has_residual: hidden_states += residual @@ -702,32 +744,50 @@ def forward_paged(self, def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): - if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa: + q_states = q_states.transpose(1, 2) + k_states = k_states.transpose(1, 2) + v_states = v_states.transpose(1, 2) + + # SDPA - q_states = q_states.transpose(1, 2) - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) + if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) + if self.sliding_window and k_states.shape[2] >= self.sliding_window: + k_states = k_states[:, :, -self.sliding_window:, :] + v_states = v_states[:, :, -self.sliding_window:, :] + attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) - attn_output = F.scaled_dot_product_attention(q_states, k_states, v_states, attn_mask_lr) + attn_output = F.scaled_dot_product_attention( + q_states, + k_states, + v_states, + attn_mask_lr, + scale = self.scaling + ) - else: + # Matmul attn - q_states = q_states.transpose(1, 2) - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) + else: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) k_states = k_states.transpose(-1, -2) attn_weights = torch.matmul(q_states, k_states) - attn_weights *= 1 / math.sqrt(cfg.head_dim) + attn_weights *= self.scaling attn_mask = attn_params.get_attn_mask(attn_weights.device) - if attn_mask is not None: attn_weights = attn_weights + attn_mask + + if cfg.attn_logit_softcapping: + ext_c.softcap_(attn_weights, cfg.attn_logit_softcapping) + if attn_mask is not None: + attn_weights = attn_weights + attn_mask + if self.sliding_window and k_states.shape[-1] >= self.sliding_window: + attn_weights = attn_weights[:, :, :, -self.sliding_window:] + v_states = v_states[:, :, -self.sliding_window:, :] + attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) @@ -740,11 +800,25 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): + flash_kwargs = {} + if self.sliding_window: + # assert has_flash_attn_with_window, \ + # "Installed version of flash-attn does not support sliding window" + if has_flash_attn_with_window: + flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) + if cfg.attn_logit_softcapping: + # assert has_flash_attn_with_softcap, \ + # "Installed version of flash-attn does not support softcapping" + if has_flash_attn_with_softcap: + flash_kwargs["softcap"] = cfg.attn_logit_softcapping + attn_output = flash_attn_func( q_states, k_states, v_states, - causal = True + causal = True, + softmax_scale = self.scaling, + **flash_kwargs ) attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) return attn_output @@ -752,6 +826,12 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): + # assert not self.sliding_window, \ + # "Sliding window not currently supported for xformers" + + # assert not cfg.attn_logit_softcapping, \ + # "Softcap not yet supported for xformers" + # xformers memory_efficient_attention, could be beneficial if your device's architecture is less than sm_80 are almost the same. But the martix operation @@ -770,7 +850,8 @@ def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_p q_states, k_states, v_states, - attn_bias = LowerTriangularFromBottomRightMask() + attn_bias = LowerTriangularFromBottomRightMask(), + scale = self.scaling ) attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) @@ -914,6 +995,9 @@ def forward(self, pass_lora_temp ) + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + return hidden_states @@ -942,7 +1026,7 @@ def forward_torch(self, # Project q, k, v residual = hidden_states - post_norm = self.input_layernorm.forward(hidden_states) if self.has_norm else hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states query_states = self.q_proj.forward(post_norm, loras = loras) key_states = self.k_proj.forward(post_norm, loras = loras) @@ -1013,10 +1097,20 @@ def forward_torch(self, attn_proj = self.o_proj.forward(attn_output, loras = loras) + # Post layernorm + + if self.post_layernorm: + attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = cfg.arch.residual_stream_fp32) + # Add residual connection hidden_states = (attn_proj + residual) if self.has_residual else attn_proj + if cfg.arch.residual_stream_fp32: + hidden_states = hidden_states.float() + elif cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + if intermediates: return {"post_norm": post_norm, "attn_output": attn_output, diff --git a/exllamav2/config.py b/exllamav2/config.py index 4d3e3ffd..162ef9d2 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -100,6 +100,10 @@ class ExLlamaV2Config: scale_depth: float scale_emb: float use_qk_norm: bool + query_pre_attn_scalar: float | None + final_logit_softcapping: float | None + attn_logit_softcapping: float | None + sliding_window: int checkpoint_fused_mlp: bool @@ -162,9 +166,9 @@ def prepare(self, no_tensors: bool = False): # Load generation_config.json - self.generation_config_path = os.path.join(self.model_dir, "generation_config.json") - if os.path.exists(self.generation_config_path): - with open(self.generation_config_path, encoding = "utf8") as f: + generation_config_path = os.path.join(self.model_dir, "generation_config.json") + if os.path.exists(generation_config_path): + with open(generation_config_path, encoding = "utf8") as f: gen_config = json.load(f) self.generation_config = {} try: @@ -175,8 +179,7 @@ def prepare(self, no_tensors: bool = False): self.generation_config['eos_token_id'] = [eos_token_id_as_int] else: self.generation_config['eos_token_id'] = None - - + # Model architecture assert len(read_config["architectures"]) == 1, "Multiple architectures defined in config.json" @@ -218,6 +221,8 @@ def prepare(self, no_tensors: bool = False): self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False) + self.query_pre_attn_scalar = read(read_config, float, "query_pre_attn_scalar", None) + # MLP params if self.arch.default_inner_dim_mult is not None: @@ -243,6 +248,9 @@ def prepare(self, no_tensors: bool = False): else: self.scale_depth = scale_depth / math.sqrt(self.num_hidden_layers) + self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None) + self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None) + # Positional embeddings self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0) @@ -254,6 +262,8 @@ def prepare(self, no_tensors: bool = False): "n_positions"], 2048) self.original_max_seq_len = self.max_seq_len + self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0) + rs = read(read_config, dict, "rope_scaling", None) if rs: scaling_type = rs.get("type", None) @@ -332,4 +342,54 @@ def prepare(self, no_tensors: bool = False): if not match: raise ValueError(f" ## Could not find {prefix}.* in model") - x = 0 \ No newline at end of file + x = 0 + + + def arch_compat_overrides(self, quiet: bool = False, warn_only = False): + + from exllamav2.attn import ( + has_flash_attn, + has_flash_attn_with_window, + has_flash_attn_with_softcap, + has_xformers + ) + + warnings = [] + + if self.arch.eager_attn_only: + warnings.append(" !! Warning: Architecture currently supports only eager attention") + if not warn_only: + warnings.append(" !! Warning: flash-attn, xformers and SDPA are disabled") + self.no_flash_attn = True + self.no_xformers = True + self.no_sdpa = True + else: + warnings.append(" !! Warning: flash-attn, xformers and SDPA should be disabled for correct inference") + + if has_flash_attn and not self.no_flash_attn: + disable = False + if self.attn_logit_softcapping and not has_flash_attn_with_softcap: + warnings.append(" !! Warning: model requires softcap, not supported in installed version of flash-attn") + disable = True + if (self.arch.swa or self.arch.alternating_swa) and not has_flash_attn_with_window: + warnings.append(" !! Warning: model requires SWA, not supported in installed version of flash-attn") + disable = True + if disable and not warn_only: + warnings.append(" !! Warning: disabling flash-attn") + self.no_flash_attn = True + + if has_xformers and not self.no_xformers: + disable = False + if self.attn_logit_softcapping: + warnings.append(" !! Warning: model requires softcap, not supported in xformers") + disable = True + if self.arch.swa or self.arch.alternating_swa: + warnings.append(" !! Warning: model requires SWA, not supported in xformers") + disable = True + if disable and not warn_only: + warnings.append(" !! Warning: disabling xformers") + self.no_xformers = True + + if not quiet: + for w in warnings: + print(w) diff --git a/exllamav2/conversion/compile.py b/exllamav2/conversion/compile.py index 22733965..4a9320fe 100644 --- a/exllamav2/conversion/compile.py +++ b/exllamav2/conversion/compile.py @@ -32,6 +32,8 @@ def _dsize(d): def get_f_module(job, module): + if module is None: return None + mod_dict = {} module.load() w = module.get_weight() @@ -77,7 +79,10 @@ def compile_model(job, save_fn, model): if isinstance(module, ExLlamaV2Attention): - d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.pre_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.post_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) @@ -86,7 +91,10 @@ def compile_model(job, save_fn, model): if isinstance(module, ExLlamaV2MLP): has_gate = model.config.arch.mlp_gate - d = get_f_module(job, module.post_attention_layernorm); out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.pre_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.post_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) if has_gate: d = get_q_module(job, module.gate_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.up_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.down_proj); out_dict.update(d); current_size += _dsize(d) diff --git a/exllamav2/conversion/convert_exl2.py b/exllamav2/conversion/convert_exl2.py index 8a1f194f..2f56cf77 100644 --- a/exllamav2/conversion/convert_exl2.py +++ b/exllamav2/conversion/convert_exl2.py @@ -183,8 +183,8 @@ def save_job(): config = ExLlamaV2Config() config.model_dir = job['in_dir'] -config.qkv_embed = False config.prepare() +config.arch_compat_overrides() # Tokenizer diff --git a/exllamav2/conversion/measure.py b/exllamav2/conversion/measure.py index f42b6706..9315f916 100644 --- a/exllamav2/conversion/measure.py +++ b/exllamav2/conversion/measure.py @@ -159,6 +159,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p total_numel += module.v_proj.numel() total_numel += module.o_proj.numel() + max_accuracy = 0.0 (q_, k_, v_, o_) = (-1, -1, -1, -1) for (q, k, v, o) in qmaps: @@ -177,6 +178,8 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p accuracy = test_error(module, hidden_states, target_states, cache, attn_params) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -187,6 +190,10 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p "o_proj": qjobs[3][o].get_dict() } results.append(r) + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + for x in ["k_proj", "v_proj", "o_proj"] + (["q_proj"] if not keep_q else []): if x in quantizers: del quantizers[x] @@ -216,6 +223,7 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa total_numel += module.up_proj.numel() total_numel += module.down_proj.numel() + max_accuracy = 0.0 if has_gate: (g_, u_, d_) = (-1, -1, -1) @@ -234,6 +242,8 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa accuracy = test_error(module, hidden_states, target_states, cache, attn_params) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -259,6 +269,8 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa accuracy = test_error(module, hidden_states, target_states, cache, attn_params) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -267,6 +279,10 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa "down_proj": qjobs[2][d].get_dict() } results.append(r) + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + for x in ["up_proj", "down_proj", "gate_proj"]: if x in quantizers: del quantizers[x] @@ -311,6 +327,7 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att total_numel += sum(module.w3[i].numel() for i in range(num_experts)) total_numel += sum(module.w2[i].numel() for i in range(num_experts)) + max_accuracy = 0.0 (g_, u_, d_) = (-1, -1, -1) for (g, u, d) in qmaps: @@ -328,6 +345,8 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att accuracy = test_error(module, hidden_states, target_states, cache, attn_mask) print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + max_accuracy = max(accuracy, max_accuracy) + torch.cuda.empty_cache() r = { "accuracy": accuracy, @@ -337,6 +356,10 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att "w2": qjobs[2][d].get_dict() } results.append(r) + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + return results @@ -515,9 +538,24 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): for i in range(len(hidden_states)): x = hidden_states[i].to("cuda:0") + if torch.isnan(x).any(): + print(" ## Measurement/inference error (2)") + os._exit(1) + if torch.isinf(x).any(): + print(" ## Measurement/inference error (3)") + os._exit(1) + outputs = module.forward(x, cache, attn_params, intermediates = True) target_device = "cuda:0" if i < hidden_state_offload_layers else "cpu" + for k, v in outputs.items(): + if torch.isnan(v).any(): + print(f" ## Measurement/inference error (2): {k}") + os._exit(1) + if torch.isinf(v).any(): + print(f" ## Measurement/inference error (3): {k}") + os._exit(1) + # Hessians if mode == "self_attn": diff --git a/exllamav2/conversion/quantize.py b/exllamav2/conversion/quantize.py index c6dd3abf..16f2ff23 100644 --- a/exllamav2/conversion/quantize.py +++ b/exllamav2/conversion/quantize.py @@ -21,6 +21,7 @@ import torch.nn.functional as F import gc from exllamav2.conversion.bot_status import print_stage +from exllamav2.ext import exllamav2_ext as ext_c, none_tensor def list_live_tensors(): @@ -470,6 +471,10 @@ def quant(job, save_fn, model): output = module.forward(x, cache, attn_params) if module.padding > 0: output = output[:, :, :-module.padding] + if model.config.final_logit_softcapping: + output = output.contiguous() + ext_c.softcap_(output, model.config.final_logit_softcapping) + logits = output[:, :-1, :] logits = logits.float() + 1e-10 target_ids = cal_ids[i:i+1, 1:].to("cuda:0") diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index b52c02ac..411dccf9 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -89,6 +89,8 @@ def forward(self, loras = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: + cfg = self.model.config + # If input IDs contain negative values, assume they are padding tokens from a model with not pad_token_id # defined @@ -111,7 +113,7 @@ def forward(self, # Create combined tensor on the target device batch_size, seq_len = input_ids.shape - hidden_size = self.model.config.hidden_size + hidden_size = cfg.hidden_size combined_embeddings = torch.empty(batch_size, seq_len, hidden_size, device = indexed_embeddings.device, dtype = indexed_embeddings.dtype) @@ -124,14 +126,19 @@ def forward(self, standard_mask_ = standard_mask[i] input_ids_ = input_ids[i] standard_ids_ = input_ids_[standard_mask_] - standard_embeddings_ = self.embedding(standard_ids_) + if loras is not None and loras[0].embed_tokens is not None: + standard_embeddings_ = loras[0].embed_tokens(standard_ids_) + else: + standard_embeddings_ = self.embedding(standard_ids_) standard_embeddings_ = safe_move_tensor(standard_embeddings_, indexed_embeddings.device) combined_embeddings[i][standard_mask_] = standard_embeddings_ # Normalization - if self.model.config.arch.normalize_embeddings: - combined_embeddings *= self.model.config.hidden_size ** 0.5 + if cfg.arch.residual_stream_fp32: + combined_embeddings = combined_embeddings.float() + if cfg.arch.normalize_embeddings: + combined_embeddings *= cfg.hidden_size ** 0.5 # Extract indexed embeddings and insert in-place @@ -144,10 +151,15 @@ def forward(self, # Call embedding module if no indexed embeddings else: - hidden_states = self.embedding.forward(hidden_states) - - if self.model.config.arch.normalize_embeddings: - hidden_states *= self.model.config.hidden_size ** 0.5 + if loras is not None and loras[0].embed_tokens is not None: + hidden_states = loras[0].embed_tokens(hidden_states) + else: + hidden_states = self.embedding(hidden_states) + + if cfg.arch.residual_stream_fp32: + hidden_states = hidden_states.float() + if cfg.arch.normalize_embeddings: + hidden_states *= cfg.hidden_size ** 0.5 if intermediates: return {"hidden_states": hidden_states} diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cu b/exllamav2/exllamav2_ext/cuda/layer_norm.cu index 9d112338..e65b21fc 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cu @@ -24,7 +24,8 @@ typedef void (*fp_layer_norm_kernel) const float, const float, const int, - const int + const int, + const bool ); template @@ -37,7 +38,8 @@ __global__ void layer_norm_kernel const float epsilon, const float r_dim, const int rows, - const int dim + const int dim, + const bool add_residual ) { int warp_id = threadIdx.x / WARP_SIZE; @@ -149,7 +151,10 @@ __global__ void layer_norm_kernel half2 nh = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); if (b) nh = __hadd2(nh, b2[column]); // Optional bias - y_row[column] = nh; + if (add_residual) + y_row[column] = __hadd2(nh, y_row[column]); + else + y_row[column] = nh; } } @@ -185,7 +190,8 @@ void layer_norm_cuda half* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual ) { dim3 blockDim, gridDim; @@ -198,5 +204,5 @@ void layer_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim); + kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual); } diff --git a/exllamav2/exllamav2_ext/cuda/layer_norm.cuh b/exllamav2/exllamav2_ext/cuda/layer_norm.cuh index 2b10d4d6..f780794c 100644 --- a/exllamav2/exllamav2_ext/cuda/layer_norm.cuh +++ b/exllamav2/exllamav2_ext/cuda/layer_norm.cuh @@ -14,7 +14,8 @@ void layer_norm_cuda half* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual = false ); #endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cu b/exllamav2/exllamav2_ext/cuda/q_attn.cu index bb8b2d4d..5a165fcc 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cu +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cu @@ -93,7 +93,10 @@ QAttn::QAttn bool _has_residual, int _rope_style, half* _q_norm, - half* _k_norm + half* _k_norm, + half* _post_layernorm, + half* _post_layernorm_bias, + bool _residual_fp32 ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -117,7 +120,10 @@ QAttn::QAttn has_residual(_has_residual), rope_style(_rope_style), q_norm(_q_norm), - k_norm(_k_norm) + k_norm(_k_norm), + post_layernorm(_post_layernorm), + post_layernorm_bias(_post_layernorm_bias), + residual_fp32(_residual_fp32) { } @@ -128,7 +134,7 @@ QAttn::~QAttn() void QAttn::forward_cuda_1 ( cublasHandle_t cublas_handle, - half* x, + void* x, int batch_size, int q_len, int past_len, @@ -142,14 +148,14 @@ void QAttn::forward_cuda_1 half* lora_temp ) { - half* norm_state = x; + half* norm_state = (half*) x; if (layernorm) { if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, q_len * batch_size, hidden_size); + rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, q_len * batch_size, hidden_size, false, residual_fp32, false); else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, q_len * batch_size, hidden_size); + layer_norm_cuda((half*)x, layernorm, layernorm_bias, temp_state, norm_epsilon, q_len * batch_size, hidden_size); norm_state = temp_state; } @@ -195,14 +201,25 @@ void QAttn::forward_cuda_2 ( cublasHandle_t cublas_handle, const half* attn_output, - half* hidden_state, + void* hidden_state, int q_len, int batch_size, const std::vector& loras, half* lora_temp ) { - gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, !has_residual, temp_dq); + if (!post_layernorm) + { + gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, (half*) hidden_state, q_len * batch_size, o_proj->width, o_proj->height, !has_residual, temp_dq); + } + else + { + gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, temp_state, q_len * batch_size, o_proj->width, o_proj->height, true, temp_dq); + if (layernorm_is_rms) + rms_norm_cuda(temp_state, post_layernorm, hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true, false, residual_fp32); + else + layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, (half*) hidden_state, norm_epsilon, q_len * batch_size, hidden_size, true); + } - apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size); + apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, (half*) hidden_state, lora_temp, q_len * batch_size); } diff --git a/exllamav2/exllamav2_ext/cuda/q_attn.cuh b/exllamav2/exllamav2_ext/cuda/q_attn.cuh index 16a47bd5..da9abd43 100644 --- a/exllamav2/exllamav2_ext/cuda/q_attn.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_attn.cuh @@ -19,6 +19,8 @@ public: half* layernorm; half* layernorm_bias; + half* post_layernorm; + half* post_layernorm_bias; bool layernorm_is_rms; float norm_epsilon; @@ -50,12 +52,13 @@ public: std::unordered_map> o_proj_lora; bool has_residual; + bool residual_fp32; int rope_style; QAttn ( half* _layernorm, - half* _layermorm_bias, + half* _layernorm_bias, bool _layernorm_is_rms, float _norm_epsilon, QMatrix* _q_proj, @@ -76,7 +79,10 @@ public: bool _has_residual, int _rope_style, half* _q_norm, - half* _k_norm + half* _k_norm, + half* _post_layernorm, + half* _post_layernorm_bias, + bool _residual_fp32 ); ~QAttn(); @@ -84,7 +90,7 @@ public: void forward_cuda_1 ( cublasHandle_t cublas_handle, - half* x, + void* x, int batch_size, int q_len, int past_len, @@ -102,7 +108,7 @@ public: ( cublasHandle_t cublas_handle, const half* attn_output, - half* hidden_state, + void* hidden_state, int q_len, int batch_size, const std::vector& loras, diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 97c263c9..810aebdb 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -32,7 +32,10 @@ QMLP::QMLP half* _temp_dq, int _max_rows, bool _act_gelu, - bool _has_residual + bool _has_residual, + half* _post_layernorm, + half* _post_layernorm_bias, + bool _residual_fp32 ): layernorm(_layernorm), layernorm_bias(_layernorm_bias), @@ -47,7 +50,10 @@ QMLP::QMLP temp_dq(_temp_dq), max_rows(_max_rows), act_gelu(_act_gelu), - has_residual(_has_residual) + has_residual(_has_residual), + post_layernorm(_post_layernorm), + post_layernorm_bias(_post_layernorm_bias), + residual_fp32(_residual_fp32) { } @@ -57,7 +63,7 @@ QMLP::~QMLP() { void QMLP::forward_ ( cublasHandle_t cublas_handle, - half* x, + void* x, int rows, int columns, const std::vector& loras, @@ -77,14 +83,14 @@ void QMLP::forward_ // Layernorm - half* norm_state = x; + half* norm_state = (half*) x; if (layernorm) { if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, rows, columns); + rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, rows, columns, false, residual_fp32, false); else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, rows, columns); + layer_norm_cuda((half*) x, layernorm, layernorm_bias, temp_state, norm_epsilon, rows, columns); norm_state = temp_state; } @@ -114,11 +120,25 @@ void QMLP::forward_ kernel<<>>(temp_a, rows, intermediate_size, NULL, 0); } - // Down proj + // Down proj without post_layernorm + + if (!post_layernorm) + { + gemm_half_q_half_cuda(cublas_handle, temp_a, down, (half*) x, rows, columns, intermediate_size, !has_residual, temp_dq); + } + + // Down proj with post_layernorm - gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, !has_residual, temp_dq); + else + { + gemm_half_q_half_cuda(cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq); + if (layernorm_is_rms) + rms_norm_cuda(temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32); + else + layer_norm_cuda(temp_state, post_layernorm, post_layernorm_bias, (half*) x, norm_epsilon, rows, columns, true); + } - apply_loras_cuda(cublas_handle, down_proj_lora, loras, down, temp_a, x, lora_temp, rows); + apply_loras_cuda(cublas_handle, down_proj_lora, loras, down, temp_a, (half*) x, lora_temp, rows); } diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh index 179965b2..bb56d36e 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cuh @@ -15,6 +15,8 @@ public: half* layernorm; half* layernorm_bias; + half* post_layernorm; + half* post_layernorm_bias; bool layernorm_is_rms; float norm_epsilon; @@ -36,11 +38,12 @@ public: bool act_gelu; bool has_residual; + bool residual_fp32; QMLP ( half* _layernorm, - half* _layermorm_bias, + half* _layernorm_bias, bool _layernorm_is_rms, float _norm_epsilon, QMatrix* _gate, @@ -52,7 +55,10 @@ public: half* _temp_dq, int _max_rows, bool _act_gelu, - bool _has_residual + bool _has_residual, + half* _post_layernorm, + half* _post_layernorm_bias, + bool _residual_fp32 ); ~QMLP(); @@ -60,7 +66,7 @@ public: void forward_ ( cublasHandle_t cublas_handle, - half* x, + void* x, int rows, int columns, const std::vector& loras, @@ -108,7 +114,7 @@ public: QMoEMLP ( half* _layernorm, - half* _layermorm_bias, + half* _layernorm_bias, bool _layernorm_is_rms, float _norm_epsilon, half* _gate, diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cu b/exllamav2/exllamav2_ext/cuda/rms_norm.cu index eb0d5711..f72bda11 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cu +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cu @@ -17,32 +17,36 @@ typedef void (*fp_rms_norm_kernel) ( + const void*, const half*, - const half*, - half*, + void*, const float, const float, const int, - const int + const int, + const bool, + const bool, + const bool ); template __global__ void rms_norm_kernel ( - const half* __restrict__ x, + const void* __restrict__ x, const half* __restrict__ w, - half* __restrict__ y, + void* __restrict__ y, const float epsilon, const float r_dim, const int rows, - const int dim + const int dim, + const bool add_residual, + const bool input_fp32, + const bool output_fp32 ) { int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int row = blockIdx.x; - const half2* x_row = (const half2*) (x + row * dim); - half2* y_row = (half2*) (y + row * dim); const half2* w2 = (const half2*) w; // Compute sum of squares for each block @@ -50,21 +54,45 @@ __global__ void rms_norm_kernel float sum = 0.0f; float itemf[blocks_per_warp][2]; - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) + if (!input_fp32) + { + const half2* x_row = (const half2*) (((half*)x) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) break; + + half2 x2 = x_row[column]; + float f0 = __half2float(__low2half(x2)); + float f1 = __half2float(__high2half(x2)); + f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f)); + f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f)); + itemf[i][0] = f0; + itemf[i][1] = f1; + sum = fma(f0, f0, sum); + sum = fma(f1, f1, sum); + } + } + else { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim / 2) break; - - half2 x2 = x_row[column]; - float f0 = __half2float(__low2half(x2)); - float f1 = __half2float(__high2half(x2)); - f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f)); - f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f)); - itemf[i][0] = f0; - itemf[i][1] = f1; - sum = fma(f0, f0, sum); - sum = fma(f1, f1, sum); + const float2* x_row = (const float2*) (((float*)x) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) break; + + float2 x2 = x_row[column]; + float f0 = x2.x; + float f1 = x2.y; + itemf[i][0] = f0; + itemf[i][1] = f1; + sum = fma(f0, f0, sum); + sum = fma(f1, f1, sum); + } } // Shuffle to sum across lanes @@ -90,20 +118,58 @@ __global__ void rms_norm_kernel // Normalize x, scaling by w - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) + if (!output_fp32) + { + half2* y_row = (half2*) (((half*)y) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) return; + half2 w2_ = w2[column]; + + float x_itemf0 = itemf[i][0]; + float x_itemf1 = itemf[i][1]; + float w_itemf0 = __half2float(__low2half(w2_)); + float w_itemf1 = __half2float(__high2half(w2_)); + float n0 = x_itemf0 * w_itemf0 * rmf; + float n1 = x_itemf1 * w_itemf1 * rmf; + if (add_residual) + y_row[column] = __hadd2(y_row[column], __halves2half2(__float2half_rn(n0), __float2half_rn(n1))); + else + y_row[column] = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); + } + } + else { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim / 2) return; - half2 w2_ = w2[column]; - - float x_itemf0 = itemf[i][0]; - float x_itemf1 = itemf[i][1]; - float w_itemf0 = __half2float(__low2half(w2_)); - float w_itemf1 = __half2float(__high2half(w2_)); - float n0 = x_itemf0 * w_itemf0 * rmf; - float n1 = x_itemf1 * w_itemf1 * rmf; - y_row[column] = __halves2half2(__float2half_rn(n0), __float2half_rn(n1)); + float2* y_row = (float2*) (((float*)y) + row * dim); + + #pragma unroll + for (int i = 0; i < blocks_per_warp; i++) + { + int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; + if (column >= dim / 2) return; + half2 w2_ = w2[column]; + + float x_itemf0 = itemf[i][0]; + float x_itemf1 = itemf[i][1]; + float w_itemf0 = __half2float(__low2half(w2_)); + float w_itemf1 = __half2float(__high2half(w2_)); + float n0 = x_itemf0 * w_itemf0 * rmf; + float n1 = x_itemf1 * w_itemf1 * rmf; + if (add_residual) + { + float2 y2 = y_row[column]; + y2.x += n0; + y2.y += n1; + y_row[column] = y2; + } + else + { + y_row[column] = make_float2(n0, n1); + } + } } } @@ -133,12 +199,15 @@ fp_rms_norm_kernel pick_rms_norm_kernel(const int blocks_per_warp) void rms_norm_cuda ( - const half* x, + const void* x, const half* w, - half* y, + void* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual, + const bool input_fp32, + const bool output_fp32 ) { dim3 blockDim, gridDim; @@ -151,5 +220,5 @@ void rms_norm_cuda int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2); fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, y, epsilon, r_dim, rows, dim); + kernel<<>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32); } diff --git a/exllamav2/exllamav2_ext/cuda/rms_norm.cuh b/exllamav2/exllamav2_ext/cuda/rms_norm.cuh index 4cb0fea9..2168c0f8 100644 --- a/exllamav2/exllamav2_ext/cuda/rms_norm.cuh +++ b/exllamav2/exllamav2_ext/cuda/rms_norm.cuh @@ -8,12 +8,15 @@ void rms_norm_cuda ( - const half* x, + const void* x, const half* w, - half* y, + void* y, const float epsilon, const int rows, - const int dim + const int dim, + const bool add_residual = false, + const bool input_fp32 = false, + const bool output_fp32 = false ); #endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/cuda/softcap.cu b/exllamav2/exllamav2_ext/cuda/softcap.cu new file mode 100644 index 00000000..e2bd6aaa --- /dev/null +++ b/exllamav2/exllamav2_ext/cuda/softcap.cu @@ -0,0 +1,78 @@ +#include "softcap.cuh" +#include "util.cuh" +#include "../config.h" +#include "matrix_view.cuh" + +#define NUM_THREADS 256 + +__global__ void cuda_softcap_kernel +( + float* __restrict__ x, + const uint64_t numel, + const float scale +) +{ + uint64_t idx = (uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x; + if (idx >= numel) return; + + float v = x[idx]; + v /= scale; + v = tanhf(v); + v *= scale; + x[idx] = v; +} + +void softcap_cuda_ +( + float* x, + const uint64_t numel, + const float scale +) +{ + dim3 blockDim, gridDim; + blockDim.x = NUM_THREADS; + gridDim.x = DIVIDE(numel, NUM_THREADS); + + cuda_softcap_kernel<<>>(x, numel, scale); +} + +// TODO: Profile + +__global__ void h_cuda_softcap_kernel +( + half* __restrict__ x, + const uint64_t numel, + const float scale +) +{ + uint64_t idx = (uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x; + idx *= 2; + if (idx >= numel) return; + half2* x2 = (half2*)(x + idx); + half2 v01 = *x2; + float v0 = __low2float(v01); + float v1 = __high2float(v01); + v0 /= scale; + v1 /= scale; + v0 = tanhf(v0); + v1 = tanhf(v1); + v0 *= scale; + v1 *= scale; + v01 = __floats2half2_rn(v0, v1); + *x2 = v01; +} + +void h_softcap_cuda_ +( + half* x, + const uint64_t numel, + const float scale +) +{ + dim3 blockDim, gridDim; + blockDim.x = NUM_THREADS; + gridDim.x = DIVIDE(numel / 2, NUM_THREADS); + + h_cuda_softcap_kernel<<>>(x, numel, scale); +} + diff --git a/exllamav2/exllamav2_ext/cuda/softcap.cuh b/exllamav2/exllamav2_ext/cuda/softcap.cuh new file mode 100644 index 00000000..4ea2e6f8 --- /dev/null +++ b/exllamav2/exllamav2_ext/cuda/softcap.cuh @@ -0,0 +1,24 @@ +#ifndef _softcap_cuh +#define _softcap_cuh + +#include +#include +#include +#include +#include + +void softcap_cuda_ +( + float* x, + const uint64_t numel, + const float scale +); + +void h_softcap_cuda_ +( + half* x, + const uint64_t numel, + const float scale +); + +#endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_bindings.cpp b/exllamav2/exllamav2_ext/ext_bindings.cpp index 0b61aea3..0fbb5a43 100644 --- a/exllamav2/exllamav2_ext/ext_bindings.cpp +++ b/exllamav2/exllamav2_ext/ext_bindings.cpp @@ -19,6 +19,7 @@ #include "ext_gemm.h" #include "ext_norm.h" #include "ext_rope.h" +#include "ext_element.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -113,4 +114,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // rope m.def("rope_", &rope_, "rope_"); + + // element + + m.def("softcap_", &softcap_, "softcap_"); } \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_element.cpp b/exllamav2/exllamav2_ext/ext_element.cpp new file mode 100644 index 00000000..968c88af --- /dev/null +++ b/exllamav2/exllamav2_ext/ext_element.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" +#include "cuda/softcap.cuh" +#include "cpp/util.h" + +// Apply softcapping inplace: x = scale * tanh(x/scale) + +void softcap_ +( + torch::Tensor x, + float scale +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + uint64_t numel = x.numel(); + + if (x.dtype() == torch::kFloat) + { + softcap_cuda_ + ( + (float*) x.data_ptr(), + numel, + scale + ); + } + else if (x.dtype() == torch::kHalf) + { + h_softcap_cuda_ + ( + (half*) x.data_ptr(), + numel, + scale + ); + } + else + { + TORCH_CHECK(false, "softcap_ wrong dtype"); + } +} diff --git a/exllamav2/exllamav2_ext/ext_element.h b/exllamav2/exllamav2_ext/ext_element.h new file mode 100644 index 00000000..97e6c706 --- /dev/null +++ b/exllamav2/exllamav2_ext/ext_element.h @@ -0,0 +1,6 @@ + +void softcap_ +( + torch::Tensor x, + float scale +); diff --git a/exllamav2/exllamav2_ext/ext_norm.cpp b/exllamav2/exllamav2_ext/ext_norm.cpp index c424491e..5b7377c6 100644 --- a/exllamav2/exllamav2_ext/ext_norm.cpp +++ b/exllamav2/exllamav2_ext/ext_norm.cpp @@ -28,9 +28,9 @@ void rms_norm float epsilon ) { - TORCH_CHECK_DTYPE(x, kHalf); + bool input_fp32 = x.dtype() == torch::kFloat; + bool output_fp32 = y.dtype() == torch::kFloat; TORCH_CHECK_DTYPE(w, kHalf); - TORCH_CHECK_DTYPE(y, kHalf); TORCH_CHECK_SHAPES(x, 1, w, 0, 1); TORCH_CHECK_SHAPES(x, 0, y, 0, 1); TORCH_CHECK_SHAPES(x, 1, y, 1, 1); @@ -42,12 +42,15 @@ void rms_norm rms_norm_cuda ( - (half*) x.data_ptr(), + (void*) x.data_ptr(), (half*) w.data_ptr(), - (half*) y.data_ptr(), + (void*) y.data_ptr(), epsilon, rows, - dim + dim, + false, + input_fp32, + output_fp32 ); } diff --git a/exllamav2/exllamav2_ext/ext_qattn.cpp b/exllamav2/exllamav2_ext/ext_qattn.cpp index c4452489..a9351a09 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.cpp +++ b/exllamav2/exllamav2_ext/ext_qattn.cpp @@ -39,7 +39,10 @@ uintptr_t make_q_attn bool has_residual, int rope_style, torch::Tensor q_norm, - torch::Tensor k_norm + torch::Tensor k_norm, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias, + bool residual_fp32 ) { QMatrix* qm_q_proj = reinterpret_cast (q_q_proj); @@ -48,6 +51,7 @@ uintptr_t make_q_attn QMatrix* qm_o_proj = reinterpret_cast (q_o_proj); TORCH_CHECK_DTYPE_OPT(layernorm, kHalf); + TORCH_CHECK_DTYPE_OPT(post_layernorm, kHalf); if (qm_q_proj && !layernorm.is_meta()) TORCH_CHECK(qm_q_proj->height == layernorm.size(0), "q_proj is wrong shape") if (qm_k_proj && !layernorm.is_meta()) TORCH_CHECK(qm_k_proj->height == layernorm.size(0), "k_proj is wrong shape") @@ -78,7 +82,10 @@ uintptr_t make_q_attn has_residual, rope_style, (half*) q_norm.is_meta() ? NULL : (half*) q_norm.data_ptr(), - (half*) k_norm.is_meta() ? NULL : (half*) k_norm.data_ptr() + (half*) k_norm.is_meta() ? NULL : (half*) k_norm.data_ptr(), + (half*) post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(), + (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr(), + residual_fp32 ); return reinterpret_cast (attn); @@ -111,7 +118,9 @@ void q_attn_forward_1 ) { QAttn* attn = reinterpret_cast (q_attn); - TORCH_CHECK_DTYPE(x, kHalf); + if (attn->residual_fp32) { TORCH_CHECK_DTYPE(x, kFloat); } + else { TORCH_CHECK_DTYPE(x, kHalf); } + TORCH_CHECK_DTYPE_OPT(past_lens, kInt); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); @@ -147,7 +156,8 @@ void q_attn_forward_2 ) { QAttn* attn = reinterpret_cast (q_attn); - TORCH_CHECK_DTYPE(x, kHalf); + if (attn->residual_fp32) { TORCH_CHECK_DTYPE(x, kFloat); } + else { TORCH_CHECK_DTYPE(x, kHalf); } const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); diff --git a/exllamav2/exllamav2_ext/ext_qattn.h b/exllamav2/exllamav2_ext/ext_qattn.h index dfb300b6..938e90b1 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.h +++ b/exllamav2/exllamav2_ext/ext_qattn.h @@ -23,7 +23,10 @@ uintptr_t make_q_attn bool has_residual, int rope_style, torch::Tensor q_norm, - torch::Tensor k_norm + torch::Tensor k_norm, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias, + bool residual_fp32 ); void free_q_attn diff --git a/exllamav2/exllamav2_ext/ext_qmlp.cpp b/exllamav2/exllamav2_ext/ext_qmlp.cpp index 827ac789..d089a5c0 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.cpp +++ b/exllamav2/exllamav2_ext/ext_qmlp.cpp @@ -30,7 +30,10 @@ uintptr_t make_q_mlp torch::Tensor temp_dq, int max_rows, bool act_gelu, - bool has_residual + bool has_residual, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias, + bool residual_fp32 ) { QMatrix* qm_gate = reinterpret_cast (q_gate); @@ -38,6 +41,7 @@ uintptr_t make_q_mlp QMatrix* qm_down = reinterpret_cast (q_down); TORCH_CHECK_DTYPE_OPT(layernorm, kHalf); + TORCH_CHECK_DTYPE_OPT(post_layernorm, kHalf); if (qm_gate && !layernorm.is_meta()) TORCH_CHECK(qm_gate->height == layernorm.size(0), "gate_proj is wrong shape") if (!layernorm.is_meta()) TORCH_CHECK(qm_up->height == layernorm.size(0), "up_proj is wrong shape") @@ -56,7 +60,10 @@ uintptr_t make_q_mlp (half*) temp_dq.data_ptr(), max_rows, act_gelu, - has_residual + has_residual, + (half*) post_layernorm.is_meta() ? NULL : (half*) post_layernorm.data_ptr(), + (half*) post_layernorm_bias.is_meta() ? NULL : (half*) post_layernorm_bias.data_ptr(), + residual_fp32 ); return reinterpret_cast (mlp); @@ -80,7 +87,8 @@ void q_mlp_forward_ ) { QMLP* mlp = reinterpret_cast (q_mlp); - TORCH_CHECK_DTYPE(x, kHalf); + if (mlp->residual_fp32) { TORCH_CHECK_DTYPE(x, kFloat); } + else { TORCH_CHECK_DTYPE(x, kHalf); } const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); @@ -93,7 +101,7 @@ void q_mlp_forward_ mlp->forward_ ( at::cuda::getCurrentCUDABlasHandle(), - (half*) x.data_ptr(), + (void*) x.data_ptr(), rows, dim, loras, diff --git a/exllamav2/exllamav2_ext/ext_qmlp.h b/exllamav2/exllamav2_ext/ext_qmlp.h index 45e7ed1a..3269bfae 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.h +++ b/exllamav2/exllamav2_ext/ext_qmlp.h @@ -14,7 +14,10 @@ uintptr_t make_q_mlp torch::Tensor temp_dq, int max_rows, bool act_gelu, - bool has_residual + bool has_residual, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias, + bool residual_fp32 ); void free_q_mlp diff --git a/exllamav2/ext.py b/exllamav2/ext.py index 2fc40401..53681510 100644 --- a/exllamav2/ext.py +++ b/exllamav2/ext.py @@ -213,6 +213,7 @@ def find_msvc(): "ext_rope.cpp", "ext_safetensors.cpp", "ext_sampling.cpp", + "ext_element.cpp", "cuda/h_add.cu", "cuda/h_gemm.cu", "cuda/lora.cu", @@ -228,6 +229,7 @@ def find_msvc(): "cuda/rope.cu", "cuda/cache.cu", "cuda/util.cu", + "cuda/softcap.cu", "cuda/comp_units/kernel_select.cu", "cuda/comp_units/unit_gptq_1.cu", "cuda/comp_units/unit_gptq_2.cu", diff --git a/exllamav2/generator/base.py b/exllamav2/generator/base.py index c70cea29..236735e0 100644 --- a/exllamav2/generator/base.py +++ b/exllamav2/generator/base.py @@ -15,6 +15,7 @@ import threading from exllamav2.generator.hooks import ExLlamaV2PostSamplingHook, ExLlamaV2PostSamplingResult from exllamav2.embedding import EMBEDDING_INDEX +from exllamav2.util import cuda_sync_active class ExLlamaV2BaseGenerator: @@ -46,7 +47,7 @@ def warmup(self): input_ids = torch.zeros((1, 2), dtype = torch.long) self.model.forward(input_ids, cache = None, input_mask = None, preprocess_only = True) - torch.cuda.synchronize() + cuda_sync_active() def full(self): diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index ad01c140..52924421 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -6,6 +6,7 @@ from exllamav2.cache import ExLlamaV2CacheBase, ExLlamaV2Cache_8bit from exllamav2.attn import ExLlamaV2Attention, assert_paged_attn from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.util import cuda_sync_active from concurrent.futures import ThreadPoolExecutor from exllamav2.compat import pairwise @@ -238,10 +239,10 @@ def __init__( model: ExLlamaV2, cache: ExLlamaV2CacheBase, tokenizer: ExLlamaV2Tokenizer, - max_batch_size: int = 16, + max_batch_size: int = None, max_seq_len: int | None = None, max_chunk_size: int | None = None, - max_q_size: int = 16, + max_q_size: int = 8, draft_model: ExLlamaV2 | None = None, draft_cache: ExLlamaV2CacheBase | None = None, num_draft_tokens: int = 4, @@ -267,7 +268,7 @@ def __init__( :param max_batch_size: The maximum number of sequences to process in parallel. The generator will also limit this - dynamically considering the available cache space. + dynamically considering the available cache space. Specify None to calculate automatically :param max_seq_len: Maximum length of each individual sequence. Defaults to the model's max_seq_len. @@ -324,7 +325,13 @@ def __init__( self.draft_model = draft_model self.draft_cache = draft_cache - self.num_draft_tokens = num_draft_tokens if (draft_model or use_ngram_draft) else 0 + + if draft_model or use_ngram_draft: + assert num_draft_tokens <= max_q_size, \ + "num_draft_tokens cannot be larger than max_q_size." + self.num_draft_tokens = num_draft_tokens + else: + self.num_draft_tokens = 0 if draft_model: assert draft_cache is not None, \ @@ -343,12 +350,16 @@ def __init__( assert not isinstance(cache, ExLlamaV2Cache_8bit), \ "Dynamic generator does not currently work with 8-bit cache. Use either FP16 or Q4." - model_max_q = cfg.max_batch_size * cfg.max_input_len - req_max_q = max_q_size * max_batch_size - assert req_max_q <= model_max_q, \ - f"Model has max_batch_size * max_input_len = {cfg.max_batch_size} * {cfg.max_input_len} tokens, " + \ - f"generator requires max_batch_size * max_q_size = {max_batch_size} * {max_q_size} tokens." - self.max_batch_size = max_batch_size + if not max_batch_size: + max_batch_size = cfg.max_input_len // max_q_size + self.max_batch_size = max_batch_size + else: + model_max_q = cfg.max_batch_size * cfg.max_input_len + req_max_q = max_q_size * max_batch_size + assert req_max_q <= model_max_q, \ + f"Model has max_batch_size * max_input_len = {cfg.max_batch_size} * {cfg.max_input_len} tokens, " + \ + f"generator requires max_batch_size * max_q_size = {max_batch_size} * {max_q_size} tokens." + self.max_batch_size = max_batch_size if max_seq_len is not None: assert max_seq_len <= model.config.max_seq_len, \ @@ -1013,7 +1024,7 @@ def iterate_draftmodel_gen(self, results: list): for job in self.active_jobs: if not job.is_prefill_done(): continue if job.time_first_token is None: - torch.cuda.synchronize() + cuda_sync_active() job.time_first_token = time.time() job_ids = job.get_input_ids_list() input_ids_list += job_ids @@ -1091,7 +1102,7 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None): logit_mapping.append(len(input_ids_list)) if not job.is_prefill_done(): continue if job.time_first_token is None: - torch.cuda.synchronize() + cuda_sync_active() job.time_first_token = time.time() if draft_tokens is None: job_ids = job.get_input_ids_list(add_to_cache = True) @@ -1921,7 +1932,7 @@ def emit( self.held_k_tokens.append(next_k_tokens) self.held_k_probs.append(next_k_probs) if self.return_logits: - self.held_logits.append(logits) + self.held_logits.append(logits[:1, :, :]) # Stop if we reach max_new_tokens diff --git a/exllamav2/generator/dynamic_async.py b/exllamav2/generator/dynamic_async.py index d9e1fee7..8e006822 100644 --- a/exllamav2/generator/dynamic_async.py +++ b/exllamav2/generator/dynamic_async.py @@ -22,7 +22,9 @@ async def _run_iteration(self): try: while True: async with self.condition: - await self.condition.wait_for(lambda: len(self.jobs) > 0) + # Unlock if there's no jobs or if the parent task is cancelled + await self.condition.wait_for(lambda: len(self.jobs) > 0 or self.iteration_task.cancelled()) + results = self.generator.iterate() for result in results: job = result["job"] @@ -31,6 +33,9 @@ async def _run_iteration(self): if result["eos"]: del self.jobs[job] await asyncio.sleep(0) + except asyncio.CancelledError: + # Silently return on cancel + return except Exception as e: # If the generator throws an exception it won't pertain to any one ongoing job, so push it to all of them for async_job in self.jobs.values(): @@ -48,6 +53,9 @@ async def _notify_condition(self): async def close(self): self.iteration_task.cancel() + + # Force a re-check of the condition to unlock the loop + await self._notify_condition() try: await self.iteration_task except asyncio.CancelledError: diff --git a/exllamav2/layernorm.py b/exllamav2/layernorm.py index 39209e6f..112ee8f0 100644 --- a/exllamav2/layernorm.py +++ b/exllamav2/layernorm.py @@ -100,6 +100,7 @@ def forward(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, # TODO: **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: output_shape = hidden_states.shape @@ -126,6 +127,7 @@ def forward_torch(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, # TODO: **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: hidden_states = self.layernorm(hidden_states) diff --git a/exllamav2/linear.py b/exllamav2/linear.py index be505eee..9858b3f1 100644 --- a/exllamav2/linear.py +++ b/exllamav2/linear.py @@ -53,7 +53,8 @@ def __init__(self, f_key: str = None, f_beg: int = None, f_end: int = None, - is_sub_module: bool = True): + is_sub_module: bool = True, + altpack_qkv: bool = False): super().__init__(model, key) self.is_sub_module = is_sub_module @@ -85,6 +86,7 @@ def __init__(self, self.f_key = f_key self.f_beg = f_beg self.f_end = f_end + self.altpack_qkv = altpack_qkv self.assumed_footprint = in_features * (out_features + self.padding) * 2 + 128 @@ -94,7 +96,7 @@ def load(self, w: dict | nn.Parameter | tuple | None = None, device_tensors: bool = True): - if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features) + if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv) if w is None: w = self.load_weight() # Load quantized linear layer from dictionary @@ -239,6 +241,14 @@ def forward(self, # Linear forward + if self.key == 'lm_head' and loras is not None and loras[0].lm_head is not None: + hidden_states_out = loras[0].lm_head(hidden_states) + + if intermediates: + return {"hidden_states": hidden_states_out} + else: + return hidden_states_out + if self.q_handle is not None and not force_recons: output_shape = hidden_states.shape[:-1] + (self.out_features,) diff --git a/exllamav2/lora.py b/exllamav2/lora.py index c4b5ea68..1dd99cbe 100644 --- a/exllamav2/lora.py +++ b/exllamav2/lora.py @@ -53,6 +53,13 @@ def __init__(self, self.target_modules = {} self.bias_ignored = False self.lora_scaling = lora_scaling + self.embed_tokens = None + self.lm_head = None + + # Compatibility check + + assert not self.model.config.arch.residual_stream_fp32, \ + "LoRAs not (yet) supported for models with FP32 residual stream" # Grab relevant items from LoRA config @@ -77,6 +84,29 @@ def __init__(self, tensor = f[key] # Find target + if key.endswith(f'{self.config.arch.lm_head_key}.weight'): + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float16) + elif tensor.dtype == torch.float32: + tensor = tensor.to(torch.float16) + target_module = self.model.modules_dict["lm_head"] + tensor = safe_move_tensor(tensor, target_module.device()) + self.lm_head = torch.nn.Linear(target_module.in_features, tensor.shape[0], bias = False, device = "meta") + self.lm_head.weight = torch.nn.Parameter(tensor, requires_grad=False) + continue + elif key.endswith(f'embed_tokens.weight'): + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float16) + elif tensor.dtype == torch.float32: + tensor = tensor.to(torch.float16) + target_module = self.model.modules_dict["model.embed_tokens"] + tensor = safe_move_tensor(tensor, target_module.device()) + self.embed_tokens = torch.nn.Embedding(tensor.shape[0], self.config.hidden_size, self.config.pad_token_id, device = "meta") + weight = torch.nn.Parameter(tensor, requires_grad=False) + if self.model.config.scale_emb != 1: + weight *= self.model.config.scale_emb + self.embed_tokens.weight = weight + continue i = key.find("model.layers.") if i == -1: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}") diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 63e4bc2b..58a644aa 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -20,7 +20,8 @@ class ExLlamaV2MLP(ExLlamaV2Module): name: str = "MLP" layer_idx: int - post_attention_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + pre_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + post_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None gate_proj: ExLlamaV2Linear | None up_proj: ExLlamaV2Linear | None down_proj: ExLlamaV2Linear | None @@ -56,19 +57,24 @@ def __init__(self, if self.has_norm: if cfg.arch.norm == "layernorm": - self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2) + self.pre_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2) + self.post_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None elif cfg.arch.norm == "rmsnorm": - self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2) + self.pre_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2) + self.post_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2_post) if cfg.arch.norm_key_2_post else None else: - self.post_attention_layernorm = None + self.pre_layernorm = None + self.post_layernorm = None self.up_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_up, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c) self.down_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_down, cfg.intermediate_size, cfg.hidden_size, self.model.config.arch.mlp_bias, prescale = cfg.scale_depth) self.submodules = [self.up_proj, self.down_proj] - if self.has_norm: - self.submodules += [self.post_attention_layernorm] + if self.pre_layernorm: + self.submodules += [self.pre_layernorm] + if self.post_layernorm: + self.submodules += [self.post_layernorm] if cfg.arch.mlp_gate: self.gate_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_gate, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b) @@ -85,8 +91,10 @@ def numel(self) -> int: if self.model.config.arch.mlp_gate: numel += self.gate_proj.numel() - if self.post_attention_layernorm is not None: - numel += self.post_attention_layernorm.numel() + if self.pre_layernorm is not None: + numel += self.pre_layernorm.numel() + if self.post_layernorm is not None: + numel += self.pre_layernorm.numel() return numel @@ -96,8 +104,10 @@ def load(self): cfg = self.model.config - if self.post_attention_layernorm is not None: - self.post_attention_layernorm.load() + if self.pre_layernorm is not None: + self.pre_layernorm.load() + if self.post_layernorm is not None: + self.post_layernorm.load() if cfg.checkpoint_fused_mlp: w12 = self.load_weight(self.key + cfg.arch.fused_mlp_key_12) @@ -119,16 +129,23 @@ def load(self): device_tensors.begin_scratch_alloc() if self.has_norm: - norm_weight = self.post_attention_layernorm.weight if self.post_attention_layernorm.weight is not None else none_tensor - norm_bias = self.post_attention_layernorm.bias if self.post_attention_layernorm.bias is not None else none_tensor - is_rms = isinstance(self.post_attention_layernorm, ExLlamaV2RMSNorm) - eps = self.post_attention_layernorm.variance_epsilon + norm_weight = self.pre_layernorm.weight if self.pre_layernorm.weight is not None else none_tensor + norm_bias = self.pre_layernorm.bias if self.pre_layernorm.bias is not None else none_tensor + is_rms = isinstance(self.pre_layernorm, ExLlamaV2RMSNorm) + eps = self.pre_layernorm.variance_epsilon else: norm_weight = none_tensor norm_bias = none_tensor is_rms = False eps = 0 + if self.post_layernorm is not None: + post_norm_weight = self.post_layernorm.weight if self.post_layernorm.weight is not None else none_tensor + post_norm_bias = self.post_layernorm.bias if self.post_layernorm.bias is not None else none_tensor + else: + post_norm_weight = none_tensor + post_norm_bias = none_tensor + self.q_handle = ext_c.make_q_mlp(norm_weight, norm_bias, is_rms, @@ -142,7 +159,10 @@ def load(self): device_tensors.get_scratch_slice(self.temp_dq_size()), cfg.max_input_len * cfg.max_batch_size, cfg.arch.mlp_act_func == "gelu", - self.has_residual) + self.has_residual, + post_norm_weight, + post_norm_bias, + cfg.arch.residual_stream_fp32) def unload(self): @@ -151,7 +171,8 @@ def unload(self): ext_c.free_q_mlp(self.q_handle) self.q_handle = None - if self.post_attention_layernorm is not None: self.post_attention_layernorm.unload() + if self.pre_layernorm is not None: self.pre_layernorm.unload() + if self.post_layernorm is not None: self.post_layernorm.unload() if self.gate_proj is not None: self.gate_proj.unload() self.up_proj.unload() self.down_proj.unload() @@ -167,8 +188,10 @@ def weight_footprint(self) -> int: if self.gate_proj is not None: fp += self.gate_proj.weight_footprint() - if self.post_attention_layernorm is not None: - fp += self.post_attention_layernorm.weight_footprint() + if self.pre_layernorm is not None: + fp += self.pre_layernorm.weight_footprint() + if self.post_layernorm is not None: + fp += self.post_layernorm.weight_footprint() return fp @@ -219,8 +242,10 @@ def temp_dq_size(self) -> int: def set_device_idx(self, idx: int): super().set_device_idx(idx) - if self.post_attention_layernorm is not None: - self.post_attention_layernorm.set_device_idx(idx) + if self.pre_layernorm is not None: + self.pre_layernorm.set_device_idx(idx) + if self.post_layernorm is not None: + self.post_layernorm.set_device_idx(idx) if self.gate_proj is not None: self.gate_proj.set_device_idx(idx) self.up_proj.set_device_idx(idx) self.down_proj.set_device_idx(idx) @@ -236,6 +261,8 @@ def forward(self, loras: list[ExLlamaV2Lora] | None = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: + cfg = self.model.config + if self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs) @@ -251,6 +278,9 @@ def forward(self, pass_loras, pass_lora_temp) + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + return hidden_states @@ -266,15 +296,15 @@ def forward_torch(self, cfg = self.model.config residual = hidden_states - post_norm = self.post_attention_layernorm.forward(hidden_states) \ - if self.has_norm else hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) \ + if self.pre_layernorm else hidden_states if self.gate_proj is not None: gate = self.gate_proj.forward(post_norm, loras = loras) if cfg.arch.mlp_act_func == "silu": y = F.silu(gate) elif cfg.arch.mlp_act_func == "gelu": - y = F.gelu(gate) + y = F.gelu(gate, approximate = "tanh") up = self.up_proj.forward(post_norm, loras = loras) y *= up y.clamp_(min = -65504.0, max = 65504.0) @@ -283,11 +313,18 @@ def forward_torch(self, if cfg.arch.mlp_act_func == "silu": y = F.silu(up) elif cfg.arch.mlp_act_func == "gelu": - y = F.gelu(up) + y = F.gelu(up, approximate = "tanh") down = self.down_proj.forward(y, loras = loras) + if self.post_layernorm: + down = self.post_layernorm.forward(down, output_fp32 = cfg.arch.residual_stream_fp32) hidden_states = down + residual if self.has_residual else down + if cfg.arch.residual_stream_fp32: + hidden_states = hidden_states.float() + elif cfg.arch.clamp_hidden_states: + hidden_states = hidden_states.clamp(-65504, 65504) + if intermediates: return {"post_norm": post_norm, "pre_down": y, diff --git a/exllamav2/model.py b/exllamav2/model.py index 7296b6e2..d21031df 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -53,6 +53,7 @@ # from exllamav2.util import list_live_tensors, print_vram_usage, set_snapshot, diff_snapshot, print_vram_usage_peak from exllamav2.util import get_basic_progress # from line_profiler import profile +from exllamav2.ext import exllamav2_ext as ext_c, none_tensor def _torch_device(idx): @@ -226,7 +227,13 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx) self.modules += [pd] else: - attn = ExLlamaV2Attention(self, layer_key, layer_idx) + if self.config.arch.alternating_swa: + swa = self.config.sliding_window if not bool(layer_idx % 2) else 0 + elif self.config.arch.swa: + swa = self.config.sliding_window + else: + swa = 0 + attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx) else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx) self.modules += [attn, mlp] @@ -417,7 +424,7 @@ def load_gen( def load_autosplit( self, cache: ExLlamaV2CacheBase, - reserve_vram: int | None = None, + reserve_vram: int | list[int] | None = None, last_id_only: bool = False, callback: Callable[[int, int], None] | None = None, callback_gen: Callable[[int, int], None] | None = None, @@ -443,7 +450,7 @@ def callback_pb(a, b): def load_autosplit_gen( self, cache: ExLlamaV2CacheBase, - reserve_vram: int | None = None, + reserve_vram: int | list[int] | None = None, last_id_only: bool = False, callback: Callable[[int, int], None] | None = None, callback_gen: Callable[[int, int], None] | None = None @@ -466,6 +473,8 @@ def load_autosplit_gen( if reserve_vram is None: reserve_vram = [192 * 1024**2] + [64 * 1024**2] * (num_devices - 1) + elif isinstance(reserve_vram, int): + reserve_vram = [reserve_vram] * num_devices reserved_vram_tensors = [] minimum_reserve_tensor = None @@ -672,6 +681,7 @@ def forward(self, return_last_state: bool = False, position_offsets: torch.Tensor | None = None, abort_event: threading.Event | None = None, + cpu_logits: bool = False, **kwargs) \ -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: """ @@ -708,6 +718,11 @@ def forward(self, :param abort_event: Optional event that, if set, will abort the forward pass. Function will return None if aborted. + :param cpu_logits: + If True, logits are collected and returned in system RAM. This is somewhat slower but can prevent + out-of-memory errors when computing logits for all positions in a long sequence, such as during a + perplexity test. + :return: FP16 logits tensor, shape (batch_size, q_len, vocab_size) (optional) state tensor, shape (batch_size, q_len, hidden_size) @@ -810,6 +825,8 @@ def forward(self, if abort_event and abort_event.is_set(): return if not _preprocess_only: + if cpu_logits: + r["logits"] = r["logits"].cpu() result = r["logits"] if result is None else torch.cat((result, r["logits"]), dim = 1) chunk_begin = chunk_end @@ -915,6 +932,9 @@ def forward_chunk(self, # if x is not None and self.config.logit_scale != 1: # x.mul_(self.config.logit_scale) + if x is not None and self.config.final_logit_softcapping: + ext_c.softcap_(x, self.config.final_logit_softcapping) + # Set padding logits to -inf if x is not None: diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index 905c3bc0..17730ce2 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -21,6 +21,9 @@ def add_args(parser): parser.add_argument("-ept", "--experts_per_token", type = int, help = "Override MoE model's default number of experts per token") parser.add_argument("-lq4", "--load_q4", action = "store_true", help = "Load weights in Q4 mode") parser.add_argument("-fst", "--fast_safetensors", action = "store_true", help = "Use alternative safetensors loader (with direct I/O when available)") + parser.add_argument("-ic", "--ignore_compatibility", action = "store_true", help = "Do not override model config options in case of compatibility issues") + parser.add_argument("-chunk", "--chunk_size", type = int, help = "Chunk size ('input length')") + def print_options(args): @@ -39,6 +42,8 @@ def print_options(args): if hasattr(args, "fast_safetensors") and args.fast_safetensors: print_opts += ["fast_safetensors"] if args.experts_per_token is not None: print_opts += [f"experts_per_token: {args.experts_per_token}"] if args.load_q4: print_opts += ["load_q4"] + if args.ignore_compatibility: print_opts += ["ignore_compatibility"] + if args.chunk_size is not None: print_opts += [f"chunk_size: {args.chunk_size}"] print(f" -- Options: {print_opts}") @@ -105,6 +110,14 @@ def init(args, if args.low_mem: config.set_low_mem() if args.load_q4: config.load_in_q4 = True + if args.chunk_size is not None: + config.max_input_len = args.chunk_size + config.max_attention_size = args.chunk_size ** 2 + + # Compatibility warnings + + config.arch_compat_overrides(warn_only = args.ignore_compatibility) + # Load model # If --gpu_split auto, return unloaded model. Model must be loaded with model.load_autosplit() supplying cache # created in lazy mode diff --git a/exllamav2/module.py b/exllamav2/module.py index 1c5863dd..bbd393cf 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -126,13 +126,13 @@ def load_weight(self, bias = tensors["bias"].half() if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2: tensor = tensor.T - return nn.Parameter(tensor), nn.Parameter(bias) + return nn.Parameter(tensor, requires_grad = False), nn.Parameter(bias, requires_grad = False) else: tensors = self.load_multi(key, ["weight"]) tensor = tensors["weight"].half() # if self.model.config.arch.orig_weights_transposed: # tensor = tensor.T - return nn.Parameter(tensor) + return nn.Parameter(tensor, requires_grad = False) # No weights found for key @@ -144,27 +144,43 @@ def load_weight_fused(self, f_beg: int, f_end: int, in_feat: int, - out_feat: int): + out_feat: int, + altpack_qkv: bool): res = [] for key in [f_key, f_key + ".weight", f_key + ".bias"]: - filename = self.model.config.tensor_file_map.get(key) + cfg = self.model.config + filename = cfg.tensor_file_map.get(key) if not filename: continue - stfile = STFile.open(filename, fast = self.model.config.fasttensors, keymap = self.model.config.arch.keymap) + stfile = STFile.open(filename, fast = cfg.fasttensors, keymap = cfg.arch.keymap) # tensor = stfile.get_tensor(key, device = self.device()).half() tensor = stfile.get_tensor(key, device = "cpu", cached = True, out_dtype = torch.half) - if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2: + + if cfg.arch.orig_weights_transposed and len(tensor.shape) == 2: tensor = tensor.T + + if altpack_qkv: + ts = tensor.shape + h, gs, d = cfg.num_key_value_heads, cfg.num_key_value_groups + 2, cfg.head_dim + tensor = tensor.view(h, gs, d, -1).transpose(0, 1).reshape(ts) + tensor = tensor[f_beg:f_end] + + if altpack_qkv: + ts = tensor.shape + h, gs, d = cfg.num_key_value_heads, (f_end - f_beg) // cfg.num_key_value_heads // cfg.head_dim, cfg.head_dim + tensor = tensor.view(gs, h, d, -1).transpose(0, 1).reshape(ts) + if not key.endswith(".bias"): if in_feat != out_feat and \ tensor.shape[1] == out_feat and \ tensor.shape[0] == in_feat: tensor = tensor.T + tensor = tensor.contiguous().to(self.device()) - res.append(nn.Parameter(tensor)) + res.append(nn.Parameter(tensor, requires_grad = False)) if len(res) == 2: return res[0], res[1] if len(res) == 1: return res[0] diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index 68518c5d..77d13a48 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -97,11 +97,17 @@ def forward(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: output_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - norm = torch.empty_like(hidden_states) + + if not output_fp32: + norm = torch.empty_like(hidden_states, dtype = torch.half) + else: + norm = torch.empty_like(hidden_states, dtype = torch.float) + ext_c.rms_norm(hidden_states, self.weight, norm, self.variance_epsilon) hidden_states = norm.view(output_shape) @@ -118,14 +124,17 @@ def forward_torch(self, past_len = None, intermediates: bool = False, loras = None, + output_fp32 = False, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: - hidden_states[hidden_states == -float('inf')] = -65504.0 - hidden_states[hidden_states == float('inf')] = 65504.0 + # hidden_states.clamp_(-65504.0, 65504.0) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim = True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = hidden_states.to(self.weight.dtype) + + if not output_fp32: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states *= self.weight if intermediates: diff --git a/exllamav2/server/websocket_actions.py b/exllamav2/server/websocket_actions.py index 12a6aa07..f84f2853 100644 --- a/exllamav2/server/websocket_actions.py +++ b/exllamav2/server/websocket_actions.py @@ -87,11 +87,11 @@ def lefttrim_token(request, ws, server, response): text = request["text"] length = int(request["trimmed_length"]) - ids = server.tokenizer.cached_encode_str(text) + ids = server.tokenizer.encode(text, encode_special_tokens = True) if ids.shape[-1] <= length: response["trimmed_text"] = text else: - response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:])[0] + response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:], decode_special_tokens = True)[0] async def infer(request, ws, server, response): diff --git a/exllamav2/tokenizer/tokenizer.py b/exllamav2/tokenizer/tokenizer.py index fbb458cc..1b9f4396 100644 --- a/exllamav2/tokenizer/tokenizer.py +++ b/exllamav2/tokenizer/tokenizer.py @@ -137,6 +137,17 @@ def __init__(self, config, lazy_init = False, force_json = False): with open(added_tokens_path, encoding = "utf8") as f: self.extended_piece_to_id.update(json.load(f)) + # Add special tokens from tokenizer_config.json + + if self.tokenizer_config_dict and "added_tokens_decoder" in self.tokenizer_config_dict: + atd = self.tokenizer_config_dict["added_tokens_decoder"] + for (k, v) in atd.items(): + if not v["special"]: + continue + token_id = int(k) + token_str = v["content"] + self.extended_piece_to_id[token_str] = token_id + # Remove unspecial added tokens that exist in the base tokenizer already, but only if they decode correctly # see https://github.com/huggingface/tokenizers/issues/1392 diff --git a/exllamav2/util.py b/exllamav2/util.py index d0387229..170241de 100644 --- a/exllamav2/util.py +++ b/exllamav2/util.py @@ -110,6 +110,17 @@ def torch_slice(self, a: int | None, b: int | None): return s +def cuda_sync_active(): + """ + Calling torch.cuda.synchronize() will create a CUDA context on CUDA:0 even if that device is not being used. + This function synchronizes only devices actively used by Torch in the current process. + """ + for device_id in range(torch.cuda.device_count()): + device = torch.device(f'cuda:{device_id}') + if torch.cuda.memory_allocated(device) > 0: + torch.cuda.synchronize(device) + + def get_basic_progress(): progress = Progress( TextColumn("[progress.description]{task.description}"), diff --git a/exllamav2/version.py b/exllamav2/version.py index 32efefd0..283b03a0 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.1.6" \ No newline at end of file +__version__ = "0.1.7" \ No newline at end of file diff --git a/model_diff.py b/model_diff.py index 9eeae0c7..5238477e 100644 --- a/model_diff.py +++ b/model_diff.py @@ -14,7 +14,7 @@ import pandas, fastparquet import torch import torch.nn.functional as F -from conversion.tokenize import get_tokens +from exllamav2.conversion.tokenize import get_tokens from exllamav2.util import list_live_tensors import gc @@ -47,6 +47,8 @@ config[1].prepare() config[0].max_batch_size = 1 config[1].max_batch_size = 1 +config[0].arch_compat_overrides() +config[1].arch_compat_overrides() model = (ExLlamaV2(config[0]), ExLlamaV2(config[1])) model[0].load(lazy = True) diff --git a/setup.py b/setup.py index 44985fea..b3d13a42 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "exllamav2/exllamav2_ext/ext_rope.cpp", "exllamav2/exllamav2_ext/ext_safetensors.cpp", "exllamav2/exllamav2_ext/ext_sampling.cpp", + "exllamav2/exllamav2_ext/ext_element.cpp", "exllamav2/exllamav2_ext/cuda/h_add.cu", "exllamav2/exllamav2_ext/cuda/h_gemm.cu", "exllamav2/exllamav2_ext/cuda/lora.cu", @@ -58,6 +59,7 @@ "exllamav2/exllamav2_ext/cuda/rope.cu", "exllamav2/exllamav2_ext/cuda/cache.cu", "exllamav2/exllamav2_ext/cuda/util.cu", + "exllamav2/exllamav2_ext/cuda/softcap.cu", "exllamav2/exllamav2_ext/cuda/comp_units/kernel_select.cu", "exllamav2/exllamav2_ext/cuda/comp_units/unit_gptq_1.cu", "exllamav2/exllamav2_ext/cuda/comp_units/unit_gptq_2.cu", diff --git a/test_inference.py b/test_inference.py index cd884d5b..2c6f67ce 100644 --- a/test_inference.py +++ b/test_inference.py @@ -50,6 +50,7 @@ parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache") parser.add_argument("-eq6", "--eval_token_q6", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q6 cache") parser.add_argument("-eq8", "--eval_token_q8", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q8 cache") +parser.add_argument("-ecl", "--eval_context_lens", action = "store_true", help = "Evaluate perplexity at range of context lengths") # parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)") parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)") parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt") @@ -83,6 +84,9 @@ if args.gpu_split: print(" ## Can only use one GPU when streaming layers") sys.exit() + if args.eval_context_lens and args.stream_layers: + print(" ## eval_context_lens not compatible with stream_layers") + sys.exit() if args.eval_dataset: if args.length and args.eval_length != args.length: print(" !! Overriding model context length to match eval row length") @@ -279,13 +283,24 @@ boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long) eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1) - logprob_sum = 0.0 - logprob_count = 0 + if args.eval_context_lens: + logprob_sum = [] + logprob_count = [] + else: + logprob_sum = 0.0 + logprob_count = 0 + + def ppl(input_ids__, logits__, lengths__, bins = False): - def ppl(input_ids__, logits__, lengths__): + logits_device = model.modules[-1].device() - logprob_sum_ = 0.0 - logprob_count_ = 0 + if bins: + num_bins = (max(lengths__) + 255) // 256 + logprob_sum_ = [0.0] * num_bins + logprob_count_ = [0] * num_bins + else: + logprob_sum_ = 0.0 + logprob_count_ = 0 assert logits__.shape[0] == input_ids__.shape[0] ll = logits__.shape[1] @@ -295,19 +310,28 @@ def ppl(input_ids__, logits__, lengths__): logits_ = logits__[bi:bi+1, cl:, :] input_ids_ = input_ids__[bi:bi+1, cl:] - chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1 + if bins: + chunksize = 256 + else: + chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1 b_ = 0 while b_ < logits_.shape[1]: a_ = b_ b_ = min(b_ + chunksize, logits_.shape[1]) - logits_f = logits_[:, a_:b_, :].float() + 1e-10 - target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device) + logits_f = logits_[:, a_:b_, :].to(logits_device).float() + 1e-10 + target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_f.device) log_probs = F.log_softmax(logits_f, dim=-1) token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) - logprob_sum_ += token_log_probs.sum().item() - logprob_count_ += target_ids.numel() + if bins: + # for cbin in range(a_ // 256 + 1): + cbin = a_ // 256 + logprob_sum_[cbin] += token_log_probs.sum().item() + logprob_count_[cbin] += target_ids.numel() + else: + logprob_sum_ += token_log_probs.sum().item() + logprob_count_ += target_ids.numel() return logprob_sum_, logprob_count_ @@ -376,18 +400,33 @@ def ppl(input_ids__, logits__, lengths__): input_ids = input_ids[:, :] if cache is not None: cache.current_seq_len = 0 - logits = model.forward(input_ids, cache) + logits = model.forward(input_ids, cache, cpu_logits = input_ids.numel() > 2048) logits = logits[:, :-1, :] - logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1]) - logprob_sum += logprob_sum__ - logprob_count += logprob_count__ - - print() - - mean_log_prob = logprob_sum / logprob_count - perplexity = math.exp(-mean_log_prob) - print(f" -- Evaluation perplexity: {perplexity:.4f}") + logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1], args.eval_context_lens) + if args.eval_context_lens: + while len(logprob_sum) < len(logprob_sum__): + logprob_sum.append(0.0) + logprob_count.append(0) + for j in range(len(logprob_sum__)): + logprob_sum[j] += logprob_sum__[j] + logprob_count[j] += logprob_count__[j] + else: + logprob_sum += logprob_sum__ + logprob_count += logprob_count__ + + if not args.eval_context_lens: + print() + mean_log_prob = logprob_sum / logprob_count + perplexity = math.exp(-mean_log_prob) + print(f" -- Evaluation perplexity: {perplexity:.4f}") + else: + print() + for j in range(len(logprob_sum__)): + mean_log_prob = logprob_sum[j] / logprob_count[j] + perplexity = math.exp(-mean_log_prob) + dl = min((j + 1) * 256, eval_length) + print(f" -- Evaluation perplexity: {dl} {perplexity:.4f}") def test_ppl_token(): global logprob_sum, logprob_count, i, input_ids