lhallee commited on
Commit
44cde8c
·
verified ·
1 Parent(s): 1247541

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +15 -9
modeling_fastesm.py CHANGED
@@ -4,9 +4,10 @@ from torch.nn import functional as F
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
 
7
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
  from transformers.modeling_outputs import (
9
- MaskedLMOutput,
10
  BaseModelOutputWithPastAndCrossAttentions,
11
  BaseModelOutputWithPoolingAndCrossAttentions,
12
  SequenceClassifierOutput,
@@ -23,6 +24,15 @@ from transformers.models.esm.modeling_esm import (
23
  from tqdm.auto import tqdm
24
 
25
 
 
 
 
 
 
 
 
 
 
26
  class FastEsmConfig(PretrainedConfig):
27
  model_type = "fast_esm"
28
  def __init__(
@@ -656,9 +666,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
656
  Model outputs including hidden states and optionally attention weights
657
  """
658
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
659
- output_hidden_states = (
660
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
661
- )
662
 
663
  if input_ids is not None and inputs_embeds is not None:
664
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -739,9 +747,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
739
  Model outputs including hidden states and optionally attention weights
740
  """
741
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
742
- output_hidden_states = (
743
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
744
- )
745
 
746
  if input_ids is not None and inputs_embeds is not None:
747
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -798,7 +804,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
798
  output_attentions: Optional[bool] = None,
799
  output_hidden_states: Optional[bool] = None,
800
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
801
- ) -> Union[Tuple, MaskedLMOutput]:
802
  outputs = self.esm(
803
  input_ids,
804
  attention_mask=attention_mask,
@@ -815,7 +821,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
815
  labels = labels.to(prediction_scores.device)
816
  loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
817
 
818
- return MaskedLMOutput(
819
  loss=loss,
820
  logits=prediction_scores,
821
  hidden_states=outputs.hidden_states,
 
4
  from torch.utils.data import Dataset, DataLoader
5
  from typing import Optional, Tuple, Union
6
  from einops import rearrange
7
+ from dataclasses import dataclass
8
  from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
  from transformers.modeling_outputs import (
10
+ ModelOutput,
11
  BaseModelOutputWithPastAndCrossAttentions,
12
  BaseModelOutputWithPoolingAndCrossAttentions,
13
  SequenceClassifierOutput,
 
24
  from tqdm.auto import tqdm
25
 
26
 
27
+ @dataclass
28
+ class EsmMaskedLMOutput(ModelOutput):
29
+ loss: Optional[torch.FloatTensor] = None
30
+ logits: Optional[torch.FloatTensor] = None
31
+ last_hidden_state: Optional[torch.FloatTensor] = None
32
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
33
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
34
+
35
+
36
  class FastEsmConfig(PretrainedConfig):
37
  model_type = "fast_esm"
38
  def __init__(
 
666
  Model outputs including hidden states and optionally attention weights
667
  """
668
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
669
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
670
 
671
  if input_ids is not None and inputs_embeds is not None:
672
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
747
  Model outputs including hidden states and optionally attention weights
748
  """
749
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
750
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
751
 
752
  if input_ids is not None and inputs_embeds is not None:
753
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
804
  output_attentions: Optional[bool] = None,
805
  output_hidden_states: Optional[bool] = None,
806
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
807
+ ) -> Union[Tuple, EsmMaskedLMOutput]:
808
  outputs = self.esm(
809
  input_ids,
810
  attention_mask=attention_mask,
 
821
  labels = labels.to(prediction_scores.device)
822
  loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
823
 
824
+ return EsmMaskedLMOutput(
825
  loss=loss,
826
  logits=prediction_scores,
827
  hidden_states=outputs.hidden_states,