|
import os |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from transformers import EsmModel |
|
import torch |
|
import numpy as np |
|
from lightning.pytorch import seed_everything |
|
from typing import Tuple |
|
import torch |
|
import gc |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from transformers import EsmModel, PreTrainedModel |
|
from configuration import MetaLATTEConfig |
|
from urllib.parse import urljoin |
|
seed_everything(42) |
|
|
|
class GELU(nn.Module): |
|
"""Implementation of the gelu activation function. |
|
|
|
For information: OpenAI GPT's gelu is slightly different |
|
(and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
def forward(self, x): |
|
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
def rotate_half(x): |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(x, cos, sin): |
|
|
|
|
|
cos = cos.unsqueeze(2) |
|
sin = sin.unsqueeze(2) |
|
return (x * cos) + (rotate_half(x) * sin) |
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
""" |
|
The rotary position embeddings from RoFormer_ (Su et. al). |
|
A crucial insight from the method is that the query and keys are |
|
transformed by rotation matrices which depend on the relative positions. |
|
Other implementations are available in the Rotary Transformer repo_ and in |
|
GPT-NeoX_, GPT-NeoX was an inspiration |
|
.. _RoFormer: https://arxiv.org/abs/2104.09864 |
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer |
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
|
.. warning: Please note that this embedding is not registered on purpose, as it is transformative |
|
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis |
|
""" |
|
|
|
def __init__(self, dim: int, *_, **__): |
|
super().__init__() |
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
self._seq_len_cached = None |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
|
|
def _update_cos_sin_tables(self, x, seq_dimension=1): |
|
seq_len = x.shape[seq_dimension] |
|
|
|
|
|
|
|
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: |
|
self._seq_len_cached = seq_len |
|
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
|
|
self._cos_cached = emb.cos()[None, :, :] |
|
self._sin_cached = emb.sin()[None, :, :] |
|
|
|
return self._cos_cached, self._sin_cached |
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) |
|
|
|
return ( |
|
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
|
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
|
) |
|
|
|
|
|
def macro_f1(y_true, y_pred, thresholds): |
|
y_pred_binary = (y_pred >= thresholds).float() |
|
tp = (y_true * y_pred_binary).sum(dim=0) |
|
fp = ((1 - y_true) * y_pred_binary).sum(dim=0) |
|
fn = (y_true * (1 - y_pred_binary)).sum(dim=0) |
|
precision = tp / (tp + fp + 1e-7) |
|
recall = tp / (tp + fn + 1e-7) |
|
f1 = 2 * precision * recall / (precision + recall + 1e-7) |
|
macro_f1 = f1.mean() |
|
return macro_f1 |
|
|
|
def safeguard_softmax(logits, dim=-1): |
|
|
|
max_logits, _ = logits.max(dim=dim, keepdim=True) |
|
exp_logits = torch.exp(logits - max_logits) |
|
exp_sum = exp_logits.sum(dim=dim, keepdim=True) |
|
probs = exp_logits / (exp_sum + 1e-7) |
|
return probs |
|
|
|
class PositionalAttentionHead(nn.Module): |
|
def __init__(self, hidden_dim, n_heads): |
|
super(PositionalAttentionHead, self).__init__() |
|
self.n_heads = n_heads |
|
self.hidden_dim = hidden_dim |
|
self.head_dim = hidden_dim // n_heads |
|
self.preattn_ln = nn.LayerNorm(self.head_dim) |
|
self.Q = nn.Linear(self.head_dim, self.head_dim, bias=False) |
|
self.K = nn.Linear(self.head_dim, self.head_dim, bias=False) |
|
self.V = nn.Linear(self.head_dim, self.head_dim, bias=False) |
|
self.rot_emb = RotaryEmbedding(self.head_dim) |
|
|
|
def forward(self, x, attention_mask): |
|
batch_size, seq_len, _ = x.size() |
|
x = x.view(batch_size, seq_len, self.n_heads, self.head_dim) |
|
x = self.preattn_ln(x) |
|
|
|
q = self.Q(x) |
|
k = self.K(x) |
|
v = self.V(x) |
|
|
|
q, k = self.rot_emb(q, k) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
attn_scores = torch.einsum('bqhd,bkhd->bhqk', q, k) / math.sqrt(self.head_dim) |
|
|
|
|
|
attn_scores = attn_scores.masked_fill(torch.logical_not(attention_mask.unsqueeze(1).unsqueeze(1)), float("-inf")) |
|
|
|
attn_probs = safeguard_softmax(attn_scores, dim=-1) |
|
|
|
x = torch.einsum('bhqk,bkhd->bqhd', attn_probs, v) |
|
x = x.reshape(batch_size, seq_len, self.hidden_dim) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return x, attn_probs |
|
|
|
class CosineAnnealingWithWarmup(_LRScheduler): |
|
|
|
|
|
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1): |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.eta_ratio = eta_ratio |
|
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_steps: |
|
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs] |
|
|
|
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
|
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) |
|
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio |
|
|
|
return [decayed_lr * base_lr for base_lr in self.base_lrs] |
|
|
|
class RobertaLMHead(nn.Module): |
|
"""Head for masked language modeling.""" |
|
def __init__(self, embed_dim, output_dim, weight): |
|
super().__init__() |
|
self.dense = nn.Linear(embed_dim, embed_dim) |
|
self.layer_norm = nn.LayerNorm(embed_dim) |
|
self.weight = weight |
|
self.gelu = GELU() |
|
self.bias = nn.Parameter(torch.zeros(output_dim)) |
|
def forward(self, features): |
|
x = self.dense(features) |
|
x = self.gelu(x) |
|
x = self.layer_norm(x) |
|
|
|
x = F.linear(x, self.weight) + self.bias |
|
return x |
|
|
|
|
|
class MultitaskProteinModel(PreTrainedModel): |
|
config_class = MetaLATTEConfig |
|
base_model_prefix = "metalatte" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.esm_model = EsmModel.from_pretrained(self.config.esm_model_name) |
|
|
|
|
|
for param in self.esm_model.parameters(): |
|
param.requires_grad = False |
|
|
|
for i in range(config.num_layers_to_finetune): |
|
for param in self.esm_model.encoder.layer[-i-1].parameters(): |
|
param.requires_grad = True |
|
self.lm_head = RobertaLMHead(embed_dim = 1280, output_dim=33, weight=self.esm_model.embeddings.word_embeddings.weight) |
|
|
|
self.attn_head = PositionalAttentionHead(self.config.hidden_size, self.config.num_attention_heads) |
|
self.attn_ln = nn.LayerNorm(self.config.hidden_size) |
|
self.attn_skip = nn.Linear(self.config.hidden_size, self.config.hidden_size) |
|
self.linear_layers = nn.ModuleList() |
|
|
|
for _ in range(self.config.num_linear_layers): |
|
self.linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size)) |
|
self.reduction_layers = nn.Sequential( |
|
nn.Linear(self.config.hidden_size, self.config.hidden_dim), |
|
GELU(), |
|
nn.Linear(self.config.hidden_dim, self.config.num_labels) |
|
) |
|
self.clf_ln = nn.LayerNorm(self.config.hidden_size) |
|
self.classification_thresholds = nn.Parameter(torch.tensor([0.5]*self.config.num_labels)) |
|
|
|
|
|
self.post_init() |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
config = kwargs.pop("config", None) |
|
if config is None: |
|
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
model = cls(config) |
|
|
|
try: |
|
state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin") |
|
state_dict = torch.hub.load_state_dict_from_url( |
|
state_dict_url, |
|
map_location=torch.device('cpu') |
|
)['state_dict'] |
|
model.load_state_dict(state_dict, strict=False) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}") |
|
|
|
return model |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
embeddings = outputs.last_hidden_state |
|
attention_masks = attention_mask |
|
|
|
x_pool, x_attns = self.attn_head(embeddings, attention_masks) |
|
x_pool = self.attn_ln(x_pool + self.attn_skip(x_pool)) |
|
|
|
for linear_layer in self.linear_layers: |
|
residue = x_pool |
|
x_pool = linear_layer(x_pool) |
|
x_pool = F.silu(x_pool) |
|
x_pool = x_pool + residue |
|
|
|
x_weighted = torch.einsum('bhlk,bld->bhld', x_attns, x_pool) |
|
x_combined = x_weighted.mean(dim=1) |
|
x_combined = self.clf_ln(x_combined) |
|
|
|
mlm_logits = self.lm_head(x_combined) |
|
attention_masks = attention_masks.unsqueeze(-1).float() |
|
attention_sum = attention_masks.sum(dim=1, keepdim=True) |
|
x_combined_masked = (x_combined * attention_masks).sum(dim=1) / attention_sum.squeeze(1) |
|
|
|
|
|
x_pred = self.reduction_layers(x_combined_masked) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return x_pred, x_attns, x_combined_masked, mlm_logits |
|
|
|
def predict(self, input_ids, attention_mask=None): |
|
x_pred, _, _, _ = self.forward(input_ids, attention_mask) |
|
classification_output = torch.sigmoid(x_pred) |
|
predictions = (classification_output >= self.classification_thresholds).float() |
|
|
|
for i, pred in enumerate(predictions): |
|
if pred.sum() == 0: |
|
weighted_probs = classification_output[i] |
|
max_class = torch.argmax(weighted_probs) |
|
predictions[i, max_class] = 1.0 |
|
|
|
return classification_output, predictions |
|
|
|
|