lhallee commited on
Commit
b418994
·
verified ·
1 Parent(s): fb1a199

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +2 -1
modeling_fastesm.py CHANGED
@@ -4,7 +4,7 @@ from torch.nn import functional as F
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
- from transformers import PreTrainedModel, PretrainedConfig
8
  from transformers.modeling_outputs import (
9
  MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
@@ -346,6 +346,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
346
  config_class = FastEsmConfig
347
  base_model_prefix = "fastesm"
348
  supports_gradient_checkpointing = True
 
349
  def _init_weights(self, module):
350
  """Initialize the weights"""
351
  if isinstance(module, nn.Linear):
 
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
  from transformers.modeling_outputs import (
9
  MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
 
346
  config_class = FastEsmConfig
347
  base_model_prefix = "fastesm"
348
  supports_gradient_checkpointing = True
349
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
350
  def _init_weights(self, module):
351
  """Initialize the weights"""
352
  if isinstance(module, nn.Linear):