OlivierDehaene
commited on
Commit
·
f0792b2
1
Parent(s):
02857b0
fuse qkv
Browse files- model.safetensors +2 -2
- modeling_gpt2_mq.py +15 -25
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ecb114682d0f35efe851ed162500e9d7babdf7c8c008fdcf76ec679e9a788533
|
3 |
+
size 4903278480
|
modeling_gpt2_mq.py
CHANGED
@@ -13,7 +13,6 @@ from transformers.modeling_outputs import (
|
|
13 |
)
|
14 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
|
15 |
from transformers.utils import logging
|
16 |
-
|
17 |
from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
|
18 |
|
19 |
logger = logging.get_logger(__name__)
|
@@ -130,10 +129,7 @@ class GPT2MQAttention(nn.Module):
|
|
130 |
if self.is_cross_attention:
|
131 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
132 |
else:
|
133 |
-
|
134 |
-
self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
|
135 |
-
# Keys and values are shared across heads
|
136 |
-
self.kv_attn = nn.Linear(self.embed_dim, 2 * self.head_dim)
|
137 |
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
138 |
|
139 |
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
@@ -143,13 +139,13 @@ class GPT2MQAttention(nn.Module):
|
|
143 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
144 |
|
145 |
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
146 |
-
# query: (b,
|
147 |
# key: (b, head_dim, sk)
|
148 |
# value: (b, sk, head_dim)
|
149 |
batch_size = query.size(0)
|
150 |
query_length = query.size(1) // self.num_heads
|
151 |
key_length = key.size(2)
|
152 |
-
# (b,
|
153 |
|
154 |
if self.scale_attn_weights:
|
155 |
query *= self.inv_norm_factor
|
@@ -157,7 +153,7 @@ class GPT2MQAttention(nn.Module):
|
|
157 |
attn_weights = torch.bmm(query, key)
|
158 |
|
159 |
# -> (b, num_heads, sq, sk)
|
160 |
-
attn_weights = attn_weights.view(batch_size, self.num_heads,
|
161 |
|
162 |
# Layer-wise attention scaling
|
163 |
if self.scale_attn_by_inverse_layer_idx:
|
@@ -174,13 +170,13 @@ class GPT2MQAttention(nn.Module):
|
|
174 |
|
175 |
# Mask heads if we want to
|
176 |
if head_mask is not None:
|
177 |
-
|
178 |
|
179 |
# (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
|
180 |
-
_attn_weights = attn_weights.view(batch_size, self.num_heads
|
181 |
# (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
|
182 |
attn_output = torch.bmm(_attn_weights, value)
|
183 |
-
attn_output = attn_output.view(batch_size, self.num_heads,
|
184 |
|
185 |
return attn_output, attn_weights
|
186 |
|
@@ -188,10 +184,8 @@ class GPT2MQAttention(nn.Module):
|
|
188 |
"""
|
189 |
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
190 |
"""
|
191 |
-
batch_size,
|
192 |
-
|
193 |
-
tensor = tensor.permute(0, 2, 1, 3)
|
194 |
-
return tensor.reshape(batch_size, seq_length, num_heads * head_dim)
|
195 |
|
196 |
def forward(
|
197 |
self,
|
@@ -207,17 +201,14 @@ class GPT2MQAttention(nn.Module):
|
|
207 |
if encoder_hidden_states is not None:
|
208 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
209 |
else:
|
210 |
-
|
211 |
-
key, value =
|
212 |
|
213 |
batch_size, seq_length = query.shape[:2]
|
214 |
-
# (query_length, batch, num_heads, head_dim)
|
215 |
-
# (batch, num_heads * query_length, head_dim)\
|
216 |
|
217 |
-
# (batch, query_length, hidden_size) -> (batch, num_heads,
|
218 |
-
|
219 |
-
|
220 |
-
query = query.reshape(batch_size, self.num_heads * seq_length, self.head_dim)
|
221 |
|
222 |
key = key.transpose(1, 2) # (batch_size, head_dim, seq_length)
|
223 |
|
@@ -360,8 +351,7 @@ class GPT2CustomModel(GPT2Model):
|
|
360 |
past_key_values_length=past_key_values_length,
|
361 |
)
|
362 |
|
363 |
-
attention_mask = attention_mask.unsqueeze(
|
364 |
-
*attention_mask.shape[1:])
|
365 |
|
366 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
367 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
|
13 |
)
|
14 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
|
15 |
from transformers.utils import logging
|
|
|
16 |
from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
|
17 |
|
18 |
logger = logging.get_logger(__name__)
|
|
|
129 |
if self.is_cross_attention:
|
130 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
131 |
else:
|
132 |
+
self.attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
|
|
|
|
|
|
|
133 |
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
134 |
|
135 |
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
|
139 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
140 |
|
141 |
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
142 |
+
# query: (b, sq * num_heads, head_dim)
|
143 |
# key: (b, head_dim, sk)
|
144 |
# value: (b, sk, head_dim)
|
145 |
batch_size = query.size(0)
|
146 |
query_length = query.size(1) // self.num_heads
|
147 |
key_length = key.size(2)
|
148 |
+
# (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
|
149 |
|
150 |
if self.scale_attn_weights:
|
151 |
query *= self.inv_norm_factor
|
|
|
153 |
attn_weights = torch.bmm(query, key)
|
154 |
|
155 |
# -> (b, num_heads, sq, sk)
|
156 |
+
attn_weights = attn_weights.view(batch_size, query_length, self.num_heads, key_length)
|
157 |
|
158 |
# Layer-wise attention scaling
|
159 |
if self.scale_attn_by_inverse_layer_idx:
|
|
|
170 |
|
171 |
# Mask heads if we want to
|
172 |
if head_mask is not None:
|
173 |
+
raise NotImplementedError
|
174 |
|
175 |
# (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
|
176 |
+
_attn_weights = attn_weights.view(batch_size, query_length * self.num_heads, key_length)
|
177 |
# (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
|
178 |
attn_output = torch.bmm(_attn_weights, value)
|
179 |
+
attn_output = attn_output.view(batch_size, query_length, self.num_heads, self.head_dim)
|
180 |
|
181 |
return attn_output, attn_weights
|
182 |
|
|
|
184 |
"""
|
185 |
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
186 |
"""
|
187 |
+
batch_size, seq_length, num_heads, head_dim = tensor.shape
|
188 |
+
return tensor.view(batch_size, seq_length, num_heads * head_dim)
|
|
|
|
|
189 |
|
190 |
def forward(
|
191 |
self,
|
|
|
201 |
if encoder_hidden_states is not None:
|
202 |
raise NotImplementedError("Cross-attention not implemented for MQA")
|
203 |
else:
|
204 |
+
qkv = self.attn(hidden_states)
|
205 |
+
query, key, value = qkv.split([self.embed_dim, self.head_dim, self.head_dim], dim=2)
|
206 |
|
207 |
batch_size, seq_length = query.shape[:2]
|
|
|
|
|
208 |
|
209 |
+
# (batch, query_length, hidden_size) -> (batch, query_length * num_heads, head_dim)
|
210 |
+
# forced to reshape here
|
211 |
+
query = query.reshape(batch_size, seq_length * self.num_heads, self.head_dim)
|
|
|
212 |
|
213 |
key = key.transpose(1, 2) # (batch_size, head_dim, seq_length)
|
214 |
|
|
|
351 |
past_key_values_length=past_key_values_length,
|
352 |
)
|
353 |
|
354 |
+
attention_mask = attention_mask.unsqueeze(2).expand(batch_size, attention_mask.shape[1], self.config.num_attention_heads, attention_mask.shape[2])
|
|
|
355 |
|
356 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
357 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|