-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayernorm.py
137 lines (92 loc) · 3.42 KB
/
layernorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from __future__ import annotations
import torch
from torch import nn
from exllamav2.module import ExLlamaV2Module
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2
class ExLlamaV2LayerNorm(ExLlamaV2Module):
name: str = "LayerNorm"
layernorm: nn.LayerNorm | None
weight: nn.Parameter | None
bias: nn.Parameter | None
variance_epsilon: float
def __init__(self,
model: ExLlamaV2,
key: str):
super().__init__(model, key)
self.layernorm = None
self.weight = None
self.bias = None
self.variance_epsilon = 1e-6
def load(self):
w = self.load_weight()
if isinstance(w, tuple):
weight = w[0]
bias = w[1]
else:
weight = w
bias = None
assert isinstance(weight, nn.Parameter)
assert bias is None or isinstance(bias, nn.Parameter)
self.layernorm = nn.LayerNorm(self.model.config.hidden_size,
elementwise_affine = True,
bias = bias is not None)
self.layernorm.weight = weight
self.weight = weight
if bias is not None:
self.layernorm.bias = bias
self.bias = bias
self.variance_epsilon = self.model.config.norm_eps
def numel(self):
return 0
# return self.layernorm.weight.data.numel()
def unload(self):
self.layernorm = None
self.weight = None
self.bias = None
def get_weight(self) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.bias is not None: return self.weight, self.bias
return self.weight
def weight_footprint(self) -> int:
hidden_size = self.model.config.hidden_size
return hidden_size * 2
def scratch_space_fixed(self) -> int:
return 0
def scratch_space(self) -> int:
return 0
def forward(self,
hidden_states: torch.Tensor,
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
output_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
norm = torch.empty_like(hidden_states)
ext_c.layer_norm(hidden_states,
self.weight.data,
self.bias.data if self.bias is not None else none_tensor,
norm,
self.variance_epsilon)
hidden_states = norm.view(output_shape)
if intermediates:
return {"hidden_states": hidden_states}
else:
return hidden_states
def forward_torch(self,
hidden_states: torch.Tensor,
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
hidden_states = self.layernorm(hidden_states)
if intermediates:
return {"hidden_states": hidden_states}
else:
return hidden_states