lhallee commited on
Commit
f50c5ec
·
verified ·
1 Parent(s): afb8987

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +1 -9
modeling_fastesm.py CHANGED
@@ -4,7 +4,7 @@ from torch.nn import functional as F
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
  from transformers.modeling_outputs import (
9
  MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
@@ -145,13 +145,6 @@ class EsmEmbeddings(nn.Module):
145
  def forward(
146
  self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
147
  ):
148
- if position_ids is None:
149
- if input_ids is not None:
150
- # Create the position ids from the input token ids. Any padded tokens remain padded.
151
- position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
152
- else:
153
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
154
-
155
  if inputs_embeds is None:
156
  inputs_embeds = self.word_embeddings(input_ids)
157
 
@@ -346,7 +339,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
346
  config_class = FastEsmConfig
347
  base_model_prefix = "fastesm"
348
  supports_gradient_checkpointing = True
349
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
350
  def _init_weights(self, module):
351
  """Initialize the weights"""
352
  if isinstance(module, nn.Linear):
 
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
+ from transformers import PreTrainedModel, PretrainedConfig
8
  from transformers.modeling_outputs import (
9
  MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
 
145
  def forward(
146
  self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
147
  ):
 
 
 
 
 
 
 
148
  if inputs_embeds is None:
149
  inputs_embeds = self.word_embeddings(input_ids)
150
 
 
339
  config_class = FastEsmConfig
340
  base_model_prefix = "fastesm"
341
  supports_gradient_checkpointing = True
 
342
  def _init_weights(self, module):
343
  """Initialize the weights"""
344
  if isinstance(module, nn.Linear):