Update modeling_fastesm.py
Browse files- modeling_fastesm.py +2 -1
modeling_fastesm.py
CHANGED
@@ -4,7 +4,7 @@ from torch.nn import functional as F
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
7 |
-
from transformers import PreTrainedModel, PretrainedConfig
|
8 |
from transformers.modeling_outputs import (
|
9 |
MaskedLMOutput,
|
10 |
BaseModelOutputWithPastAndCrossAttentions,
|
@@ -339,6 +339,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
339 |
config_class = FastEsmConfig
|
340 |
base_model_prefix = "fastesm"
|
341 |
supports_gradient_checkpointing = True
|
|
|
342 |
def _init_weights(self, module):
|
343 |
"""Initialize the weights"""
|
344 |
if isinstance(module, nn.Linear):
|
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
7 |
+
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
8 |
from transformers.modeling_outputs import (
|
9 |
MaskedLMOutput,
|
10 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
339 |
config_class = FastEsmConfig
|
340 |
base_model_prefix = "fastesm"
|
341 |
supports_gradient_checkpointing = True
|
342 |
+
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
343 |
def _init_weights(self, module):
|
344 |
"""Initialize the weights"""
|
345 |
if isinstance(module, nn.Linear):
|