|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from typing import Optional, Tuple, Union
|
|
from einops import rearrange
|
|
from transformers.modeling_outputs import (
|
|
MaskedLMOutput,
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
SequenceClassifierOutput,
|
|
TokenClassifierOutput
|
|
)
|
|
from transformers.models.esm.modeling_esm import (
|
|
RotaryEmbedding,
|
|
EsmContactPredictionHead,
|
|
EsmIntermediate,
|
|
EsmOutput,
|
|
EsmPooler,
|
|
EsmLMHead,
|
|
EsmSelfOutput,
|
|
EsmClassificationHead,
|
|
EsmPreTrainedModel,
|
|
create_position_ids_from_input_ids,
|
|
gelu
|
|
)
|
|
|
|
|
|
class EsmEmbeddings(nn.Module):
|
|
"""
|
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
if config.emb_layer_norm_before:
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
else:
|
|
self.layer_norm = None
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.position_embeddings = nn.Embedding(
|
|
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
|
)
|
|
|
|
|
|
self.mask_token_id = config.mask_token_id
|
|
|
|
def forward(
|
|
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
):
|
|
if position_ids is None:
|
|
if input_ids is not None:
|
|
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
|
else:
|
|
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
embeddings = inputs_embeds
|
|
|
|
if self.position_embedding_type == "absolute":
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings = embeddings + position_embeddings
|
|
|
|
if self.layer_norm is not None:
|
|
embeddings = self.layer_norm(embeddings)
|
|
if attention_mask is not None:
|
|
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
|
|
return embeddings
|
|
|
|
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
|
"""
|
|
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
|
|
|
Args:
|
|
inputs_embeds: torch.Tensor
|
|
|
|
Returns: torch.Tensor
|
|
"""
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
sequence_length = input_shape[1]
|
|
|
|
position_ids = torch.arange(
|
|
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
|
)
|
|
return position_ids.unsqueeze(0).expand(input_shape)
|
|
|
|
|
|
class EsmSelfAttention(nn.Module):
|
|
def __init__(self, config, position_embedding_type=None):
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({config.num_attention_heads})"
|
|
)
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.scale = self.attention_head_size**-0.5
|
|
|
|
self.dropout_prob = config.attention_probs_dropout_prob
|
|
self.position_embedding_type = position_embedding_type or getattr(
|
|
config, "position_embedding_type", "absolute"
|
|
)
|
|
self.rotary_embeddings = None
|
|
if self.position_embedding_type == "rotary":
|
|
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
|
|
|
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
|
return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
) -> Tuple[torch.Tensor]:
|
|
query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
|
if self.position_embedding_type == "rotary":
|
|
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
|
|
|
context_layer = F.scaled_dot_product_attention(
|
|
query_layer,
|
|
key_layer,
|
|
value_layer,
|
|
attn_mask=attention_mask,
|
|
dropout_p=self.dropout_prob,
|
|
scale=1.0
|
|
)
|
|
return rearrange(context_layer, 'b h s d -> b s (h d)')
|
|
|
|
|
|
class EsmAttention(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.self = EsmSelfAttention(config)
|
|
self.output = EsmSelfOutput(config)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
):
|
|
hidden_states_ln = self.LayerNorm(hidden_states)
|
|
attention_output = self.self(
|
|
hidden_states_ln,
|
|
attention_mask,
|
|
)
|
|
return self.output(attention_output, hidden_states)
|
|
|
|
|
|
class EsmLayer(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = EsmAttention(config)
|
|
self.intermediate = EsmIntermediate(config)
|
|
self.output = EsmOutput(config)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
):
|
|
attention_output = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
)
|
|
layer_output = self.feed_forward_chunk(attention_output)
|
|
return layer_output
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
attention_output_ln = self.LayerNorm(attention_output)
|
|
intermediate_output = self.intermediate(attention_output_ln)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
class EsmEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
output_hidden_states=False,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
for layer_module in self.layer:
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
hidden_states = self._gradient_checkpointing_func(
|
|
layer_module.__call__,
|
|
hidden_states,
|
|
attention_mask,
|
|
)
|
|
else:
|
|
hidden_states = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
)
|
|
|
|
if self.emb_layer_norm_after:
|
|
hidden_states = self.emb_layer_norm_after(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
)
|
|
|
|
|
|
class FastEsmModel(EsmPreTrainedModel):
|
|
def __init__(self, config, add_pooling_layer=True):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.embeddings = EsmEmbeddings(config)
|
|
self.encoder = EsmEncoder(config)
|
|
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
|
self.contact_head = EsmContactPredictionHead(
|
|
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
|
|
)
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
if output_attentions is not None:
|
|
raise ValueError("output_attentions is not supported by F.scaled_dot_product_attention")
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
batch_size, seq_length = input_shape
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
|
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
|
batch_size, 1, seq_length, seq_length
|
|
)
|
|
|
|
attention_mask = attention_mask.to(dtype=embedding_output.dtype)
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(embedding_output.dtype).min
|
|
else:
|
|
extended_attention_mask = None
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
output_hidden_states=output_hidden_states,
|
|
)
|
|
sequence_output = encoder_outputs.last_hidden_state
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
)
|
|
|
|
|
|
class FastEsmForMaskedLM(EsmPreTrainedModel):
|
|
_tied_weights_keys = ["lm_head.decoder.weight"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.esm = FastEsmModel(config, add_pooling_layer=False)
|
|
self.lm_head = EsmLMHead(config)
|
|
self.loss_fct = nn.CrossEntropyLoss()
|
|
self.init_weights()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.lm_head.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.lm_head.decoder = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
) -> Union[Tuple, MaskedLMOutput]:
|
|
outputs = self.esm(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
output_hidden_states=output_hidden_states,
|
|
output_attentions=output_attentions,
|
|
)
|
|
sequence_output = outputs.last_hidden_state
|
|
prediction_scores = self.lm_head(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
labels = labels.to(prediction_scores.device)
|
|
loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
return MaskedLMOutput(
|
|
loss=loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
)
|
|
|
|
def predict_contacts(self, tokens, attention_mask):
|
|
raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
|
|
|
|
|
|
class FastEsmForSequenceClassification(EsmPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
self.esm = FastEsmModel(config, add_pooling_layer=False)
|
|
self.classifier = EsmClassificationHead(config)
|
|
self.mse = nn.MSELoss()
|
|
self.ce = nn.CrossEntropyLoss()
|
|
self.bce = nn.BCEWithLogitsLoss()
|
|
self.init_weights()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
) -> Union[Tuple, SequenceClassifierOutput]:
|
|
outputs = self.esm(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
)
|
|
sequence_output = outputs.last_hidden_state
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
labels = labels.to(logits.device)
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
if self.num_labels == 1:
|
|
loss = self.mse(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = self.mse(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss = self.bce(logits, labels)
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
)
|
|
|
|
|
|
class FastEsmForTokenClassification(EsmPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.esm = FastEsmModel(config, add_pooling_layer=False)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
self.loss_fct = nn.CrossEntropyLoss()
|
|
self.init_weights()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
) -> Union[Tuple, TokenClassifierOutput]:
|
|
outputs = self.esm(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
)
|
|
sequence_output = outputs.last_hidden_state
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
labels = labels.to(logits.device)
|
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Test the hidden state differences between the FastEsmModel and the HF EsmModel.
|
|
In full precision, the differences are very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
|
|
In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
|
|
"""
|
|
import random
|
|
from transformers import EsmModel as TransformersEsmModel, EsmTokenizer
|
|
|
|
model_paths = [
|
|
"facebook/esm2_t6_8M_UR50D",
|
|
"facebook/esm2_t12_35M_UR50D",
|
|
"facebook/esm2_t30_150M_UR50D",
|
|
"facebook/esm2_t33_650M_UR50D",
|
|
]
|
|
canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
|
length = 64
|
|
seq_count = 100
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
|
|
|
|
def generate_random_sequence(length: int) -> str:
|
|
return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
|
|
|
|
print("Percentage of hidden states that are within the tolerance:")
|
|
for model_path in model_paths:
|
|
print(f"Testing {model_path}...")
|
|
tokenizer = EsmTokenizer.from_pretrained(model_path)
|
|
fast_model = FastEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
|
|
model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
|
|
|
|
counts = [0] * len(tolerances)
|
|
for _ in range(seq_count):
|
|
example_seq = generate_random_sequence(length)
|
|
fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
|
fast_output = fast_model(fast_tokens).last_hidden_state.detach().cpu()
|
|
|
|
model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
|
model_output = model(model_tokens).last_hidden_state.detach().cpu()
|
|
|
|
for i, atol in enumerate(tolerances):
|
|
if torch.allclose(fast_output, model_output, atol=atol):
|
|
counts[i] += 1
|
|
|
|
print(f"{model_path}:")
|
|
for i, atol in enumerate(tolerances):
|
|
print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
|
|
|
|
model.cpu()
|
|
fast_model.cpu()
|
|
del model
|
|
del fast_model
|
|
torch.cuda.empty_cache()
|
|
|