Update modeling_chatglm.py
Browse files- modeling_chatglm.py +2 -2
modeling_chatglm.py
CHANGED
@@ -743,7 +743,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
743 |
|
744 |
self.embed_tokens = nn.Embedding(
|
745 |
config.padded_vocab_size,
|
746 |
-
|
747 |
dtype=config.torch_dtype,
|
748 |
device=device
|
749 |
)
|
@@ -825,7 +825,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
825 |
batch_size, seq_length = input_ids.shape
|
826 |
|
827 |
if inputs_embeds is None:
|
828 |
-
inputs_embeds = self.
|
829 |
|
830 |
if self.pre_seq_len is not None:
|
831 |
if past_key_values is None:
|
|
|
743 |
|
744 |
self.embed_tokens = nn.Embedding(
|
745 |
config.padded_vocab_size,
|
746 |
+
config.hidden_size,
|
747 |
dtype=config.torch_dtype,
|
748 |
device=device
|
749 |
)
|
|
|
825 |
batch_size, seq_length = input_ids.shape
|
826 |
|
827 |
if inputs_embeds is None:
|
828 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
829 |
|
830 |
if self.pre_seq_len is not None:
|
831 |
if past_key_values is None:
|