Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +15 -9
modeling_fastesm.py
CHANGED
@@ -4,9 +4,10 @@ from torch.nn import functional as F
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
|
|
7 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
8 |
from transformers.modeling_outputs import (
|
9 |
-
|
10 |
BaseModelOutputWithPastAndCrossAttentions,
|
11 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
12 |
SequenceClassifierOutput,
|
@@ -23,6 +24,15 @@ from transformers.models.esm.modeling_esm import (
|
|
23 |
from tqdm.auto import tqdm
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class FastEsmConfig(PretrainedConfig):
|
27 |
model_type = "fast_esm"
|
28 |
def __init__(
|
@@ -656,9 +666,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
|
|
656 |
Model outputs including hidden states and optionally attention weights
|
657 |
"""
|
658 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
659 |
-
output_hidden_states =
|
660 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
661 |
-
)
|
662 |
|
663 |
if input_ids is not None and inputs_embeds is not None:
|
664 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
@@ -739,9 +747,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
739 |
Model outputs including hidden states and optionally attention weights
|
740 |
"""
|
741 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
742 |
-
output_hidden_states =
|
743 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
744 |
-
)
|
745 |
|
746 |
if input_ids is not None and inputs_embeds is not None:
|
747 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
@@ -798,7 +804,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
798 |
output_attentions: Optional[bool] = None,
|
799 |
output_hidden_states: Optional[bool] = None,
|
800 |
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
801 |
-
) -> Union[Tuple,
|
802 |
outputs = self.esm(
|
803 |
input_ids,
|
804 |
attention_mask=attention_mask,
|
@@ -815,7 +821,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
815 |
labels = labels.to(prediction_scores.device)
|
816 |
loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
817 |
|
818 |
-
return
|
819 |
loss=loss,
|
820 |
logits=prediction_scores,
|
821 |
hidden_states=outputs.hidden_states,
|
|
|
4 |
from torch.utils.data import Dataset, DataLoader
|
5 |
from typing import Optional, Tuple, Union
|
6 |
from einops import rearrange
|
7 |
+
from dataclasses import dataclass
|
8 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
9 |
from transformers.modeling_outputs import (
|
10 |
+
ModelOutput,
|
11 |
BaseModelOutputWithPastAndCrossAttentions,
|
12 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
13 |
SequenceClassifierOutput,
|
|
|
24 |
from tqdm.auto import tqdm
|
25 |
|
26 |
|
27 |
+
@dataclass
|
28 |
+
class EsmMaskedLMOutput(ModelOutput):
|
29 |
+
loss: Optional[torch.FloatTensor] = None
|
30 |
+
logits: Optional[torch.FloatTensor] = None
|
31 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
32 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
33 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
34 |
+
|
35 |
+
|
36 |
class FastEsmConfig(PretrainedConfig):
|
37 |
model_type = "fast_esm"
|
38 |
def __init__(
|
|
|
666 |
Model outputs including hidden states and optionally attention weights
|
667 |
"""
|
668 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
669 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
670 |
|
671 |
if input_ids is not None and inputs_embeds is not None:
|
672 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
747 |
Model outputs including hidden states and optionally attention weights
|
748 |
"""
|
749 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
750 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
751 |
|
752 |
if input_ids is not None and inputs_embeds is not None:
|
753 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
804 |
output_attentions: Optional[bool] = None,
|
805 |
output_hidden_states: Optional[bool] = None,
|
806 |
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
807 |
+
) -> Union[Tuple, EsmMaskedLMOutput]:
|
808 |
outputs = self.esm(
|
809 |
input_ids,
|
810 |
attention_mask=attention_mask,
|
|
|
821 |
labels = labels.to(prediction_scores.device)
|
822 |
loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
823 |
|
824 |
+
return EsmMaskedLMOutput(
|
825 |
loss=loss,
|
826 |
logits=prediction_scores,
|
827 |
hidden_states=outputs.hidden_states,
|