Hyungtae Kim
commited on
Commit
·
27675b5
1
Parent(s):
8fb2de8
Remove custom embedding.
Browse files- modeling_mpt.py +2 -3
modeling_mpt.py
CHANGED
@@ -12,7 +12,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
|
|
12 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
13 |
from .attention import attn_bias_shape, build_attn_bias
|
14 |
from .blocks import MPTBlock
|
15 |
-
from .custom_embedding import SharedEmbedding
|
16 |
from .norm import NORM_CLASS_REGISTRY
|
17 |
from .configuration_mpt import MPTConfig
|
18 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
@@ -56,7 +55,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
56 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
57 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
58 |
self.embedding_fraction = config.embedding_fraction
|
59 |
-
self.wte =
|
60 |
if not self.alibi:
|
61 |
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
62 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
@@ -322,7 +321,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
322 |
if inputs_embeds is not None:
|
323 |
raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
|
324 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
325 |
-
logits =
|
326 |
if self.logit_scale is not None:
|
327 |
if self.logit_scale == 0:
|
328 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
|
|
12 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
13 |
from .attention import attn_bias_shape, build_attn_bias
|
14 |
from .blocks import MPTBlock
|
|
|
15 |
from .norm import NORM_CLASS_REGISTRY
|
16 |
from .configuration_mpt import MPTConfig
|
17 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
|
|
55 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
56 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
57 |
self.embedding_fraction = config.embedding_fraction
|
58 |
+
self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
|
59 |
if not self.alibi:
|
60 |
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
61 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
|
|
321 |
if inputs_embeds is not None:
|
322 |
raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
|
323 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
324 |
+
logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
|
325 |
if self.logit_scale is not None:
|
326 |
if self.logit_scale == 0:
|
327 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|