Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +6 -19
modeling_fastesm.py
CHANGED
@@ -749,35 +749,22 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
749 |
else:
|
750 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
751 |
|
752 |
-
|
753 |
-
|
754 |
-
input_ids=input_ids,
|
755 |
-
position_ids=position_ids,
|
756 |
attention_mask=attention_mask,
|
|
|
757 |
inputs_embeds=inputs_embeds,
|
758 |
-
)
|
759 |
-
|
760 |
-
if attention_mask is not None:
|
761 |
-
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
762 |
-
batch_size, 1, seq_length, seq_length
|
763 |
-
).bool()
|
764 |
-
else:
|
765 |
-
extended_attention_mask = None
|
766 |
-
|
767 |
-
encoder_outputs = self.encoder(
|
768 |
-
embedding_output,
|
769 |
-
attention_mask=extended_attention_mask,
|
770 |
output_hidden_states=output_hidden_states,
|
771 |
output_attentions=output_attentions,
|
772 |
)
|
773 |
-
sequence_output =
|
774 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
775 |
|
776 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
777 |
last_hidden_state=sequence_output,
|
778 |
pooler_output=pooled_output,
|
779 |
-
hidden_states=
|
780 |
-
attentions=
|
781 |
)
|
782 |
|
783 |
|
|
|
749 |
else:
|
750 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
751 |
|
752 |
+
outputs = self.esm(
|
753 |
+
input_ids,
|
|
|
|
|
754 |
attention_mask=attention_mask,
|
755 |
+
position_ids=position_ids,
|
756 |
inputs_embeds=inputs_embeds,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
757 |
output_hidden_states=output_hidden_states,
|
758 |
output_attentions=output_attentions,
|
759 |
)
|
760 |
+
sequence_output = outputs.last_hidden_state
|
761 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
762 |
|
763 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
764 |
last_hidden_state=sequence_output,
|
765 |
pooler_output=pooled_output,
|
766 |
+
hidden_states=outputs.hidden_states,
|
767 |
+
attentions=outputs.attentions,
|
768 |
)
|
769 |
|
770 |
|