Update modeling_fastesm.py
Browse files- modeling_fastesm.py +2 -1
modeling_fastesm.py
CHANGED
@@ -4,7 +4,7 @@ from torch.nn import functional as F
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
7 |
-
from transformers import PreTrainedModel, PretrainedConfig
|
8 |
from transformers.modeling_outputs import (
|
9 |
MaskedLMOutput,
|
10 |
BaseModelOutputWithPastAndCrossAttentions,
|
@@ -346,6 +346,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
346 |
config_class = FastEsmConfig
|
347 |
base_model_prefix = "fastesm"
|
348 |
supports_gradient_checkpointing = True
|
|
|
349 |
def _init_weights(self, module):
|
350 |
"""Initialize the weights"""
|
351 |
if isinstance(module, nn.Linear):
|
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
7 |
+
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
8 |
from transformers.modeling_outputs import (
|
9 |
MaskedLMOutput,
|
10 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
346 |
config_class = FastEsmConfig
|
347 |
base_model_prefix = "fastesm"
|
348 |
supports_gradient_checkpointing = True
|
349 |
+
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
350 |
def _init_weights(self, module):
|
351 |
"""Initialize the weights"""
|
352 |
if isinstance(module, nn.Linear):
|