OlivierDehaene commited on
Commit
f0792b2
·
1 Parent(s): 02857b0
Files changed (2) hide show
  1. model.safetensors +2 -2
  2. modeling_gpt2_mq.py +15 -25
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba58d7bbc20355cd3083e789a88fa6b9016ec36ffaf113e94df03d1449ecadf6
3
- size 4903283827
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecb114682d0f35efe851ed162500e9d7babdf7c8c008fdcf76ec679e9a788533
3
+ size 4903278480
modeling_gpt2_mq.py CHANGED
@@ -13,7 +13,6 @@ from transformers.modeling_outputs import (
13
  )
14
  from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
15
  from transformers.utils import logging
16
-
17
  from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
18
 
19
  logger = logging.get_logger(__name__)
@@ -130,10 +129,7 @@ class GPT2MQAttention(nn.Module):
130
  if self.is_cross_attention:
131
  raise NotImplementedError("Cross-attention not implemented for MQA")
132
  else:
133
- # self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
134
- self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
135
- # Keys and values are shared across heads
136
- self.kv_attn = nn.Linear(self.embed_dim, 2 * self.head_dim)
137
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
138
 
139
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
@@ -143,13 +139,13 @@ class GPT2MQAttention(nn.Module):
143
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
144
 
145
  def _attn(self, query, key, value, attention_mask=None, head_mask=None):
146
- # query: (b, num_heads * sq, head_dim)
147
  # key: (b, head_dim, sk)
148
  # value: (b, sk, head_dim)
149
  batch_size = query.size(0)
150
  query_length = query.size(1) // self.num_heads
151
  key_length = key.size(2)
152
- # (b, num_heads * sq, head_dim) x (b, head_dim, sk) -> (b, num_heads * sq, sk)
153
 
154
  if self.scale_attn_weights:
155
  query *= self.inv_norm_factor
@@ -157,7 +153,7 @@ class GPT2MQAttention(nn.Module):
157
  attn_weights = torch.bmm(query, key)
158
 
159
  # -> (b, num_heads, sq, sk)
160
- attn_weights = attn_weights.view(batch_size, self.num_heads, query_length, key_length)
161
 
162
  # Layer-wise attention scaling
163
  if self.scale_attn_by_inverse_layer_idx:
@@ -174,13 +170,13 @@ class GPT2MQAttention(nn.Module):
174
 
175
  # Mask heads if we want to
176
  if head_mask is not None:
177
- attn_weights = attn_weights * head_mask
178
 
179
  # (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
180
- _attn_weights = attn_weights.view(batch_size, self.num_heads * query_length, key_length)
181
  # (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
182
  attn_output = torch.bmm(_attn_weights, value)
183
- attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
184
 
185
  return attn_output, attn_weights
186
 
@@ -188,10 +184,8 @@ class GPT2MQAttention(nn.Module):
188
  """
189
  Merges attn_head_size dim and num_attn_heads dim into hidden_size
190
  """
191
- batch_size, num_heads, seq_length, head_dim = tensor.shape
192
-
193
- tensor = tensor.permute(0, 2, 1, 3)
194
- return tensor.reshape(batch_size, seq_length, num_heads * head_dim)
195
 
196
  def forward(
197
  self,
@@ -207,17 +201,14 @@ class GPT2MQAttention(nn.Module):
207
  if encoder_hidden_states is not None:
208
  raise NotImplementedError("Cross-attention not implemented for MQA")
209
  else:
210
- query = self.q_attn(hidden_states)
211
- key, value = self.kv_attn(hidden_states).split(self.head_dim, dim=2)
212
 
213
  batch_size, seq_length = query.shape[:2]
214
- # (query_length, batch, num_heads, head_dim)
215
- # (batch, num_heads * query_length, head_dim)\
216
 
217
- # (batch, query_length, hidden_size) -> (batch, num_heads, query_length, head_dim)
218
- query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
219
- # -> (batch, num_heads * query_length, head_dim)
220
- query = query.reshape(batch_size, self.num_heads * seq_length, self.head_dim)
221
 
222
  key = key.transpose(1, 2) # (batch_size, head_dim, seq_length)
223
 
@@ -360,8 +351,7 @@ class GPT2CustomModel(GPT2Model):
360
  past_key_values_length=past_key_values_length,
361
  )
362
 
363
- attention_mask = attention_mask.unsqueeze(1).expand(batch_size, self.config.num_attention_heads,
364
- *attention_mask.shape[1:])
365
 
366
  # If a 2D or 3D attention mask is provided for the cross-attention
367
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
 
13
  )
14
  from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
15
  from transformers.utils import logging
 
16
  from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
17
 
18
  logger = logging.get_logger(__name__)
 
129
  if self.is_cross_attention:
130
  raise NotImplementedError("Cross-attention not implemented for MQA")
131
  else:
132
+ self.attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.head_dim)
 
 
 
133
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
134
 
135
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
 
139
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
140
 
141
  def _attn(self, query, key, value, attention_mask=None, head_mask=None):
142
+ # query: (b, sq * num_heads, head_dim)
143
  # key: (b, head_dim, sk)
144
  # value: (b, sk, head_dim)
145
  batch_size = query.size(0)
146
  query_length = query.size(1) // self.num_heads
147
  key_length = key.size(2)
148
+ # (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
149
 
150
  if self.scale_attn_weights:
151
  query *= self.inv_norm_factor
 
153
  attn_weights = torch.bmm(query, key)
154
 
155
  # -> (b, num_heads, sq, sk)
156
+ attn_weights = attn_weights.view(batch_size, query_length, self.num_heads, key_length)
157
 
158
  # Layer-wise attention scaling
159
  if self.scale_attn_by_inverse_layer_idx:
 
170
 
171
  # Mask heads if we want to
172
  if head_mask is not None:
173
+ raise NotImplementedError
174
 
175
  # (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
176
+ _attn_weights = attn_weights.view(batch_size, query_length * self.num_heads, key_length)
177
  # (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
178
  attn_output = torch.bmm(_attn_weights, value)
179
+ attn_output = attn_output.view(batch_size, query_length, self.num_heads, self.head_dim)
180
 
181
  return attn_output, attn_weights
182
 
 
184
  """
185
  Merges attn_head_size dim and num_attn_heads dim into hidden_size
186
  """
187
+ batch_size, seq_length, num_heads, head_dim = tensor.shape
188
+ return tensor.view(batch_size, seq_length, num_heads * head_dim)
 
 
189
 
190
  def forward(
191
  self,
 
201
  if encoder_hidden_states is not None:
202
  raise NotImplementedError("Cross-attention not implemented for MQA")
203
  else:
204
+ qkv = self.attn(hidden_states)
205
+ query, key, value = qkv.split([self.embed_dim, self.head_dim, self.head_dim], dim=2)
206
 
207
  batch_size, seq_length = query.shape[:2]
 
 
208
 
209
+ # (batch, query_length, hidden_size) -> (batch, query_length * num_heads, head_dim)
210
+ # forced to reshape here
211
+ query = query.reshape(batch_size, seq_length * self.num_heads, self.head_dim)
 
212
 
213
  key = key.transpose(1, 2) # (batch_size, head_dim, seq_length)
214
 
 
351
  past_key_values_length=past_key_values_length,
352
  )
353
 
354
+ attention_mask = attention_mask.unsqueeze(2).expand(batch_size, attention_mask.shape[1], self.config.num_attention_heads, attention_mask.shape[2])
 
355
 
356
  # If a 2D or 3D attention mask is provided for the cross-attention
357
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]