FastESM2_650 / modeling_fastesm.py
lhallee's picture
Update modeling_fastesm.py
874ce57 verified
raw
history blame
25.3 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Optional, Tuple, Union
from einops import rearrange
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import (
MaskedLMOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
SequenceClassifierOutput,
TokenClassifierOutput
)
from transformers.models.esm.modeling_esm import (
EsmIntermediate,
EsmOutput,
EsmPooler,
EsmLMHead,
EsmSelfOutput,
EsmClassificationHead,
create_position_ids_from_input_ids,
)
class FastEsmConfig(PretrainedConfig):
model_type = "fast_esm"
def __init__(
self,
vocab_size=None,
mask_token_id=None,
pad_token_id=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1026,
initializer_range=0.02,
layer_norm_eps=1e-12,
position_embedding_type="absolute",
emb_layer_norm_before=None,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.emb_layer_norm_before = emb_layer_norm_before
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = super().to_dict()
return output
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin):
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
def __init__(self, dim: int):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=2):
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
class EsmEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
if config.emb_layer_norm_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
self.layer_norm = None
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
# Token dropout does not work correctly so we disable it
# self.token_dropout = config.token_dropout
self.mask_token_id = config.mask_token_id
def forward(
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
class EsmSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.scale = self.attention_head_size**-0.5
self.dropout_prob = config.attention_probs_dropout_prob
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
self.rotary_embeddings = None
if self.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
context_layer = F.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob,
scale=1.0
)
return rearrange(context_layer, 'b h s d -> b s (h d)')
class EsmAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = EsmSelfAttention(config)
self.output = EsmSelfOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
):
hidden_states_ln = self.LayerNorm(hidden_states)
attention_output = self.self(
hidden_states_ln,
attention_mask,
)
return self.output(attention_output, hidden_states)
class EsmLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = EsmAttention(config)
self.intermediate = EsmIntermediate(config)
self.output = EsmOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
):
attention_output = self.attention(
hidden_states,
attention_mask,
)
layer_output = self.feed_forward_chunk(attention_output)
return layer_output
def feed_forward_chunk(self, attention_output):
attention_output_ln = self.LayerNorm(attention_output)
intermediate_output = self.intermediate(attention_output_ln)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class EsmEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_hidden_states=False,
):
all_hidden_states = () if output_hidden_states else None
for layer_module in self.layer:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
)
else:
hidden_states = layer_module(
hidden_states,
attention_mask,
)
if self.emb_layer_norm_after:
hidden_states = self.emb_layer_norm_after(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
class FastEsmPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = FastEsmConfig
base_model_prefix = "fastesm"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_input_embeddings(self) -> nn.Module:
try:
return self.embeddings.word_embeddings
except AttributeError:
return self.esm.embeddings.word_embeddings
class FastEsmModel(FastEsmPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = EsmEmbeddings(config)
self.encoder = EsmEncoder(config)
self.pooler = EsmPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
if output_attentions is not None:
raise ValueError("output_attentions is not supported by F.scaled_dot_product_attention")
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
# Prepare attention mask
if attention_mask is not None:
# attention_mask shape should be (batch_size, 1, 1, seq_length)
# Expand to (batch_size, 1, seq_length, seq_length)
extended_attention_mask = attention_mask[:, None, None, :].expand(
batch_size, 1, seq_length, seq_length
).bool()
else:
extended_attention_mask = None
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs.last_hidden_state
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
class FastEsmForMaskedLM(FastEsmPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config):
super().__init__(config)
self.esm = FastEsmModel(config, add_pooling_layer=False)
self.lm_head = EsmLMHead(config)
self.loss_fct = nn.CrossEntropyLoss()
self.init_weights()
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
sequence_output = outputs.last_hidden_state
prediction_scores = self.lm_head(sequence_output)
loss = None
if labels is not None:
labels = labels.to(prediction_scores.device)
loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
return MaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
)
def predict_contacts(self, tokens, attention_mask):
raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.esm = FastEsmModel(config, add_pooling_layer=False)
self.classifier = EsmClassificationHead(config)
self.mse = nn.MSELoss()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
self.init_weights()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = self.mse(logits.squeeze(), labels.squeeze())
else:
loss = self.mse(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = self.bce(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class FastEsmForTokenClassification(FastEsmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.esm = FastEsmModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.loss_fct = nn.CrossEntropyLoss()
self.init_weights()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
if __name__ == "__main__":
"""
Test the hidden state differences between the FastEsmModel and the HF EsmModel.
In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
"""
import random
from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
model_paths = [
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
#"facebook/esm2_t30_150M_UR50D",
#"facebook/esm2_t33_650M_UR50D",
]
canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
length = 64
seq_count = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
def generate_random_sequence(length: int) -> str:
return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
print("Percentage of hidden states that are within the tolerance:")
for model_path in model_paths:
print(f"Testing {model_path}...")
tokenizer = EsmTokenizer.from_pretrained(model_path)
config = FastEsmConfig.from_pretrained(model_path)
fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
counts = [0] * len(tolerances)
for _ in range(seq_count):
example_seq = generate_random_sequence(length)
fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
for i, atol in enumerate(tolerances):
if torch.allclose(fast_output, model_output, atol=atol):
counts[i] += 1
print(f"{model_path}:")
for i, atol in enumerate(tolerances):
print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
model.cpu()
fast_model.cpu()
del model
del fast_model
torch.cuda.empty_cache()