lhallee commited on
Commit
f7f9194
·
verified ·
1 Parent(s): f50c5ec

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +2 -1
modeling_fastesm.py CHANGED
@@ -4,7 +4,7 @@ from torch.nn import functional as F
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
- from transformers import PreTrainedModel, PretrainedConfig
8
  from transformers.modeling_outputs import (
9
  MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
@@ -339,6 +339,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
339
  config_class = FastEsmConfig
340
  base_model_prefix = "fastesm"
341
  supports_gradient_checkpointing = True
 
342
  def _init_weights(self, module):
343
  """Initialize the weights"""
344
  if isinstance(module, nn.Linear):
 
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
  from transformers.modeling_outputs import (
9
  MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
 
339
  config_class = FastEsmConfig
340
  base_model_prefix = "fastesm"
341
  supports_gradient_checkpointing = True
342
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
343
  def _init_weights(self, module):
344
  """Initialize the weights"""
345
  if isinstance(module, nn.Linear):