lhallee commited on
Commit
56fe2cc
·
verified ·
1 Parent(s): 8b1744b

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +6 -19
modeling_fastesm.py CHANGED
@@ -749,35 +749,22 @@ class FastEsmModel(FastEsmPreTrainedModel):
749
  else:
750
  raise ValueError("You have to specify either input_ids or inputs_embeds")
751
 
752
- batch_size, seq_length = input_shape
753
- embedding_output = self.embeddings(
754
- input_ids=input_ids,
755
- position_ids=position_ids,
756
  attention_mask=attention_mask,
 
757
  inputs_embeds=inputs_embeds,
758
- )
759
-
760
- if attention_mask is not None:
761
- extended_attention_mask = attention_mask[:, None, None, :].expand(
762
- batch_size, 1, seq_length, seq_length
763
- ).bool()
764
- else:
765
- extended_attention_mask = None
766
-
767
- encoder_outputs = self.encoder(
768
- embedding_output,
769
- attention_mask=extended_attention_mask,
770
  output_hidden_states=output_hidden_states,
771
  output_attentions=output_attentions,
772
  )
773
- sequence_output = encoder_outputs.last_hidden_state
774
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
775
 
776
  return BaseModelOutputWithPoolingAndCrossAttentions(
777
  last_hidden_state=sequence_output,
778
  pooler_output=pooled_output,
779
- hidden_states=encoder_outputs.hidden_states,
780
- attentions=encoder_outputs.attentions,
781
  )
782
 
783
 
 
749
  else:
750
  raise ValueError("You have to specify either input_ids or inputs_embeds")
751
 
752
+ outputs = self.esm(
753
+ input_ids,
 
 
754
  attention_mask=attention_mask,
755
+ position_ids=position_ids,
756
  inputs_embeds=inputs_embeds,
 
 
 
 
 
 
 
 
 
 
 
 
757
  output_hidden_states=output_hidden_states,
758
  output_attentions=output_attentions,
759
  )
760
+ sequence_output = outputs.last_hidden_state
761
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
762
 
763
  return BaseModelOutputWithPoolingAndCrossAttentions(
764
  last_hidden_state=sequence_output,
765
  pooler_output=pooled_output,
766
+ hidden_states=outputs.hidden_states,
767
+ attentions=outputs.attentions,
768
  )
769
 
770