ybelkada commited on
Commit
c289770
·
1 Parent(s): faed813

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +5 -7
modeling_chatglm.py CHANGED
@@ -220,11 +220,10 @@ class CoreAttention(torch.nn.Module):
220
 
221
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
222
  pytorch_major_version = int(torch.__version__.split('.')[0])
223
- if False:
224
  query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
225
  if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
226
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
227
- is_causal=True)
228
  else:
229
  if attention_mask is not None:
230
  attention_mask = ~attention_mask
@@ -237,7 +236,7 @@ class CoreAttention(torch.nn.Module):
237
  # Raw attention scores
238
 
239
  # [b, np, sq, sk]
240
- output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
241
 
242
  # [sq, b, np, hn] -> [sq, b * np, hn]
243
  query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
@@ -312,7 +311,6 @@ class CoreAttention(torch.nn.Module):
312
 
313
  class SelfAttention(torch.nn.Module):
314
  """Parallel self-attention layer abstract class.
315
-
316
  Self-attention layer takes input with size [s, b, h]
317
  and returns output of the same size.
318
  """
@@ -448,7 +446,6 @@ class SelfAttention(torch.nn.Module):
448
 
449
  return output, kv_cache
450
 
451
-
452
  def _config_to_kwargs(args):
453
  common_kwargs = {
454
  "dtype": args.torch_dtype,
@@ -504,7 +501,6 @@ class MLP(torch.nn.Module):
504
 
505
  class GLMBlock(torch.nn.Module):
506
  """A single transformer layer.
507
-
508
  Transformer layer takes input with size [s, b, h] and returns an
509
  output of the same size.
510
  """
@@ -862,6 +858,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
862
  all_hidden_states = () if output_hidden_states else None
863
 
864
  hidden_states = inputs_embeds
 
 
865
 
866
  for index, layer in enumerate(self.layers):
867
  if output_hidden_states:
 
220
 
221
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
222
  pytorch_major_version = int(torch.__version__.split('.')[0])
223
+ if pytorch_major_version >= 2:
224
  query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
225
  if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
226
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)
 
227
  else:
228
  if attention_mask is not None:
229
  attention_mask = ~attention_mask
 
236
  # Raw attention scores
237
 
238
  # [b, np, sq, sk]
239
+ output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(0))
240
 
241
  # [sq, b, np, hn] -> [sq, b * np, hn]
242
  query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
 
311
 
312
  class SelfAttention(torch.nn.Module):
313
  """Parallel self-attention layer abstract class.
 
314
  Self-attention layer takes input with size [s, b, h]
315
  and returns output of the same size.
316
  """
 
446
 
447
  return output, kv_cache
448
 
 
449
  def _config_to_kwargs(args):
450
  common_kwargs = {
451
  "dtype": args.torch_dtype,
 
501
 
502
  class GLMBlock(torch.nn.Module):
503
  """A single transformer layer.
 
504
  Transformer layer takes input with size [s, b, h] and returns an
505
  output of the same size.
506
  """
 
858
  all_hidden_states = () if output_hidden_states else None
859
 
860
  hidden_states = inputs_embeds
861
+ # To comply with former chat-glm format that expects (seqlen, bs, hd)
862
+ hidden_states = hidden_states.permute(1, 0, 2)
863
 
864
  for index, layer in enumerate(self.layers):
865
  if output_hidden_states: