Update modeling_chatglm.py
Browse files- modeling_chatglm.py +5 -7
modeling_chatglm.py
CHANGED
@@ -220,11 +220,10 @@ class CoreAttention(torch.nn.Module):
|
|
220 |
|
221 |
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
222 |
pytorch_major_version = int(torch.__version__.split('.')[0])
|
223 |
-
if
|
224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
-
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
227 |
-
is_causal=True)
|
228 |
else:
|
229 |
if attention_mask is not None:
|
230 |
attention_mask = ~attention_mask
|
@@ -237,7 +236,7 @@ class CoreAttention(torch.nn.Module):
|
|
237 |
# Raw attention scores
|
238 |
|
239 |
# [b, np, sq, sk]
|
240 |
-
output_size = (query_layer.size(
|
241 |
|
242 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
243 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
@@ -312,7 +311,6 @@ class CoreAttention(torch.nn.Module):
|
|
312 |
|
313 |
class SelfAttention(torch.nn.Module):
|
314 |
"""Parallel self-attention layer abstract class.
|
315 |
-
|
316 |
Self-attention layer takes input with size [s, b, h]
|
317 |
and returns output of the same size.
|
318 |
"""
|
@@ -448,7 +446,6 @@ class SelfAttention(torch.nn.Module):
|
|
448 |
|
449 |
return output, kv_cache
|
450 |
|
451 |
-
|
452 |
def _config_to_kwargs(args):
|
453 |
common_kwargs = {
|
454 |
"dtype": args.torch_dtype,
|
@@ -504,7 +501,6 @@ class MLP(torch.nn.Module):
|
|
504 |
|
505 |
class GLMBlock(torch.nn.Module):
|
506 |
"""A single transformer layer.
|
507 |
-
|
508 |
Transformer layer takes input with size [s, b, h] and returns an
|
509 |
output of the same size.
|
510 |
"""
|
@@ -862,6 +858,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
862 |
all_hidden_states = () if output_hidden_states else None
|
863 |
|
864 |
hidden_states = inputs_embeds
|
|
|
|
|
865 |
|
866 |
for index, layer in enumerate(self.layers):
|
867 |
if output_hidden_states:
|
|
|
220 |
|
221 |
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
222 |
pytorch_major_version = int(torch.__version__.split('.')[0])
|
223 |
+
if pytorch_major_version >= 2:
|
224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)
|
|
|
227 |
else:
|
228 |
if attention_mask is not None:
|
229 |
attention_mask = ~attention_mask
|
|
|
236 |
# Raw attention scores
|
237 |
|
238 |
# [b, np, sq, sk]
|
239 |
+
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(0))
|
240 |
|
241 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
242 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
|
311 |
|
312 |
class SelfAttention(torch.nn.Module):
|
313 |
"""Parallel self-attention layer abstract class.
|
|
|
314 |
Self-attention layer takes input with size [s, b, h]
|
315 |
and returns output of the same size.
|
316 |
"""
|
|
|
446 |
|
447 |
return output, kv_cache
|
448 |
|
|
|
449 |
def _config_to_kwargs(args):
|
450 |
common_kwargs = {
|
451 |
"dtype": args.torch_dtype,
|
|
|
501 |
|
502 |
class GLMBlock(torch.nn.Module):
|
503 |
"""A single transformer layer.
|
|
|
504 |
Transformer layer takes input with size [s, b, h] and returns an
|
505 |
output of the same size.
|
506 |
"""
|
|
|
858 |
all_hidden_states = () if output_hidden_states else None
|
859 |
|
860 |
hidden_states = inputs_embeds
|
861 |
+
# To comply with former chat-glm format that expects (seqlen, bs, hd)
|
862 |
+
hidden_states = hidden_states.permute(1, 0, 2)
|
863 |
|
864 |
for index, layer in enumerate(self.layers):
|
865 |
if output_hidden_states:
|