Spaces:

yinuozhang commited on
Commit
f2de080
1 Parent(s): 4831657

Upload 3 files

Browse files
Files changed (3) hide show
  1. __init__.py +7 -0
  2. configuration.py +47 -0
  3. model.py +275 -0
__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .configuration import MetaLATTEConfig
4
+ from .model import MultitaskProteinModel
5
+
6
+ AutoConfig.register("metalatte", MetaLATTEConfig)
7
+ AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
configuration.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MetaLATTEConfig(PretrainedConfig):
4
+ model_type = "metalatte"
5
+
6
+ def __init__(
7
+ self,
8
+ num_labels=15,
9
+ hidden_size=1280,
10
+ num_hidden_layers=33,
11
+ num_attention_heads=20,
12
+ intermediate_size=5120,
13
+ hidden_act="gelu",
14
+ hidden_dropout_prob=0.0,
15
+ attention_probs_dropout_prob=0.0,
16
+ max_position_embeddings=1026,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-5,
19
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
20
+ num_layers_to_finetune=2,
21
+ num_linear_layers=3,
22
+ hidden_dim=512,
23
+ **kwargs
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.num_labels = num_labels
27
+ self.hidden_size = hidden_size
28
+ self.num_hidden_layers = num_hidden_layers
29
+ self.num_attention_heads = num_attention_heads
30
+ self.intermediate_size = intermediate_size
31
+ self.hidden_act = hidden_act
32
+ self.hidden_dropout_prob = hidden_dropout_prob
33
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.initializer_range = initializer_range
36
+ self.layer_norm_eps = layer_norm_eps
37
+ self.esm_model_name = esm_model_name
38
+ self.num_layers_to_finetune = num_layers_to_finetune
39
+ self.num_linear_layers = num_linear_layers
40
+ self.hidden_dim = hidden_dim
41
+
42
+ @classmethod
43
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
44
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
45
+
46
+ def save_pretrained(self, save_directory):
47
+ super().save_pretrained(save_directory)
model.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' # do this before importing pytorch
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from transformers import EsmModel
8
+ import torch
9
+ import numpy as np
10
+ from lightning.pytorch import seed_everything
11
+ from typing import Tuple
12
+ import torch
13
+ import gc
14
+ from torch.optim.lr_scheduler import _LRScheduler
15
+ from transformers import EsmModel, PreTrainedModel
16
+ from configuration import MetaLATTEConfig
17
+
18
+ seed_everything(42)
19
+
20
+ class GELU(nn.Module):
21
+ """Implementation of the gelu activation function.
22
+
23
+ For information: OpenAI GPT's gelu is slightly different
24
+ (and gives slightly different results):
25
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
26
+ """
27
+ def forward(self, x):
28
+ return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
29
+
30
+
31
+ def rotate_half(x):
32
+ x1, x2 = x.chunk(2, dim=-1) # x: B, L, H, hidden # x1: B, L, H, hidden // 2
33
+ return torch.cat((-x2, x1), dim=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb(x, cos, sin):
37
+ # Assuming x has shape (B, L, H, HIDDEN_DIM)
38
+ # cos and sin have shape (1, L, HIDDEN_DIM)
39
+ cos = cos.unsqueeze(2) # (1, L, 1, HIDDEN_DIM)
40
+ sin = sin.unsqueeze(2) # (1, L, 1, HIDDEN_DIM)
41
+ return (x * cos) + (rotate_half(x) * sin)
42
+
43
+
44
+ class RotaryEmbedding(torch.nn.Module):
45
+ """
46
+ The rotary position embeddings from RoFormer_ (Su et. al).
47
+ A crucial insight from the method is that the query and keys are
48
+ transformed by rotation matrices which depend on the relative positions.
49
+ Other implementations are available in the Rotary Transformer repo_ and in
50
+ GPT-NeoX_, GPT-NeoX was an inspiration
51
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
52
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
53
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
54
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
55
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
56
+ """
57
+
58
+ def __init__(self, dim: int, *_, **__):
59
+ super().__init__()
60
+ # Generate and save the inverse frequency buffer (non trainable)
61
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
62
+ self.register_buffer("inv_freq", inv_freq)
63
+
64
+ self._seq_len_cached = None
65
+ self._cos_cached = None
66
+ self._sin_cached = None
67
+
68
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
69
+ seq_len = x.shape[seq_dimension]
70
+
71
+ # Reset the tables if the sequence length has changed,
72
+ # or if we're on a new device (possibly due to tracing for instance)
73
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
74
+ self._seq_len_cached = seq_len
75
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
76
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq) # L, 256
77
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # L, 512
78
+
79
+ self._cos_cached = emb.cos()[None, :, :] # 1, L, 512
80
+ self._sin_cached = emb.sin()[None, :, :] # 1, L, 512
81
+
82
+ return self._cos_cached, self._sin_cached
83
+
84
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k)
86
+
87
+ return (
88
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), # B, L, H, hidden
89
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
90
+ )
91
+
92
+
93
+ def macro_f1(y_true, y_pred, thresholds):
94
+ y_pred_binary = (y_pred >= thresholds).float()
95
+ tp = (y_true * y_pred_binary).sum(dim=0)
96
+ fp = ((1 - y_true) * y_pred_binary).sum(dim=0)
97
+ fn = (y_true * (1 - y_pred_binary)).sum(dim=0)
98
+ precision = tp / (tp + fp + 1e-7)
99
+ recall = tp / (tp + fn + 1e-7)
100
+ f1 = 2 * precision * recall / (precision + recall + 1e-7)
101
+ macro_f1 = f1.mean()
102
+ return macro_f1
103
+
104
+ def safeguard_softmax(logits, dim=-1):
105
+ # remove max number to prevent exp() to be INF
106
+ max_logits, _ = logits.max(dim=dim, keepdim=True)
107
+ exp_logits = torch.exp(logits - max_logits)
108
+ exp_sum = exp_logits.sum(dim=dim, keepdim=True)
109
+ probs = exp_logits / (exp_sum + 1e-7) # Adding a small epsilon to prevent division by zero
110
+ return probs
111
+
112
+ class PositionalAttentionHead(nn.Module):
113
+ def __init__(self, hidden_dim, n_heads):
114
+ super(PositionalAttentionHead, self).__init__()
115
+ self.n_heads = n_heads
116
+ self.hidden_dim = hidden_dim
117
+ self.head_dim = hidden_dim // n_heads
118
+ self.preattn_ln = nn.LayerNorm(self.head_dim)
119
+ self.Q = nn.Linear(self.head_dim, self.head_dim, bias=False)
120
+ self.K = nn.Linear(self.head_dim, self.head_dim, bias=False)
121
+ self.V = nn.Linear(self.head_dim, self.head_dim, bias=False)
122
+ self.rot_emb = RotaryEmbedding(self.head_dim)
123
+
124
+ def forward(self, x, attention_mask):
125
+ batch_size, seq_len, _ = x.size() # B, L, H
126
+ x = x.view(batch_size, seq_len, self.n_heads, self.head_dim)
127
+ x = self.preattn_ln(x)
128
+
129
+ q = self.Q(x)
130
+ k = self.K(x)
131
+ v = self.V(x)
132
+
133
+ q, k = self.rot_emb(q, k)
134
+ gc.collect()
135
+ torch.cuda.empty_cache()
136
+
137
+ attn_scores = torch.einsum('bqhd,bkhd->bhqk', q, k) / math.sqrt(self.head_dim)
138
+ #print(attention_mask.unsqueeze(1).shape)
139
+ #print(attention_mask.unsqueeze(1).unsqueeze(1).shape)
140
+ attn_scores = attn_scores.masked_fill(torch.logical_not(attention_mask.unsqueeze(1).unsqueeze(1)), float("-inf")) # B, H, L, L
141
+
142
+ attn_probs = safeguard_softmax(attn_scores, dim=-1)
143
+
144
+ x = torch.einsum('bhqk,bkhd->bqhd', attn_probs, v)
145
+ x = x.reshape(batch_size, seq_len, self.hidden_dim) # B, L, H
146
+ gc.collect()
147
+ torch.cuda.empty_cache()
148
+ return x, attn_probs
149
+
150
+ class CosineAnnealingWithWarmup(_LRScheduler):
151
+ # Implement based on Llama paper's description
152
+ # https://arxiv.org/abs/2302.13971
153
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
154
+ self.warmup_steps = warmup_steps
155
+ self.total_steps = total_steps
156
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
157
+ super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
158
+
159
+ def get_lr(self):
160
+ if self.last_epoch < self.warmup_steps:
161
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
162
+
163
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
164
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
165
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
166
+
167
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
168
+
169
+ class RobertaLMHead(nn.Module):
170
+ """Head for masked language modeling."""
171
+ def __init__(self, embed_dim, output_dim, weight):
172
+ super().__init__()
173
+ self.dense = nn.Linear(embed_dim, embed_dim)
174
+ self.layer_norm = nn.LayerNorm(embed_dim)
175
+ self.weight = weight
176
+ self.gelu = GELU()
177
+ self.bias = nn.Parameter(torch.zeros(output_dim))
178
+ def forward(self, features):
179
+ x = self.dense(features)
180
+ x = self.gelu(x)
181
+ x = self.layer_norm(x)
182
+ # project back to size of vocabulary with bias
183
+ x = F.linear(x, self.weight) + self.bias
184
+ return x
185
+
186
+
187
+ class MultitaskProteinModel(PreTrainedModel):
188
+ config_class = MetaLATTEConfig
189
+ base_model_prefix = "metalatte"
190
+ def __init__(self, config):
191
+ super().__init__(config)
192
+ self.config = config
193
+ self.esm_model = EsmModel.from_pretrained(self.config.esm_model_name)
194
+ # layer freezing for the original esm model
195
+ # first freeze all
196
+ for param in self.esm_model.parameters():
197
+ param.requires_grad = False
198
+ # unfreeze the required layers
199
+ for i in range(config.num_layers_to_finetune):
200
+ for param in self.esm_model.encoder.layer[-i-1].parameters():
201
+ param.requires_grad = True
202
+ self.lm_head = RobertaLMHead(embed_dim = 1280, output_dim=33, weight=self.esm_model.embeddings.word_embeddings.weight)
203
+ # esm_dim should be 1280
204
+ self.attn_head = PositionalAttentionHead(self.config.hidden_size, self.config.num_attention_heads)
205
+ self.attn_ln = nn.LayerNorm(self.config.hidden_size)
206
+ self.attn_skip = nn.Linear(self.config.hidden_size, self.config.hidden_size)
207
+ self.linear_layers = nn.ModuleList()
208
+ # Add linear layers after the attention head
209
+ for _ in range(self.config.num_linear_layers):
210
+ self.linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size))
211
+ self.reduction_layers = nn.Sequential(
212
+ nn.Linear(self.config.hidden_size, self.config.hidden_dim),
213
+ GELU(),
214
+ nn.Linear(self.config.hidden_dim, self.config.num_labels)
215
+ )
216
+ self.clf_ln = nn.LayerNorm(self.config.hidden_size)
217
+ self.classification_thresholds = nn.Parameter(torch.tensor([0.5]*self.config.num_labels))
218
+
219
+ # Initialize weights and apply final processing
220
+ self.post_init()
221
+
222
+ @classmethod
223
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
224
+ config = kwargs.pop("config", None)
225
+ if config is None:
226
+ config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
227
+
228
+ model = cls(config)
229
+ state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
230
+ model.load_state_dict(state_dict, strict=False)
231
+ return model
232
+
233
+ def forward(self, input_ids, attention_mask=None):
234
+ outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
235
+ embeddings = outputs.last_hidden_state
236
+ attention_masks = attention_mask
237
+
238
+ x_pool, x_attns = self.attn_head(embeddings, attention_masks)
239
+ x_pool = self.attn_ln(x_pool + self.attn_skip(x_pool)) # Added skip connection for the attention layer
240
+
241
+ for linear_layer in self.linear_layers:
242
+ residue = x_pool
243
+ x_pool = linear_layer(x_pool) # 1280 -> 1280
244
+ x_pool = F.silu(x_pool)
245
+ x_pool = x_pool + residue # Skip connection
246
+
247
+ x_weighted = torch.einsum('bhlk,bld->bhld', x_attns, x_pool) # (B, H, L, 1280)
248
+ x_combined = x_weighted.mean(dim=1) # Average over heads: (B, L, 1280)
249
+ x_combined = self.clf_ln(x_combined)
250
+
251
+ mlm_logits = self.lm_head(x_combined)
252
+ attention_masks = attention_masks.unsqueeze(-1).float() # (B, L, 1)
253
+ attention_sum = attention_masks.sum(dim=1, keepdim=True) # (B, 1, 1)
254
+ x_combined_masked = (x_combined * attention_masks).sum(dim=1) / attention_sum.squeeze(1) # (B, 1280)
255
+
256
+ # Compute classification logits
257
+ x_pred = self.reduction_layers(x_combined_masked)
258
+ gc.collect()
259
+ torch.cuda.empty_cache()
260
+ return x_pred, x_attns, x_combined_masked, mlm_logits
261
+
262
+ def predict(self, input_ids, attention_mask=None):
263
+ x_pred, _, _, _ = self.forward(input_ids, attention_mask)
264
+ classification_output = torch.sigmoid(x_pred)
265
+ predictions = (classification_output >= self.classification_thresholds).float()
266
+
267
+ for i, pred in enumerate(predictions):
268
+ if pred.sum() == 0:
269
+ weighted_probs = classification_output[i]
270
+ max_class = torch.argmax(weighted_probs)
271
+ predictions[i, max_class] = 1.0
272
+
273
+ return classification_output, predictions
274
+
275
+