Niksa Praljak commited on
Commit
0655b48
1 Parent(s): 14fddb7

BioM3-PenCL push with no weights

Browse files
Stage1_source/PL_wrapper.py ADDED
@@ -0,0 +1,1613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch fucntions
2
+ import torch
3
+ from torch import nn, optim
4
+ from torch.nn import functional as F
5
+ import torch.distributed as dist
6
+
7
+ # PL functions
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import Trainer, seed_everything
10
+
11
+ # misc functions
12
+ import itertools
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import sys
16
+ from tqdm import tqdm
17
+ import time
18
+
19
+ # other learning packages
20
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
21
+
22
+ # our packages
23
+ import Stage1_source.helper_funcs as helper_tools
24
+ import Stage1_source.preprocess as prep
25
+ import Stage1_source.model as mod
26
+
27
+
28
+ ######################
29
+ # Default PL wrapper #
30
+ ######################
31
+
32
+ class PL_PEN_CL(pl.LightningModule):
33
+
34
+
35
+ def __init__(
36
+ self,
37
+ args: any,
38
+ model: nn.Module,
39
+ text_tokenizer: any,
40
+ sequence_tokenizer: any
41
+ ):
42
+
43
+ super().__init__()
44
+ # arguments
45
+ self.script_args = args
46
+
47
+ # model components
48
+ self.model = model
49
+
50
+ # tokenizers
51
+ self.text_tokenizer = text_tokenizer
52
+ self.sequence_tokenizer = sequence_tokenizer
53
+
54
+ # validation tracker for outputs
55
+ self.val_text_joint_latents = []
56
+ self.val_seq_joint_latents = []
57
+
58
+ # prediction tracker for outputs
59
+ self.predict_text_joint_latents = []
60
+ self.predict_seq_joint_latents = []
61
+
62
+ def forward(
63
+ self,
64
+ x_t: torch.Tensor,
65
+ x_s: torch.Tensor
66
+ ) -> (
67
+ torch.Tensor,
68
+ torch.Tensor,
69
+ torch.Tensor
70
+ ):
71
+
72
+ outputs = self.model(
73
+ x_t=x_t,
74
+ x_s=x_s
75
+ )
76
+
77
+ return (
78
+ outputs['text_joint_latent'],
79
+ outputs['seq_joint_latent'],
80
+ )
81
+
82
+ def training_step(
83
+ self,
84
+ batch: torch.Tensor,
85
+ batch_idx: any,
86
+ ) -> dict:
87
+
88
+ if isinstance(batch, list):
89
+ # split the
90
+ text_batch, protein_batch = batch
91
+
92
+ # forward pass
93
+ z_t, z_s = self(
94
+ x_t=text_batch,
95
+ x_s=protein_batch
96
+ )
97
+ dist.barrier()
98
+
99
+ # gather all tensors
100
+ z_t_all = self.all_gather(z_t, sync_grads=True)
101
+ dist.barrier()
102
+ z_s_all = self.all_gather(z_s, sync_grads=True)
103
+
104
+ # stack the embeddings
105
+ z_t_all = z_t_all.view(-1, z_t.shape[-1])
106
+ z_s_all = z_s_all.view(-1, z_s.shape[-1])
107
+
108
+ # compute loss values
109
+ loss, logits = self.model.compute_loss(
110
+ protein_embeddings=z_s_all,
111
+ text_embeddings=z_t_all
112
+ )
113
+
114
+ # track loss ...
115
+ self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
116
+
117
+ # track metrics
118
+ metric_dict = self.performance_metrics(logits=logits)
119
+ for key in metric_dict:
120
+ values = metric_dict[key]
121
+
122
+ final_key = 'train_' + key
123
+ self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True)
124
+
125
+ if batch_idx == 0:
126
+ gpu_memory_usage = helper_tools.print_gpu_initialization()
127
+ self.log(f'gpu_memory_usage', gpu_memory_usage, sync_dist=True)
128
+
129
+ return {'loss': loss}
130
+
131
+
132
+ def validation_step(
133
+ self,
134
+ batch: list,
135
+ batch_idx: any
136
+ ) -> dict:
137
+
138
+ # split the batch
139
+ if isinstance(batch, list):
140
+ # mean loss
141
+ text_batch, protein_batch = batch
142
+
143
+ # forward pass
144
+ z_t, z_s = self(
145
+ x_t=text_batch,
146
+ x_s=protein_batch
147
+ )
148
+
149
+ dist.barrier()
150
+ # gather all tensors
151
+ z_t_all = self.all_gather(z_t, sync_grads=True).view(-1, z_t.shape[-1])
152
+ dist.barrier()
153
+ z_s_all = self.all_gather(z_s, sync_grads=True).view(-1, z_s.shape[-1])
154
+
155
+ # stack the embeddings
156
+ z_t_all = z_t_all.view(-1, z_t.shape[-1])
157
+ z_s_all = z_s_all.view(-1, z_s.shape[-1])
158
+
159
+ # compute loss values
160
+ loss, logits = self.model.compute_loss(
161
+ protein_embeddings=z_s_all,
162
+ text_embeddings=z_t_all
163
+ )
164
+
165
+
166
+ # track validation loss ...
167
+ self.log('valid_loss', loss, prog_bar=True, sync_dist=True)
168
+
169
+ # copmute validation metrics
170
+ metric_dict = self.performance_metrics(logits=logits.detach().cpu())
171
+
172
+ for key in metric_dict:
173
+ values = metric_dict[key]
174
+ final_key = 'valid_' + key
175
+ self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, sync_dist=True)
176
+
177
+ # collect joint embedding
178
+ self.val_text_joint_latents.append(z_t_all.detach().cpu())
179
+ self.val_seq_joint_latents.append(z_s_all.detach().cpu())
180
+
181
+ return {'valid_loss': loss}
182
+
183
+ def on_validation_epoch_end(self):
184
+
185
+ # collect and aggregate outputs from all validation steps
186
+ val_z_t_joint = torch.cat(self.val_text_joint_latents, dim=0)
187
+ val_z_s_joint = torch.cat(self.val_seq_joint_latents, dim=0)
188
+
189
+ # compute singular values
190
+ text_log_sigma_k, S_text = self.compute_singular(val_z_t_joint.detach().cpu())
191
+ protein_log_sigma_k, S_protein = self.compute_singular(val_z_s_joint.detach().cpu())
192
+
193
+ # save image pngs for tracking dimensionality collapse
194
+ self.save_png_to_tensorboard(
195
+ data=text_log_sigma_k.numpy(),
196
+ title='text',
197
+ )
198
+ self.save_png_to_tensorboard(
199
+ data=protein_log_sigma_k.numpy(),
200
+ title='protein'
201
+ )
202
+
203
+ # free memory
204
+ self.val_text_joint_latents.clear()
205
+ self.val_seq_joint_latents.clear()
206
+
207
+
208
+ # compute effective rank (RankME):
209
+ erank_text = self.compute_effective_rank(sigma_ks=S_text)
210
+ erank_protein = self.compute_effective_rank(sigma_ks=S_protein)
211
+
212
+ # log erank metrics
213
+ self.log('valid_erank_text', erank_text, sync_dist=True)
214
+ self.log('valid_erank_protein', erank_protein, sync_dist=True)
215
+
216
+
217
+ def configure_optimizers(self,):
218
+
219
+ params = [
220
+ {"params": self.model.protein_encoder.parameters(), "lr": self.script_args.protein_encoder_lr},
221
+ {"params": self.model.text_encoder.parameters(), "lr": self.script_args.text_encoder_lr},
222
+ {"params": itertools.chain(
223
+ self.model.protein_projection.parameters(),
224
+ self.model.text_projection.parameters()
225
+ ),
226
+ "lr": self.script_args.head_lr,
227
+ "weight_decay": self.script_args.weight_decay}
228
+ ]
229
+
230
+ optimizer = torch.optim.AdamW(params, weight_decay=self.script_args.weight_decay)
231
+
232
+ return {
233
+ "optimizer": optimizer,
234
+ }
235
+
236
+ @torch.no_grad()
237
+ def compute_class_metrics(
238
+ self,
239
+ outputs: torch.Tensor,
240
+ targets: torch.Tensor,
241
+ source: str
242
+ ) -> dict:
243
+
244
+ # convert torch tensors to numpy array
245
+ outputs_np = outputs.numpy()
246
+ targets_np = targets.numpy()
247
+
248
+ # compute the metrics
249
+ accuracy = accuracy_score(targets_np, outputs_np.round())
250
+ precision = precision_score(targets_np, outputs_np.round(), average='micro')
251
+ recall = recall_score(targets_np, outputs_np.round(), average='micro')
252
+ f1 = f1_score(targets_np, outputs_np.round(), average='micro')
253
+
254
+ return {
255
+ f'{source}_accuracy': accuracy,
256
+ f'{source}_precision': precision,
257
+ f'{source}_recall': recall,
258
+ f'{source}_f1': f1
259
+ }
260
+
261
+ @torch.no_grad()
262
+ def performance_metrics(self, logits: torch.Tensor) -> tuple:
263
+
264
+ logits = logits.cpu().float()
265
+
266
+ # get probs
267
+ p_text = F.softmax(logits, dim=-1) # prob of a given text captions aligning well with seq. pairs
268
+ p_seq = F.softmax(logits.T, dim=-1) # prob of a given seq aligning well with text pairs
269
+ p_tot = (p_seq + p_text) / 2 # total prob
270
+
271
+ # get class labels
272
+ y_pred_text = torch.argmax(p_text, dim=-1)
273
+ y_pred_seq = torch.argmax(p_seq, dim=-1)
274
+ y_pred = torch.argmax(p_tot, dim=-1)
275
+ y_true = torch.arange(y_pred_text.shape[0])
276
+
277
+ # compute class metrics
278
+ text_metrics = self.compute_class_metrics(
279
+ outputs=y_pred_text,
280
+ targets=y_true,
281
+ source='text'
282
+ )
283
+ seq_metrics = self.compute_class_metrics(
284
+ outputs=y_pred_seq,
285
+ targets=y_true,
286
+ source='seq'
287
+ )
288
+ total_metrics = self.compute_class_metrics(
289
+ outputs=y_pred,
290
+ targets=y_true,
291
+ source='total'
292
+ )
293
+
294
+ # combine dicts into one
295
+ combined_dict = {}
296
+ combined_dict.update(text_metrics)
297
+ combined_dict.update(seq_metrics)
298
+ combined_dict.update(total_metrics)
299
+
300
+ return combined_dict
301
+
302
+ @torch.no_grad()
303
+ def compute_singular(self, inputs: torch.Tensor) -> (
304
+ torch.Tensor,
305
+ torch.Tensor
306
+ ):
307
+
308
+ # goal of this function: track for dimensionality collapse
309
+ # inputs dim: (batch_size, emb_dim)
310
+
311
+ mean_inputs = torch.mean(inputs, dim=0) # average over batch dimension
312
+ norm_inputs = inputs - mean_inputs # normalize vectors
313
+
314
+ # compute correlation matrix #TODO: double check work...
315
+ C = torch.zeros((norm_inputs.shape[-1], norm_inputs.shape[-1]))
316
+ for sample_idx in range(norm_inputs.shape[0]):
317
+ norm_vector = norm_inputs[sample_idx, :].unsqueeze(0)
318
+ C += norm_vector.T @ norm_vector
319
+ C *= 1/norm_vector.shape[0]
320
+
321
+ _, S, _ = torch.linalg.svd(C, full_matrices=False)
322
+
323
+ # return singular value indexes
324
+ log_sigma_k, _ = torch.sort(torch.log(S), descending=True)
325
+ return (
326
+ log_sigma_k,
327
+ S
328
+ )
329
+
330
+ def compute_effective_rank(self, sigma_ks: torch.Tensor) -> torch.Tensor:
331
+ """
332
+ references:
333
+ - Roy et al. The effective rank: a measure of effective dimensionality
334
+ - Garrido et al. RankMe: Assessing the Downstream Performnace of Pretrained SS Reps by their Rank.
335
+ """
336
+ # sort the singular values
337
+ sigma_ks, _ = torch.sort(sigma_ks, descending=True)
338
+
339
+ # copute L1 norm for sing values.
340
+ l1_norm_sigma = torch.norm(sigma_ks, p=1)
341
+
342
+ # compute singular value distribution
343
+ p_k = sigma_ks / l1_norm_sigma + torch.finfo(torch.float).eps
344
+
345
+ # compute Shannon entropy
346
+ entropy = - torch.sum(p_k * torch.log(p_k))
347
+
348
+ # get effective rank (RankME):
349
+ erank = torch.exp(entropy)
350
+
351
+ return erank
352
+
353
+ def save_png_to_tensorboard(
354
+ self,
355
+ data: np.single,
356
+ title: str,
357
+ x_axis_label: str='Singular Value Rank Index',
358
+ y_axis_label: str='Log of singular values',
359
+ ):
360
+
361
+ current_epoch = self.trainer.current_epoch
362
+
363
+ # Plot the line
364
+ fig, ax = plt.subplots(dpi=300)
365
+ ax.plot(data)
366
+ ax.set_xlabel(x_axis_label)
367
+ ax.set_ylabel(y_axis_label)
368
+ ax.set_title(title)
369
+ ax.set_ylim([-25,3])
370
+
371
+ # Log the plot in TensorBoard
372
+ self.logger.experiment.add_figure(f'{title}_SingularValues_{current_epoch}', fig, current_epoch)
373
+
374
+ def predict_step(
375
+ self,
376
+ batch: torch.Tensor,
377
+ batch_idx: torch.Tensor,
378
+ dataloder_idx: bool=False
379
+ ) -> (
380
+ torch.Tensor,
381
+ torch.Tensor
382
+ ):
383
+
384
+
385
+ if isinstance(batch, list):
386
+ # mean loss
387
+ text_batch, protein_batch = batch
388
+ outputs = self(
389
+ x_t=text_batch,
390
+ x_s=protein_batch,
391
+ )
392
+
393
+ z_t_joint, z_p_joint = outputs
394
+
395
+ self.predict_text_joint_latents.append(z_t_joint.detach().cpu())
396
+ self.predict_seq_joint_latents.append(z_p_joint.detach().cpu())
397
+
398
+ return outputs
399
+
400
+ def on_predict_epoch_end(self, outputs=None):
401
+
402
+ self.predict_text_joint_latents = torch.cat(self.predict_text_joint_latents).cpu()
403
+ self.predict_seq_joint_latents = torch.cat(self.predict_seq_joint_latents).cpu()
404
+
405
+
406
+
407
+ ##########################
408
+ # Masked-task PL wrapper #
409
+ ##########################
410
+
411
+ class mask_PL_PEN_CL(pl.LightningModule):
412
+
413
+
414
+ def __init__(
415
+ self,
416
+ args: any,
417
+ model: nn.Module,
418
+ text_tokenizer: any,
419
+ sequence_tokenizer: any
420
+ ):
421
+
422
+ super().__init__()
423
+ # arguments
424
+ self.script_args = args
425
+
426
+ # model components
427
+ self.model = model
428
+
429
+ # tokenizers
430
+ self.text_tokenizer = text_tokenizer
431
+ self.sequence_tokenizer = sequence_tokenizer
432
+
433
+ # validation tracker for outputs
434
+ self.val_text_joint_latents = []
435
+ self.val_seq_joint_latents = []
436
+
437
+ # prediction tracker for outputs
438
+ self.predict_text_joint_latents = []
439
+ self.predict_seq_joint_latents = []
440
+
441
+ def forward(
442
+ self,
443
+ x_t: torch.Tensor,
444
+ x_s: torch.Tensor,
445
+ compute_masked_logits: bool=False
446
+ ) -> (
447
+ torch.Tensor,
448
+ torch.Tensor,
449
+ torch.Tensor
450
+ ):
451
+
452
+ outputs = self.model(
453
+ x_t=x_t,
454
+ x_s=x_s,
455
+ compute_masked_logits=compute_masked_logits
456
+ )
457
+
458
+ if compute_masked_logits:
459
+ # forward pass for computing logits for masked language objective
460
+ return (
461
+ outputs['text_masked_logits'],
462
+ outputs['protein_masked_logits']
463
+ )
464
+ else:
465
+ # forward pass for computing latent embeddings in the joint space
466
+ return (
467
+ outputs['text_joint_latent'],
468
+ outputs['seq_joint_latent'],
469
+ )
470
+
471
+ def training_step(
472
+ self,
473
+ batch: torch.Tensor,
474
+ batch_idx: any,
475
+ ) -> dict:
476
+
477
+ if isinstance(batch, list):
478
+ # split the data
479
+ text_batch, protein_batch, text_mask_batch, protein_mask_batch = batch
480
+
481
+ # forward pass
482
+ z_t, z_s = self(
483
+ x_t=text_batch,
484
+ x_s=protein_batch,
485
+ compute_masked_logits=False
486
+ )
487
+ dist.barrier()
488
+
489
+ # gather all tensors
490
+ z_t_all = self.all_gather(z_t, sync_grads=True)
491
+ dist.barrier()
492
+ z_s_all = self.all_gather(z_s, sync_grads=True)
493
+
494
+ # stack the embeddings
495
+ z_t_all = z_t_all.view(-1, z_t.shape[-1])
496
+ z_s_all = z_s_all.view(-1, z_s.shape[-1])
497
+
498
+ # compute loss values
499
+ loss_align, logits = self.model.compute_loss(
500
+ protein_embeddings=z_s_all,
501
+ text_embeddings=z_t_all
502
+ )
503
+
504
+ # compute mask language model logits
505
+ logits_t_mask, logits_s_mask = self(
506
+ x_t=text_mask_batch,
507
+ x_s=protein_mask_batch,
508
+ compute_masked_logits=True
509
+ )
510
+
511
+ # compute mask language loss for biomedical expert model
512
+ loss_text_mask = self.model.compute_masked_lang_loss(
513
+ logits_masked=logits_t_mask,
514
+ targets=text_batch,
515
+ targets_masked=text_mask_batch,
516
+ mask_token_id=self.text_tokenizer.mask_token_id
517
+ )
518
+
519
+ # compute mask language loss for protein expert model
520
+ loss_sequence_mask = self.model.compute_masked_lang_loss(
521
+ logits_masked=logits_s_mask,
522
+ targets=protein_batch,
523
+ targets_masked=protein_mask_batch,
524
+ mask_token_id=self.sequence_tokenizer.mask_idx
525
+ )
526
+
527
+
528
+ # total loss
529
+ loss = loss_align + loss_text_mask + loss_sequence_mask
530
+
531
+ # track loss ...
532
+ self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
533
+ self.log('train_loss_align', loss_align, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
534
+ self.log('train_loss_text_mask', loss_text_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True)
535
+ self.log('train_loss_seq_mask', loss_sequence_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True)
536
+
537
+ # track metrics
538
+ metric_dict = self.performance_metrics(logits=logits)
539
+ for key in metric_dict:
540
+ values = metric_dict[key]
541
+
542
+ final_key = 'train_' + key
543
+ self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True)
544
+
545
+ if batch_idx == 0:
546
+ gpu_memory_usage = helper_tools.print_gpu_initialization()
547
+ self.log(f'gpu_memory_usage', gpu_memory_usage, sync_dist=True)
548
+
549
+ return {'loss': loss}
550
+
551
+
552
+ def validation_step(
553
+ self,
554
+ batch: list,
555
+ batch_idx: any
556
+ ) -> dict:
557
+
558
+ # split the batch
559
+ if isinstance(batch, list):
560
+ # mean loss
561
+ text_batch, protein_batch, text_mask_batch, protein_mask_batch = batch
562
+
563
+ # forward pass
564
+ z_t, z_s = self(
565
+ x_t=text_batch,
566
+ x_s=protein_batch
567
+ )
568
+
569
+ dist.barrier()
570
+ # gather all tensors
571
+ z_t_all = self.all_gather(z_t, sync_grads=True).view(-1, z_t.shape[-1])
572
+ dist.barrier()
573
+ z_s_all = self.all_gather(z_s, sync_grads=True).view(-1, z_s.shape[-1])
574
+
575
+ # stack the embeddings
576
+ z_t_all = z_t_all.view(-1, z_t.shape[-1])
577
+ z_s_all = z_s_all.view(-1, z_s.shape[-1])
578
+
579
+ # compute loss values
580
+ loss_align, logits = self.model.compute_loss(
581
+ protein_embeddings=z_s_all,
582
+ text_embeddings=z_t_all
583
+ )
584
+
585
+ # compute mask language model logits
586
+ logits_t_mask, logits_s_mask = self(
587
+ x_t=text_mask_batch,
588
+ x_s=protein_mask_batch,
589
+ compute_masked_logits=True
590
+ )
591
+
592
+ # compute mask language loss for biomedical expert model
593
+ loss_text_mask = self.model.compute_masked_lang_loss(
594
+ logits_masked=logits_t_mask,
595
+ targets=text_batch,
596
+ targets_masked=text_mask_batch,
597
+ mask_token_id=self.text_tokenizer.mask_token_id
598
+ )
599
+
600
+ # compute mask language loss for protein expert model
601
+ loss_sequence_mask = self.model.compute_masked_lang_loss(
602
+ logits_masked=logits_s_mask,
603
+ targets=protein_batch,
604
+ targets_masked=protein_mask_batch,
605
+ mask_token_id=self.sequence_tokenizer.mask_idx
606
+ )
607
+
608
+ # total loss
609
+ loss = loss_align + loss_text_mask + loss_sequence_mask
610
+
611
+ # track validation loss ...
612
+ self.log('valid_loss', loss, prog_bar=True, sync_dist=True)
613
+ self.log('valid_loss_align', loss_align, prog_bar=True, sync_dist=True)
614
+ self.log('valid_loss_text_mask', loss_text_mask, prog_bar=False, sync_dist=True)
615
+ self.log('valid_loss_seq_mask', loss_sequence_mask, prog_bar=False, sync_dist=True)
616
+
617
+ # copmute validation metrics
618
+ metric_dict = self.performance_metrics(logits=logits.detach().cpu())
619
+
620
+ for key in metric_dict:
621
+ values = metric_dict[key]
622
+ final_key = 'valid_' + key
623
+ self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, sync_dist=True)
624
+
625
+ # collect joint embedding
626
+ self.val_text_joint_latents.append(z_t_all.detach().cpu())
627
+ self.val_seq_joint_latents.append(z_s_all.detach().cpu())
628
+
629
+ return {'valid_loss': loss}
630
+
631
+ def on_validation_epoch_end(self):
632
+
633
+ # # collect and aggregate outputs from all validation steps
634
+ # val_z_t_joint = torch.cat(self.val_text_joint_latents, dim=0)
635
+ # val_z_s_joint = torch.cat(self.val_seq_joint_latents, dim=0)
636
+
637
+ # compute singular values
638
+ # text_log_sigma_k, S_text = self.compute_singular(val_z_t_joint.detach().cpu())
639
+ # protein_log_sigma_k, S_protein = self.compute_singular(val_z_s_joint.detach().cpu())
640
+
641
+ # save image pngs for tracking dimensionality collapse
642
+ # self.save_png_to_tensorboard(
643
+ # data=text_log_sigma_k.numpy(),
644
+ # title='text',
645
+ # )
646
+ # self.save_png_to_tensorboard(
647
+ # data=protein_log_sigma_k.numpy(),
648
+ # title='protein'
649
+ # )
650
+
651
+ # free memory
652
+ self.val_text_joint_latents.clear()
653
+ self.val_seq_joint_latents.clear()
654
+
655
+
656
+ # compute effective rank (RankME):
657
+ # erank_text = self.compute_effective_rank(sigma_ks=S_text)
658
+ # erank_protein = self.compute_effective_rank(sigma_ks=S_protein)
659
+
660
+ # log erank metrics
661
+ # self.log('valid_erank_text', erank_text, sync_dist=True)
662
+ # self.log('valid_erank_protein', erank_protein, sync_dist=True)
663
+
664
+
665
+ def configure_optimizers(self,):
666
+
667
+ params = [
668
+ {"params": self.model.protein_encoder.parameters(), "lr": self.script_args.protein_encoder_lr},
669
+ {"params": self.model.text_encoder.parameters(), "lr": self.script_args.text_encoder_lr},
670
+ {"params": itertools.chain(
671
+ self.model.protein_projection.parameters(),
672
+ self.model.text_projection.parameters()
673
+ ),
674
+ "lr": self.script_args.head_lr,
675
+ "weight_decay": self.script_args.weight_decay}
676
+ ]
677
+
678
+ optimizer = torch.optim.AdamW(params, weight_decay=self.script_args.weight_decay)
679
+
680
+ return {
681
+ "optimizer": optimizer,
682
+ }
683
+
684
+ @torch.no_grad()
685
+ def compute_class_metrics(
686
+ self,
687
+ outputs: torch.Tensor,
688
+ targets: torch.Tensor,
689
+ source: str
690
+ ) -> dict:
691
+
692
+ # convert torch tensors to numpy array
693
+ outputs_np = outputs.numpy()
694
+ targets_np = targets.numpy()
695
+
696
+ # compute the metrics
697
+ accuracy = accuracy_score(targets_np, outputs_np.round())
698
+ precision = precision_score(targets_np, outputs_np.round(), average='micro')
699
+ recall = recall_score(targets_np, outputs_np.round(), average='micro')
700
+ f1 = f1_score(targets_np, outputs_np.round(), average='micro')
701
+
702
+ return {
703
+ f'{source}_accuracy': accuracy,
704
+ f'{source}_precision': precision,
705
+ f'{source}_recall': recall,
706
+ f'{source}_f1': f1
707
+ }
708
+
709
+ @torch.no_grad()
710
+ def performance_metrics(self, logits: torch.Tensor) -> tuple:
711
+
712
+ logits = logits.cpu().float()
713
+
714
+ # get probs
715
+ p_text = F.softmax(logits, dim=-1) # prob of a given text captions aligning well with seq. pairs
716
+ p_seq = F.softmax(logits.T, dim=-1) # prob of a given seq aligning well with text pairs
717
+ p_tot = (p_seq + p_text) / 2 # total prob
718
+
719
+ # get class labels
720
+ y_pred_text = torch.argmax(p_text, dim=-1)
721
+ y_pred_seq = torch.argmax(p_seq, dim=-1)
722
+ y_pred = torch.argmax(p_tot, dim=-1)
723
+ y_true = torch.arange(y_pred_text.shape[0])
724
+
725
+ # compute class metrics
726
+ text_metrics = self.compute_class_metrics(
727
+ outputs=y_pred_text,
728
+ targets=y_true,
729
+ source='text'
730
+ )
731
+ seq_metrics = self.compute_class_metrics(
732
+ outputs=y_pred_seq,
733
+ targets=y_true,
734
+ source='seq'
735
+ )
736
+ total_metrics = self.compute_class_metrics(
737
+ outputs=y_pred,
738
+ targets=y_true,
739
+ source='total'
740
+ )
741
+
742
+ # combine dicts into one
743
+ combined_dict = {}
744
+ combined_dict.update(text_metrics)
745
+ combined_dict.update(seq_metrics)
746
+ combined_dict.update(total_metrics)
747
+
748
+ return combined_dict
749
+
750
+ @torch.no_grad()
751
+ def compute_singular(self, inputs: torch.Tensor) -> (
752
+ torch.Tensor,
753
+ torch.Tensor
754
+ ):
755
+
756
+ # goal of this function: track for dimensionality collapse
757
+ # inputs dim: (batch_size, emb_dim)
758
+
759
+ mean_inputs = torch.mean(inputs, dim=0) # average over batch dimension
760
+ norm_inputs = inputs - mean_inputs # normalize vectors
761
+
762
+ # compute correlation matrix #TODO: double check work...
763
+ C = torch.zeros((norm_inputs.shape[-1], norm_inputs.shape[-1]))
764
+ for sample_idx in tqdm(range(norm_inputs.shape[0])):
765
+ norm_vector = norm_inputs[sample_idx, :].unsqueeze(0)
766
+ C += norm_vector.T @ norm_vector
767
+ C *= 1/norm_vector.shape[0]
768
+
769
+ _, S, _ = torch.linalg.svd(C, full_matrices=False)
770
+
771
+ # return singular value indexes
772
+ log_sigma_k, _ = torch.sort(torch.log(S), descending=True)
773
+ return (
774
+ log_sigma_k,
775
+ S
776
+ )
777
+
778
+ def compute_effective_rank(self, sigma_ks: torch.Tensor) -> torch.Tensor:
779
+ """
780
+ references:
781
+ - Roy et al. The effective rank: a measure of effective dimensionality
782
+ - Garrido et al. RankMe: Assessing the Downstream Performnace of Pretrained SS Reps by their Rank.
783
+ """
784
+ # sort the singular values
785
+ sigma_ks, _ = torch.sort(sigma_ks, descending=True)
786
+
787
+ # copute L1 norm for sing values.
788
+ l1_norm_sigma = torch.norm(sigma_ks, p=1)
789
+
790
+ # compute singular value distribution
791
+ p_k = sigma_ks / l1_norm_sigma + torch.finfo(torch.float).eps
792
+
793
+ # compute Shannon entropy
794
+ entropy = - torch.sum(p_k * torch.log(p_k))
795
+
796
+ # get effective rank (RankME):
797
+ erank = torch.exp(entropy)
798
+
799
+ return erank
800
+
801
+ def save_png_to_tensorboard(
802
+ self,
803
+ data: np.single,
804
+ title: str,
805
+ x_axis_label: str='Singular Value Rank Index',
806
+ y_axis_label: str='Log of singular values',
807
+ ):
808
+
809
+ current_epoch = self.trainer.current_epoch
810
+
811
+ # Plot the line
812
+ fig, ax = plt.subplots(dpi=300)
813
+ ax.plot(data)
814
+ ax.set_xlabel(x_axis_label)
815
+ ax.set_ylabel(y_axis_label)
816
+ ax.set_title(title)
817
+ ax.set_ylim([-25,3])
818
+
819
+ # Log the plot in TensorBoard
820
+ self.logger.experiment.add_figure(f'{title}_SingularValues_{current_epoch}', fig, current_epoch)
821
+
822
+ def predict_step(
823
+ self,
824
+ batch: torch.Tensor,
825
+ batch_idx: torch.Tensor,
826
+ dataloder_idx: bool=False
827
+ ) -> (
828
+ torch.Tensor,
829
+ torch.Tensor
830
+ ):
831
+
832
+
833
+ if isinstance(batch, list):
834
+ # mean loss
835
+ text_batch, protein_batch = batch
836
+ outputs = self(
837
+ x_t=text_batch,
838
+ x_s=protein_batch,
839
+ compute_masked_logits=False
840
+ )
841
+
842
+ z_t_joint, z_p_joint = outputs
843
+
844
+ self.predict_text_joint_latents.append(z_t_joint.detach().cpu())
845
+ self.predict_seq_joint_latents.append(z_p_joint.detach().cpu())
846
+
847
+ return outputs
848
+
849
+ def on_predict_epoch_end(self, outputs=None):
850
+
851
+ self.predict_text_joint_latents = torch.cat(self.predict_text_joint_latents).cpu()
852
+ self.predict_seq_joint_latents = torch.cat(self.predict_seq_joint_latents).cpu()
853
+
854
+
855
+
856
+ ########################
857
+ # Pfam-task PL wrapper #
858
+ ########################
859
+
860
+
861
+ class pfam_PL_PEN_CL(pl.LightningModule):
862
+
863
+
864
+ def __init__(
865
+ self,
866
+ args: any,
867
+ model: nn.Module,
868
+ text_tokenizer: any,
869
+ sequence_tokenizer: any
870
+ ):
871
+
872
+ super().__init__()
873
+ # arguments
874
+ self.script_args = args
875
+
876
+ # model components
877
+ self.model = model
878
+
879
+ # tokenizers
880
+ self.text_tokenizer = text_tokenizer
881
+ self.sequence_tokenizer = sequence_tokenizer
882
+
883
+ # validation tracker for outputs
884
+ self.val_text_joint_latents = []
885
+ self.val_seq_joint_latents = []
886
+
887
+ # predictions...
888
+ self.predict_text_joint_latents = []
889
+ self.predict_seq_joint_latents = []
890
+
891
+ def forward(
892
+ self,
893
+ x_t: torch.Tensor,
894
+ x_p: torch.Tensor,
895
+ compute_masked_logits: bool=False
896
+ ) -> (
897
+ torch.Tensor,
898
+ torch.Tensor,
899
+ torch.Tensor
900
+ ):
901
+
902
+ outputs = self.model(
903
+ x_t=x_t,
904
+ x_s=x_p,
905
+ compute_masked_logits=compute_masked_logits
906
+ )
907
+
908
+ if compute_masked_logits:
909
+ # forward pass for computing logits for masked language objective
910
+ return (
911
+ outputs['text_masked_logits'],
912
+ outputs['protein_masked_logits']
913
+ )
914
+ else:
915
+ # forward pass for computing latent embeddings in the joint space
916
+ return (
917
+ outputs['text_joint_latent'],
918
+ outputs['seq_joint_latent'],
919
+ )
920
+
921
+
922
+ def on_train_batch_start(self, batch, batch_idx):
923
+ self.batch_start_time = time.time()
924
+
925
+ def on_train_batch_end(self, outputs, batch, batch_idx):
926
+ batch_end_time = time.time()
927
+ batch_time = batch_end_time - self.batch_start_time
928
+ #print(f'Rank={dist.get_rank()}: time to process batch is {batch_time}')
929
+ #self.log(f'batch_time_rank_{dist.get_rank()}', batch_time, on_step=True, on_epoch=False)
930
+
931
+ def training_step(self, batch: torch.Tensor, batch_idx: any) -> dict:
932
+ """
933
+ Execute a single training step.
934
+
935
+ Given a batch of data, this function processes both Swiss-Prot and Pfam data through the model, computes
936
+ various loss values including inter-modal, intra-modal, and masked language model losses for both text
937
+ and protein sequences. This function also computes and logs various metrics and GPU memory usage.
938
+
939
+ Parameters:
940
+ - batch: The input data batch. This can include multiple types of data.
941
+ - batch_idx: Index of the current batch.
942
+
943
+ Steps:
944
+ 1. Split the data into Swiss-Prot and Pfam batches, if the batch is a list.
945
+ 2. Forward pass the Swiss-Prot data through the model.
946
+ 3. Synchronize and gather embeddings from all GPUs.
947
+ 4. Forward pass the Pfam data through the model.
948
+ 5. Synchronize and gather Pfam embeddings from all GPUs.
949
+ 6. Concatenate Swiss-Prot and Pfam embeddings.
950
+ 7. Compute inter-modal and intra-modal loss values.
951
+ 8. Compute masked language model logits for the concatenated batch.
952
+ 9. Compute masked language loss for both text and protein sequences.
953
+ 10. Compute and log the total loss and individual loss components.
954
+ 11. Compute and log performance metrics.
955
+ 12. Log GPU memory usage at the start of training.
956
+
957
+ Returns:
958
+ - Dictionary containing the total loss value.
959
+
960
+ Note:
961
+ This function is intended to be used within a distributed (multi-GPU) training context, as evident
962
+ from the use of barriers and gathering operations. It's designed to handle batches that contain both
963
+ Swiss-Prot and Pfam data, both being biological datasets used in multi-modal protein embeddings.
964
+ The function utilizes both inter-modal (between modalities) and intra-modal (within the same modality)
965
+ contrastive losses, as well as masked language modeling objectives similar to BERT's MLM objective.
966
+ """
967
+
968
+ # Check if the batch is a list and split data if so.
969
+ if isinstance(batch, list):
970
+ text_batch, protein_batch, text_mask_batch, protein_mask_batch, \
971
+ pfam_text_batch, pfam_protein_batch, pfam_text_mask_batch, pfam_protein_mask_batch, \
972
+ bool_pfam_vector = batch
973
+
974
+
975
+ #print(f'rank={dist.get_rank()}: text size {text_batch.shape}')
976
+
977
+ #start_time_forward_pass = time.time()
978
+ # Forward pass with Swiss-Prot data.
979
+ z_t_swiss, z_p_swiss = self(
980
+ x_t=text_batch,
981
+ x_p=protein_batch,
982
+ compute_masked_logits=False
983
+ )
984
+ # Timer end and log
985
+ #end_time_forward_pass = time.time()
986
+ #print(f"Rank={dist.get_rank()}: Time taken for Swiss-Prot forward pass: {end_time_forward_pass - start_time_forward_pass} seconds.")
987
+
988
+ # Ensure all GPUs are synchronized.
989
+ dist.barrier()
990
+
991
+ # Forward pass with Pfam data.
992
+ z_t_pfam, z_p_pfam = self(
993
+ x_t=pfam_text_batch,
994
+ x_p=pfam_protein_batch,
995
+ compute_masked_logits=False
996
+ )
997
+ dist.barrier()
998
+
999
+ #Gather tensors from all GPUs.
1000
+ z_t_swiss_all = self.all_gather(z_t_swiss, sync_grads=True)
1001
+ dist.barrier()
1002
+ z_p_swiss_all = self.all_gather(z_p_swiss, sync_grads=True)
1003
+
1004
+ # Reshape the embeddings.
1005
+ z_t_swiss_all = z_t_swiss_all.view(-1, z_t_swiss.shape[-1])
1006
+ z_p_swiss_all = z_p_swiss_all.view(-1, z_p_swiss.shape[-1])
1007
+
1008
+
1009
+ # Gather tensors from all GPUs.
1010
+ z_t_pfam_all = self.all_gather(z_t_pfam, sync_grads=True)
1011
+ dist.barrier()
1012
+ z_p_pfam_all = self.all_gather(z_p_pfam, sync_grads=True)
1013
+
1014
+ # Reshape the embeddings.
1015
+ z_t_pfam_all = z_t_pfam_all.view(-1, z_t_pfam.shape[-1])
1016
+ z_p_pfam_all = z_p_pfam_all.view(-1, z_p_pfam.shape[-1])
1017
+
1018
+ # Concatenate Swiss-Prot and Pfam embeddings.
1019
+ z_t_all = torch.cat((z_t_swiss_all, z_t_pfam_all), dim=0)
1020
+ z_p_all = torch.cat((z_p_swiss_all, z_p_pfam_all), dim=0)
1021
+
1022
+ # Timer start
1023
+ #start_time_loss_computation = time.time()
1024
+
1025
+ # Compute inter-modal loss.
1026
+ loss_align, logits = self.model.compute_inter_loss(
1027
+ protein_embeddings=z_p_all,
1028
+ text_embeddings=z_t_all,
1029
+ batch_size=z_p_all.shape[0] // 2
1030
+ )
1031
+ # Timer end and log
1032
+ #end_time_loss_computation = time.time()
1033
+ #print(f"Rank={dist.get_rank()}: Time taken for loss computation: {end_time_loss_computation - start_time_loss_computation} seconds.")
1034
+
1035
+
1036
+ # Compute intra-modal loss.
1037
+ loss_intra, cosine_similarity = self.model.compute_intra_loss(
1038
+ protein_embeddings=z_p_all,
1039
+ batch_size=z_p_all.shape[0] // 2
1040
+ )
1041
+
1042
+ # Concatenate batches for masked language modeling.
1043
+ all_text_batch = torch.cat((text_batch, pfam_text_batch), dim=0)
1044
+ all_protein_batch = torch.cat((protein_batch, pfam_protein_batch), dim=0)
1045
+ all_text_mask_batch = torch.cat((text_mask_batch, pfam_text_mask_batch), dim=0)
1046
+ all_protein_mask_batch = torch.cat((protein_mask_batch, pfam_protein_mask_batch), dim=0)
1047
+
1048
+ #TODO: timer start
1049
+ #start_time_mask_comp = time.time()
1050
+
1051
+ # Compute masked language model logits.
1052
+ logits_t_mask, logits_s_mask = self(
1053
+ x_t=all_text_mask_batch,
1054
+ x_p=all_protein_mask_batch,
1055
+ compute_masked_logits=True
1056
+ )
1057
+ #end_time_mask_comp = time.time()
1058
+ #print(f"Rank={dist.get_rank()}: Time taken for mask predictions: {end_time_mask_comp - start_time_mask_comp} seconds.")
1059
+
1060
+
1061
+ # Compute masked language model loss for text data.
1062
+ loss_text_mask = self.model.compute_masked_lang_loss(
1063
+ logits_masked=logits_t_mask,
1064
+ targets=all_text_batch,
1065
+ targets_masked=all_text_mask_batch,
1066
+ mask_token_id=self.text_tokenizer.mask_token_id
1067
+ )
1068
+
1069
+ # Compute masked language model loss for protein data.
1070
+ loss_sequence_mask = self.model.compute_masked_lang_loss(
1071
+ logits_masked=logits_s_mask,
1072
+ targets=all_protein_batch,
1073
+ targets_masked=all_protein_mask_batch,
1074
+ mask_token_id=self.sequence_tokenizer.mask_idx
1075
+ )
1076
+
1077
+
1078
+ if self.script_args.dataset_type == 'pfam':
1079
+ # Aggregate all computed losses.
1080
+ loss = loss_align + loss_intra + loss_text_mask + loss_sequence_mask
1081
+
1082
+ elif self.script_args.dataset_type == 'pfam_ablated':
1083
+ # Aggregate all losses besides PFC.
1084
+ loss = loss_align + loss_text_mask + loss_sequence_mask
1085
+ else:
1086
+ # Add an assertion here
1087
+ assert self.script_args.dataset_type in ['pfam', 'pfam_ablated'], "Unexpected dataset_type value"
1088
+ sys.stderr.write("Unexpected dataset_type value\n")
1089
+ sys.exit(1)
1090
+
1091
+ # Log the individual and total loss values.
1092
+ self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1093
+ self.log('train_loss_align', loss_align, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1094
+ self.log('train_loss_intra', loss_intra, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1095
+ self.log('train_loss_text_mask', loss_text_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True)
1096
+ self.log('train_loss_seq_mask', loss_sequence_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True)
1097
+
1098
+ # Compute and log additional performance metrics.
1099
+ metric_dict = self.performance_metrics(logits=logits)
1100
+ for key in metric_dict:
1101
+ values = metric_dict[key]
1102
+ final_key = 'train_' + key
1103
+ self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True)
1104
+
1105
+ # Log GPU memory usage at the beginning of the training.
1106
+ if batch_idx == 0:
1107
+ gpu_memory_usage = helper_tools.print_gpu_initialization()
1108
+ self.log(f'gpu_memory_usage', gpu_memory_usage, sync_dist=True)
1109
+
1110
+ # log CPU memory
1111
+ memory_usage = helper_tools.print_memory_usage()
1112
+ self.log(f'memory_usage', memory_usage, sync_dist=True)
1113
+
1114
+ return {'loss': loss}
1115
+
1116
+
1117
+ def validation_step(
1118
+ self,
1119
+ batch: torch.Tensor,
1120
+ batch_idx: any,
1121
+ ) -> dict:
1122
+
1123
+ """
1124
+ `validation_step()`: Validates a single batch of data and computes loss and performance metrics.
1125
+
1126
+ Parameters:
1127
+ - `self`: Reference to the current instance of the model or module.
1128
+ - `batch`: Input data, which might contain text and protein sequences, their corresponding masks, and additional data from both Swiss-Prot and Pfam datasets.
1129
+ - `batch_idx`: Identifier for the current batch.
1130
+
1131
+ Functionality:
1132
+ 1. Extracts and processes data from the given batch.
1133
+ 2. Computes embeddings for Swiss-Prot and Pfam datasets.
1134
+ 3. Concatenates these embeddings to form a unified representation.
1135
+ 4. Computes various loss values: inter-modal, intra-modal, and masked language losses for both biomedical texts and protein sequences.
1136
+ 5. Logs the computed loss values and other performance metrics, highlighting metrics such as F1-score.
1137
+ 6. Collects and appends the joint embeddings of the batch for potential future use.
1138
+
1139
+ Returns:
1140
+ - A dictionary with the total validation loss for the current batch.
1141
+ """
1142
+
1143
+ if isinstance(batch, list):
1144
+ # split the data
1145
+ text_batch, protein_batch, text_mask_batch, protein_mask_batch, \
1146
+ pfam_text_batch, pfam_protein_batch, pfam_text_mask_batch, pfam_protein_mask_batch, \
1147
+ bool_pfam_vector = batch
1148
+
1149
+
1150
+ # forward pass over the swiss-prot data
1151
+ z_t_swiss, z_p_swiss = self(
1152
+ x_t=text_batch,
1153
+ x_p=protein_batch,
1154
+ compute_masked_logits=False
1155
+ )
1156
+ dist.barrier() # wait till all GPUs catch up...
1157
+
1158
+ # gather all tensors
1159
+ z_t_swiss_all = self.all_gather(z_t_swiss, sync_grads=True)
1160
+ dist.barrier()
1161
+ z_p_swiss_all = self.all_gather(z_p_swiss, sync_grads=True)
1162
+
1163
+ # stack the embeddings
1164
+ z_t_swiss_all = z_t_swiss_all.view(-1, z_t_swiss.shape[-1])
1165
+ z_p_swiss_all = z_p_swiss_all.view(-1, z_p_swiss.shape[-1])
1166
+
1167
+ # foward pass over the pfam data
1168
+ z_t_pfam, z_p_pfam = self(
1169
+ x_t=pfam_text_batch,
1170
+ x_p=pfam_protein_batch,
1171
+ compute_masked_logits=False
1172
+ )
1173
+ dist.barrier() # wait till all GPUs catch up...
1174
+
1175
+ # gather all tensors
1176
+ z_t_pfam_all = self.all_gather(z_t_pfam, sync_grads=True)
1177
+ dist.barrier()
1178
+ z_p_pfam_all = self.all_gather(z_p_pfam, sync_grads=True)
1179
+
1180
+ # stack the embeddings
1181
+ z_t_pfam_all = z_t_pfam_all.view(-1, z_t_pfam.shape[-1])
1182
+ z_p_pfam_all = z_p_pfam_all.view(-1, z_p_pfam.shape[-1])
1183
+
1184
+ # concatenate swiss-prot <> pfam embeddings
1185
+ z_t_all = torch.cat((z_t_swiss_all, z_t_pfam_all), dim=0)
1186
+ z_p_all = torch.cat((z_p_swiss_all, z_p_pfam_all), dim=0)
1187
+
1188
+ # compute inter-modal loss values
1189
+ loss_align, logits = self.model.compute_inter_loss(
1190
+ protein_embeddings=z_p_all,
1191
+ text_embeddings=z_t_all,
1192
+ batch_size=z_p_all.shape[0] // 2
1193
+ )
1194
+
1195
+ # compute intra-modal loss values
1196
+ loss_intra, cosine_similarity = self.model.compute_intra_loss(
1197
+ protein_embeddings=z_p_all,
1198
+ batch_size=z_p_all.shape[0] // 2
1199
+ )
1200
+
1201
+ # concatenate batch samples
1202
+ all_text_batch = torch.cat((text_batch, pfam_text_batch), dim=0)
1203
+ all_protein_batch = torch.cat((protein_batch, pfam_protein_batch), dim=0)
1204
+ all_text_mask_batch = torch.cat((text_mask_batch, pfam_text_mask_batch), dim=0)
1205
+ all_protein_mask_batch = torch.cat((protein_mask_batch, pfam_protein_mask_batch), dim=0)
1206
+
1207
+ # compute mask language model logits
1208
+ logits_t_mask, logits_s_mask = self(
1209
+ x_t=all_text_mask_batch,
1210
+ x_p=all_protein_mask_batch,
1211
+ compute_masked_logits=True
1212
+ )
1213
+
1214
+ # compute mask language loss for biomedical expert model
1215
+ loss_text_mask = self.model.compute_masked_lang_loss(
1216
+ logits_masked=logits_t_mask,
1217
+ targets=all_text_batch,
1218
+ targets_masked=all_text_mask_batch,
1219
+ mask_token_id=self.text_tokenizer.mask_token_id
1220
+ )
1221
+
1222
+ # compute mask language loss for protein expert model
1223
+ loss_sequence_mask = self.model.compute_masked_lang_loss(
1224
+ logits_masked=logits_s_mask,
1225
+ targets=all_protein_batch,
1226
+ targets_masked=all_protein_mask_batch,
1227
+ mask_token_id=self.sequence_tokenizer.mask_idx
1228
+ )
1229
+
1230
+
1231
+ # total loss
1232
+ #loss = loss_align + loss_intra + loss_text_mask + loss_sequence_mask
1233
+
1234
+ if self.script_args.dataset_type == 'pfam':
1235
+ # Aggregate all computed losses.
1236
+ loss = loss_align + loss_intra + loss_text_mask + loss_sequence_mask
1237
+
1238
+ elif self.script_args.dataset_type == 'pfam_ablated':
1239
+ # Aggregate all losses besides PFC.
1240
+ loss = loss_align + loss_text_mask + loss_sequence_mask
1241
+ else:
1242
+ # Add an assertion here
1243
+ assert self.script_args.dataset_type in ['pfam', 'pfam_ablated'], "Unexpected dataset_type value"
1244
+ sys.stderr.write("Unexpected dataset_type value\n")
1245
+ sys.exit(1)
1246
+
1247
+
1248
+ # track loss ...
1249
+ self.log('valid_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1250
+ self.log('valid_loss_align', loss_align, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1251
+ self.log('valid_loss_intra', loss_intra, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1252
+ self.log('valid_loss_text_mask', loss_text_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True)
1253
+ self.log('valid_loss_seq_mask', loss_sequence_mask, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True)
1254
+ # log CPU memory
1255
+ memory_usage = helper_tools.print_memory_usage()
1256
+ self.log(f'memory_usage', memory_usage, sync_dist=True)
1257
+
1258
+ # track metrics
1259
+ metric_dict = self.performance_metrics(logits=logits.detach().cpu())
1260
+ for key in metric_dict:
1261
+ values = metric_dict[key]
1262
+ final_key = 'valid_' + key
1263
+ self.log(final_key, metric_dict[key], prog_bar=True if 'f1' in key else False, on_step=True, on_epoch=True, sync_dist=True)
1264
+
1265
+
1266
+ # collect joint embedding
1267
+ #self.val_text_joint_latents.append(z_t_all.detach().cpu())
1268
+ #self.val_seq_joint_latents.append(z_p_all.detach().cpu())
1269
+
1270
+ return {'valid_loss': loss}
1271
+
1272
+
1273
+ # def on_validation_epoch_end(self):
1274
+ # print('Enter validation end of epoch analysis...')
1275
+ #
1276
+ # # collect and aggregate outputs from all validation steps
1277
+ # val_z_t_joint = torch.cat(self.val_text_joint_latents, dim=0)
1278
+ # val_z_s_joint = torch.cat(self.val_seq_joint_latents, dim=0)
1279
+ #
1280
+ # # compute singular values
1281
+ # print('Compute singular values...')
1282
+ # text_log_sigma_k, S_text = self.compute_singular(val_z_t_joint.detach().cpu())
1283
+ # protein_log_sigma_k, S_protein = self.compute_singular(val_z_s_joint.detach().cpu())
1284
+ #
1285
+ # # save image pngs for tracking dimensionality collapse
1286
+ # self.save_png_to_tensorboard(
1287
+ # data=text_log_sigma_k.numpy(),
1288
+ # title='text',
1289
+ # )
1290
+ # self.save_png_to_tensorboard(
1291
+ # data=protein_log_sigma_k.numpy(),
1292
+ # title='protein'
1293
+ # )
1294
+ #
1295
+ # # free memory
1296
+ # self.val_text_joint_latents.clear()
1297
+ # self.val_seq_joint_latents.clear()
1298
+ #
1299
+ #
1300
+ # # compute effective rank (RankME):
1301
+ # print('Compute eranks')
1302
+ # erank_text = self.compute_effective_rank(sigma_ks=S_text)
1303
+ # erank_protein = self.compute_effective_rank(sigma_ks=S_protein)
1304
+ #
1305
+ # # log erank metrics
1306
+ # self.log('valid_erank_text', erank_text, sync_dist=True)
1307
+ # self.log('valid_erank_protein', erank_protein, sync_dist=True)
1308
+
1309
+ def configure_optimizers(self,):
1310
+
1311
+ params = [
1312
+ {"params": self.model.protein_encoder.parameters(), "lr": self.script_args.protein_encoder_lr},
1313
+ {"params": self.model.text_encoder.parameters(), "lr": self.script_args.text_encoder_lr},
1314
+ {"params": itertools.chain(
1315
+ self.model.protein_projection.parameters(),
1316
+ self.model.text_projection.parameters()
1317
+ ),
1318
+ "lr": self.script_args.head_lr,
1319
+ "weight_decay": self.script_args.weight_decay}
1320
+ ]
1321
+
1322
+ optimizer = torch.optim.AdamW(params, weight_decay=self.script_args.weight_decay)
1323
+
1324
+ return {
1325
+ "optimizer": optimizer,
1326
+ }
1327
+
1328
+ @torch.no_grad()
1329
+ def compute_class_metrics(
1330
+ self,
1331
+ outputs: torch.Tensor,
1332
+ targets: torch.Tensor,
1333
+ source: str
1334
+ ) -> dict:
1335
+
1336
+ # convert torch tensors to numpy array
1337
+ outputs_np = outputs.numpy()
1338
+ targets_np = targets.numpy()
1339
+
1340
+ # compute the metrics
1341
+ accuracy = accuracy_score(targets_np, outputs_np.round())
1342
+ precision = precision_score(targets_np, outputs_np.round(), average='micro')
1343
+ recall = recall_score(targets_np, outputs_np.round(), average='micro')
1344
+ f1 = f1_score(targets_np, outputs_np.round(), average='micro')
1345
+
1346
+ return {
1347
+ f'{source}_accuracy': accuracy,
1348
+ f'{source}_precision': precision,
1349
+ f'{source}_recall': recall,
1350
+ f'{source}_f1': f1
1351
+ }
1352
+
1353
+ @torch.no_grad()
1354
+ def performance_metrics(self, logits: torch.Tensor) -> tuple:
1355
+
1356
+ logits = logits.cpu().float()
1357
+
1358
+ # get probs
1359
+ p_text = F.softmax(logits, dim=-1) # prob of a given text captions aligning well with seq. pairs
1360
+ p_seq = F.softmax(logits.T, dim=-1) # prob of a given seq aligning well with text pairs
1361
+ p_tot = (p_seq + p_text) / 2 # total prob
1362
+
1363
+ # get class labels
1364
+ y_pred_text = torch.argmax(p_text, dim=-1)
1365
+ y_pred_seq = torch.argmax(p_seq, dim=-1)
1366
+ y_pred = torch.argmax(p_tot, dim=-1)
1367
+ y_true = torch.arange(y_pred_text.shape[0])
1368
+
1369
+ # compute class metrics
1370
+ text_metrics = self.compute_class_metrics(
1371
+ outputs=y_pred_text,
1372
+ targets=y_true,
1373
+ source='text'
1374
+ )
1375
+ seq_metrics = self.compute_class_metrics(
1376
+ outputs=y_pred_seq,
1377
+ targets=y_true,
1378
+ source='seq'
1379
+ )
1380
+ total_metrics = self.compute_class_metrics(
1381
+ outputs=y_pred,
1382
+ targets=y_true,
1383
+ source='total'
1384
+ )
1385
+
1386
+ # combine dicts into one
1387
+ combined_dict = {}
1388
+ combined_dict.update(text_metrics)
1389
+ combined_dict.update(seq_metrics)
1390
+ combined_dict.update(total_metrics)
1391
+
1392
+ return combined_dict
1393
+
1394
+ @torch.no_grad()
1395
+ def compute_singular(self, inputs: torch.Tensor) -> (
1396
+ torch.Tensor,
1397
+ torch.Tensor
1398
+ ):
1399
+
1400
+ # goal of this function: track for dimensionality collapse
1401
+ # inputs dim: (batch_size, emb_dim)
1402
+
1403
+ mean_inputs = torch.mean(inputs, dim=0) # average over batch dimension
1404
+ norm_inputs = inputs - mean_inputs # normalize vectors
1405
+
1406
+ # compute correlation matrix #TODO: double check work...
1407
+ C = torch.zeros((norm_inputs.shape[-1], norm_inputs.shape[-1]))
1408
+ for sample_idx in range(norm_inputs.shape[0]):
1409
+ norm_vector = norm_inputs[sample_idx, :].unsqueeze(0)
1410
+ C += norm_vector.T @ norm_vector
1411
+ C *= 1/norm_vector.shape[0]
1412
+
1413
+ _, S, _ = torch.linalg.svd(C, full_matrices=False)
1414
+
1415
+ # return singular value indexes
1416
+ log_sigma_k, _ = torch.sort(torch.log(S), descending=True)
1417
+ return (
1418
+ log_sigma_k,
1419
+ S
1420
+ )
1421
+
1422
+ def compute_effective_rank(self, sigma_ks: torch.Tensor) -> torch.Tensor:
1423
+ """
1424
+ references:
1425
+ - Roy et al. The effective rank: a measure of effective dimensionality
1426
+ - Garrido et al. RankMe: Assessing the Downstream Performnace of Pretrained SS Reps by their Rank.
1427
+ """
1428
+ # sort the singular values
1429
+ sigma_ks, _ = torch.sort(sigma_ks, descending=True)
1430
+
1431
+ # copute L1 norm for sing values.
1432
+ l1_norm_sigma = torch.norm(sigma_ks, p=1)
1433
+
1434
+ # compute singular value distribution
1435
+ p_k = sigma_ks / l1_norm_sigma + torch.finfo(torch.float).eps
1436
+
1437
+ # compute Shannon entropy
1438
+ entropy = - torch.sum(p_k * torch.log(p_k))
1439
+
1440
+ # get effective rank (RankME):
1441
+ erank = torch.exp(entropy)
1442
+
1443
+ return erank
1444
+
1445
+ def save_png_to_tensorboard(
1446
+ self,
1447
+ data: np.single,
1448
+ title: str,
1449
+ x_axis_label: str='Singular Value Rank Index',
1450
+ y_axis_label: str='Log of singular values',
1451
+ ):
1452
+
1453
+ current_epoch = self.trainer.current_epoch
1454
+
1455
+ # Plot the line
1456
+ fig, ax = plt.subplots(dpi=300)
1457
+ ax.plot(data)
1458
+ ax.set_xlabel(x_axis_label)
1459
+ ax.set_ylabel(y_axis_label)
1460
+ ax.set_title(title)
1461
+ ax.set_ylim([-25,3])
1462
+
1463
+ # Log the plot in TensorBoard
1464
+ self.logger.experiment.add_figure(f'{title}_SingularValues_{current_epoch}', fig, current_epoch)
1465
+
1466
+ # Close the figure to free up memory
1467
+ plt.close(fig)
1468
+
1469
+ def predict_step(
1470
+ self,
1471
+ batch: torch.Tensor,
1472
+ batch_idx: torch.Tensor,
1473
+ dataloder_idx: bool=False
1474
+ ) -> (
1475
+ torch.Tensor,
1476
+ torch.Tensor
1477
+ ):
1478
+
1479
+
1480
+ if isinstance(batch, list):
1481
+ # mean loss
1482
+ text_batch, protein_batch = batch
1483
+ outputs = self(
1484
+ x_t=text_batch,
1485
+ x_p=protein_batch,
1486
+ compute_masked_logits=False
1487
+ )
1488
+
1489
+ z_t_joint, z_p_joint = outputs
1490
+
1491
+ self.predict_text_joint_latents.append(z_t_joint.detach().cpu())
1492
+ self.predict_seq_joint_latents.append(z_p_joint.detach().cpu())
1493
+
1494
+ return outputs
1495
+
1496
+ def on_predict_epoch_end(self, outputs=None):
1497
+
1498
+ self.predict_text_joint_latents = torch.cat(self.predict_text_joint_latents).cpu()
1499
+ self.predict_seq_joint_latents = torch.cat(self.predict_seq_joint_latents).cpu()
1500
+
1501
+
1502
+ ##########################
1503
+ # Facilitator PL wrapper #
1504
+ ##########################
1505
+
1506
+ class PL_Facilitator(pl.LightningModule):
1507
+
1508
+ def __init__(
1509
+ self,
1510
+ args: any
1511
+ ):
1512
+
1513
+ super().__init__()
1514
+
1515
+ # arguments
1516
+ self.args = args
1517
+
1518
+ # model
1519
+ self.model = mod.Facilitator(
1520
+ in_dim=self.args.emb_dim,
1521
+ hid_dim=self.args.hid_dim,
1522
+ out_dim=self.args.emb_dim,
1523
+ dropout=self.args.dropout
1524
+ )
1525
+
1526
+ self.text_to_protein_joint_embeddings = []
1527
+
1528
+ def forward(
1529
+ self,
1530
+ z_t: torch.Tensor,
1531
+ ) -> torch.Tensor:
1532
+
1533
+ # reconfigure z_t to z_p (additional alignment)
1534
+ z_t_to_p = self.model(z_t)
1535
+
1536
+ return z_t_to_p
1537
+
1538
+
1539
+
1540
+ def training_step(self, batch: torch.Tensor, batch_id: any) -> dict:
1541
+
1542
+ # check if the batch is a list and split data if so
1543
+ if isinstance(batch, list):
1544
+ text_embeddings, protein_embeddings = batch
1545
+
1546
+ # forward pass with the model
1547
+ z_t_to_p = self(z_t=text_embeddings)
1548
+
1549
+ # compute loss
1550
+ loss = self.model.compute_loss(
1551
+ output=z_t_to_p,
1552
+ target=protein_embeddings,
1553
+ loss_option=self.args.loss_type
1554
+ )
1555
+
1556
+ # log the total loss
1557
+ self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1558
+
1559
+ return {'loss': loss}
1560
+
1561
+
1562
+ def validation_step(self, batch: torch.Tensor, batch_id: any) -> dict:
1563
+
1564
+ # check if the batch is a list and split data if so
1565
+ if isinstance(batch, list):
1566
+ text_embeddings, protein_embeddings = batch
1567
+
1568
+ # forward pass with the model
1569
+ z_t_to_p = self(z_t=text_embeddings)
1570
+
1571
+ # compute loss
1572
+ loss = self.model.compute_loss(
1573
+ output=z_t_to_p,
1574
+ target=protein_embeddings,
1575
+ loss_option=self.args.loss_type
1576
+ )
1577
+
1578
+ # log the total loss
1579
+ self.log('valid_loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
1580
+
1581
+ return {'loss': loss}
1582
+
1583
+
1584
+ def configure_optimizers(self,):
1585
+
1586
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
1587
+
1588
+ return {
1589
+ "optimizer": optimizer
1590
+ }
1591
+
1592
+
1593
+ def predict_step(self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = None) -> torch.Tensor:
1594
+ """
1595
+ Defines a single prediction (inference) step.
1596
+ """
1597
+
1598
+ # Unpack the batch if it comes in a list format.
1599
+ # Here, we only take text embeddings for prediction as an example.
1600
+ if isinstance(batch, list):
1601
+ text_embeddings, _ = batch # We ignore the second element (protein_embeddings)
1602
+ else:
1603
+ text_embeddings = batch
1604
+
1605
+ # Perform forward pass to get transformed text embeddings (z_t_to_p)
1606
+ z_t_to_p = self(z_t=text_embeddings)
1607
+ self.text_to_protein_joint_embeddings.append(z_t_to_p.detach().cpu())
1608
+
1609
+ return z_t_to_p
1610
+
1611
+ def on_predict_epoch_end(self, outputs=None):
1612
+
1613
+ self.text_to_protein_joint_embeddings = torch.cat(self.text_to_protein_joint_embeddings).cpu()
Stage1_source/helper_funcs.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pynvml import *
2
+ import psutil
3
+
4
+ """
5
+ To track memory allocation, let's take advantage of the nvidia-ml-py3 package and GPU memory allocation from python.
6
+
7
+ ref: https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one
8
+ """
9
+
10
+
11
+ def print_gpu_initialization():
12
+ nvmlInit()
13
+ handle = nvmlDeviceGetHandleByIndex(0)
14
+ info = nvmlDeviceGetMemoryInfo(handle)
15
+ print(f"GPU memory occupied: {info.used//1024**2} MB.")
16
+ return info.used // 1024**2
17
+
18
+
19
+ def print_summary(result):
20
+ print(f"Time: {result.metrics['train_runtime']:.2f}")
21
+ print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
22
+ print_gpu_utilization()
23
+
24
+
25
+ def print_memory_usage():
26
+ process = psutil.Process(os.getpid())
27
+ memory_in_bytes = process.memory_info().rss
28
+ memory_in_megabytes = memory_in_bytes / (1024 ** 2)
29
+ #print(f"Memory used by this script: {memory_in_megabytes:.2f} MB")
30
+
31
+ return memory_in_megabytes
32
+
33
+
34
+
35
+
36
+
37
+
Stage1_source/model.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertForMaskedLM
8
+ import torch.distributed as dist
9
+ import esm
10
+ from torch.nn.utils.weight_norm import weight_norm
11
+
12
+
13
+ """
14
+ functions and classes adapted from the following:
15
+ 1. https://keras.io/examples/vision/nl_image_search/
16
+ 2. https://colab.research.google.com/drive/1hYHb0FTdKQCXZs3qCwVZnSuVGrZU2Z1w?usp=sharing
17
+ """
18
+
19
+ class ProteinEncoder(nn.Module):
20
+ """
21
+ Encoder for protein sequence to a fixed size vector --> z_s
22
+ """
23
+
24
+ def __init__(self, args: any):
25
+ super().__init__()
26
+
27
+ #self.script_args = args
28
+ self.seq_model_path = args.seq_model_path
29
+ self.pretrained = args.pretrained_seq
30
+ self.trainable = args.trainable_seq
31
+ self.n_layers_to_finetune = args.pLM_n_layers_to_finetune
32
+ self.rep_layer = args.rep_layer
33
+ self.model, self.alphabet = self.get_ESM_model() # get model and alphabet (ESM)
34
+
35
+ for p in self.model.parameters():
36
+ if self.trainable and self.n_layers_to_finetune == 0:
37
+ p.required_grad = True
38
+ else:
39
+ p.requires_grad = False
40
+
41
+ # Make the last n_layers_to_finetune layers trainable
42
+ if self.trainable and self.n_layers_to_finetune != 0:
43
+ for layer in self.model.layers[-self.n_layers_to_finetune:]:
44
+ for p in layer.parameters():
45
+ p.requires_grad = True
46
+
47
+ # Use the [CLS] token hidden representation as the sentence's embedding
48
+ # for the downstream latent alignment.
49
+ self.target_token_idx = 0
50
+
51
+ def get_ESM_model(self):
52
+
53
+ return esm.pretrained.load_model_and_alphabet(
54
+ os.path.expanduser(
55
+ self.seq_model_path
56
+ )
57
+ )
58
+
59
+ def forward(self, x_s: torch.Tensor, compute_logits: bool=False):
60
+ # drop channel depth
61
+ x_s = x_s.squeeze(1)
62
+
63
+ outputs = self.model(
64
+ x_s,
65
+ repr_layers=[self.rep_layer],
66
+ return_contacts=False
67
+ )
68
+
69
+ # mask langauge model objective
70
+ if compute_logits:
71
+ logits = outputs['logits']
72
+ return logits
73
+
74
+ # fine-tuning cls token for protein sequence alignment with biomedical text
75
+ cls_hidden = outputs['representations'][self.rep_layer][:,self.target_token_idx,:]
76
+ return cls_hidden
77
+
78
+ class TextEncoder(nn.Module):
79
+
80
+ """
81
+ Encoder for protein's natural text to a fixed size vector --> z_t
82
+ """
83
+
84
+ def __init__(self, args: any):
85
+ super().__init__()
86
+
87
+ self.model_name = args.text_model_path
88
+ self.pretrained = args.pretrained_text
89
+ self.trainable = args.trainable_text
90
+ self.n_layers_to_finetune = args.bLM_n_layers_to_finetune
91
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model_path)
92
+
93
+ if self.pretrained:
94
+ #self.model = AutoModel.from_pretrained(self.model_name)
95
+ self.model = BertForMaskedLM.from_pretrained(self.model_name)
96
+
97
+ else:
98
+ #self.model = AutoModel.from_config(self.model_name)
99
+ self.model = BertForMaskedLM.from_config(self.model_name)
100
+
101
+ for p in self.model.parameters():
102
+ if self.trainable and self.n_layers_to_finetune == 0:
103
+ p.required_grad = True
104
+ else:
105
+ p.requires_grad = False
106
+
107
+ # Make the last n_layers_to_finetune layers trainable
108
+ if self.trainable and self.n_layers_to_finetune != 0:
109
+ for layer in self.model.bert.encoder.layer[-self.n_layers_to_finetune:]:
110
+ for p in layer.parameters():
111
+ p.requires_grad = True
112
+
113
+ # Use the [CLS] token hidden representation as the sentence's embedding
114
+ # for the downstream latent alignment.
115
+ self.target_token_idx = 0
116
+
117
+ def forward(self, inputs: torch.Tensor, compute_logits: bool=False) -> torch.Tensor:
118
+ # drop channel depth
119
+ inputs = inputs.squeeze(1)
120
+
121
+ if compute_logits:
122
+ # compute the masked language model logits
123
+ #sequence_output = outputs.last_hidden_state
124
+ outputs = self.model(inputs)
125
+ logits = outputs.logits
126
+ return logits
127
+
128
+ else:
129
+ outputs = self.model(inputs, output_hidden_states=True)
130
+ # use the token representations...
131
+ last_hidden_state = outputs.hidden_states[-1]
132
+ return last_hidden_state[:, self.target_token_idx, :] # return [cls] token
133
+
134
+
135
+
136
+ class ProjectionHead(nn.Module):
137
+ """
138
+ g(.) which maps z_t --> h_t or z_s --> h_s
139
+
140
+ Note: h is the joint embedding representation, h_t
141
+ is the joint embedding for the text caption, and
142
+ h_s is the joint embedding for the protein sequence.
143
+ """
144
+
145
+ def __init__(self, embedding_dim: int, args: any):
146
+
147
+ super().__init__()
148
+ self.projection_dim = args.proj_embedding_dim
149
+ self.dropout = args.dropout
150
+ self.embedding_dim = embedding_dim
151
+
152
+ # model graph
153
+ self.projection = nn.Linear(self.embedding_dim, self.projection_dim)
154
+ self.gelu = nn.GELU()
155
+ self.fc = nn.Linear(self.projection_dim, self.projection_dim)
156
+ self.dropout = nn.Dropout(self.dropout)
157
+ self.layer_norm = nn.LayerNorm(self.projection_dim)
158
+
159
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
160
+
161
+ projection = self.projection(z)
162
+ h = self.gelu(projection)
163
+ h = self.fc(h)
164
+ h = self.dropout(h)
165
+ h = h + projection
166
+ h = self.layer_norm(h)
167
+ return h
168
+
169
+
170
+
171
+
172
+ #####################
173
+ # Pfam architecture #
174
+ #####################
175
+
176
+
177
+
178
+ class pfam_PEN_CL(nn.Module):
179
+
180
+ """
181
+ Protein Embeddings with Natural lanauge using Constrastive Learing (PEN-CL) while including pfam constrastive learning.
182
+ """
183
+
184
+ def __init__(self, args: any):
185
+
186
+ super().__init__()
187
+
188
+ self.protein_embedding = args.protein_encoder_embedding
189
+ self.text_embedding = args.text_encoder_embedding
190
+ self.temperature = args.temperature
191
+
192
+ # protein sequence expert
193
+ self.protein_encoder = ProteinEncoder(args=args)
194
+ # natural text expert
195
+ self.text_encoder = TextEncoder(args=args)
196
+
197
+ # projection heads g_seq( . ) --> joint embedding space
198
+ self.protein_projection = ProjectionHead(
199
+ embedding_dim=self.protein_embedding,
200
+ args=args
201
+ )
202
+
203
+ # projection heads g_text( . ) --> joint embedding space
204
+ self.text_projection = ProjectionHead(
205
+ embedding_dim=self.text_embedding,
206
+ args=args
207
+ )
208
+
209
+ def forward(
210
+ self,
211
+ x_t: torch.Tensor,
212
+ x_s: torch.Tensor,
213
+ compute_masked_logits: bool=False
214
+ ) -> dict:
215
+
216
+ if compute_masked_logits:
217
+ # forward pass for computing logits for masked langauge objective
218
+ protein_logits = self.protein_encoder(x_s, compute_logits=True)
219
+ text_logits = self.text_encoder(x_t, compute_logits=True)
220
+
221
+ return {
222
+ 'text_masked_logits': text_logits,
223
+ 'protein_masked_logits': protein_logits
224
+ }
225
+
226
+ else:
227
+ # split the tuple into 2 dicts...
228
+ # getting protein sequence and text inputs ...
229
+ z_t = self.text_encoder(x_t, compute_logits=False)
230
+ z_s = self.protein_encoder(x_s, compute_logits=False)
231
+
232
+ # "joint" sequence and text embedding (with same dimension)
233
+ z_t_joint = self.text_projection(z_t)
234
+ z_s_joint = self.protein_projection(z_s)
235
+
236
+ return {
237
+ 'text_joint_latent': z_t_joint,
238
+ 'seq_joint_latent': z_s_joint,
239
+ }
240
+
241
+ def compute_inter_loss(
242
+ self,
243
+ protein_embeddings: torch.Tensor,
244
+ text_embeddings: torch.Tensor,
245
+ batch_size: int
246
+ ) -> (
247
+ torch.Tensor,
248
+ torch.Tensor
249
+ ):
250
+
251
+ """
252
+ Compute the inter-modal contrastive InfoNCE loss between protein and text embeddings.
253
+
254
+ Parameters:
255
+ - protein_embeddings: A tensor representing the embeddings of the protein sequences.
256
+ - text_embeddings: A tensor representing the embeddings of the text descriptions.
257
+ - batch_size: The number of samples in the batch.
258
+
259
+ Steps:
260
+ 1. Generate a masking matrix to identify off-diagonal elements.
261
+ 2. Compute cosine similarities (i.e., logits) between text and protein embeddings.
262
+ 3. Compute self-similarities for both protein and text embeddings.
263
+ 4. Mask off-diagonal elements between swiss-prot and pfam in the similarity matrices.
264
+ 5. Define ground truth by averaging the masked protein and text similarity matrices.
265
+ 6. Compute the contrastive loss for the protein and text embeddings using the ground truth.
266
+
267
+ Returns:
268
+ - Mean contrastive loss for the given batch of protein and text embeddings.
269
+ - The logits (cosine similarity matrix between text and protein embeddings).
270
+
271
+ Note: This function assumes a specific structure in the input batches, where corresponding positive samples
272
+ in the protein and text embeddings are arranged in a particular way, allowing for masking and contrastive loss calculation.
273
+ """
274
+
275
+ # get off-diagonal masking matrix
276
+ mask = torch.zeros((2*batch_size, 2*batch_size))
277
+ # mask the bottom left quadrant diagonal
278
+ mask[batch_size:, :batch_size] = torch.eye(batch_size)
279
+ # mask the top right quadrant
280
+ mask[:batch_size, batch_size:] = torch.eye(batch_size)
281
+ # convert to correct device and convert to boolean
282
+ mask = mask.to(protein_embeddings.device).bool()
283
+
284
+ # matrix multiplication between model embeddings
285
+ logits = (text_embeddings @ protein_embeddings.T) / self.temperature
286
+ protein_similarity = protein_embeddings @ protein_embeddings.T
287
+ text_similarity = text_embeddings @ text_embeddings.T
288
+
289
+ # mask the off-diagonal between swiss-prot and pfam
290
+ mask_protein_similarity = self.set_inf(protein_similarity, mask)
291
+ mask_text_similarity = self.set_inf(text_similarity, mask)
292
+ mask_logits = self.set_inf(logits, mask)
293
+
294
+ # ground truth
295
+ targets = F.softmax(
296
+ (mask_protein_similarity + mask_text_similarity) / (2 * self.temperature), dim=-1
297
+ )
298
+
299
+ # compute loss
300
+ text_loss = self.cross_entropy(mask_logits, targets, reduction='none')
301
+ protein_loss = self.cross_entropy(mask_logits.T, targets.T, reduction='none')
302
+ loss = (protein_loss + text_loss) / 2.0
303
+
304
+ return (
305
+ loss.mean(),
306
+ mask_logits.detach().cpu()
307
+ )
308
+
309
+
310
+ def compute_intra_loss(
311
+ self,
312
+ protein_embeddings,
313
+ batch_size
314
+ ) -> (
315
+ torch.Tensor,
316
+ torch.Tensor,
317
+ ):
318
+ """
319
+ Compute the intra-modal contrastive InfoNCE loss for protein embeddings.
320
+
321
+ Parameters:
322
+ - protein_embeddings: A tensor representing the embeddings of the protein sequences.
323
+ - batch_size: Batch size used for training.
324
+
325
+ Steps:
326
+ 1. Normalize the protein embeddings using L2 normalization.
327
+ 2. Compute the cosine similarity between the normalized embeddings.
328
+ 3. Mask the diagonal of the cosine similarity matrix to avoid using a protein's similarity with itself.
329
+ 4. Define positive examples by rolling the mask. The positive example for a given protein embedding is determined by an embedding half the batch size away.
330
+ 5. Compute the InfoNCE loss using the masked cosine similarity matrix.
331
+
332
+ Returns:
333
+ - Mean InfoNCE loss for the given batch of protein embeddings.
334
+ - The cosine similarity matrix.
335
+
336
+ Note: The underlying assumption is that in each batch, corresponding positive samples for a given protein embedding
337
+ lie half the batch size away. The function computes the negative log likelihood loss between these positive samples
338
+ and the entire batch.
339
+ """
340
+
341
+ # l2 normalization
342
+ #norm_protein_embeddings = F.normalize(protein_embeddings, p=2, dim=1)
343
+ norm_protein_embeddings = protein_embeddings
344
+
345
+ # cosine similarity
346
+ cosine_similarity = (norm_protein_embeddings @ norm_protein_embeddings.T) / self.temperature
347
+
348
+ # mask cosine similarity matrix
349
+ sample_size = protein_embeddings.shape[0]
350
+ mask = torch.eye(sample_size, device=cosine_similarity.device, dtype=torch.bool)
351
+ #cosine_similarity.masked_fill_(mask, float(-9e15))
352
+ cosine_similarity = self.set_inf(cosine_similarity, mask)
353
+
354
+ # Find positive example -> batch_size //2 away from the original example (swiss-prot<>pfam)
355
+ pos_mask = mask.roll(shifts=mask.shape[0]//2, dims=0)
356
+
357
+ # InfoNCE loss
358
+ nll = -cosine_similarity[pos_mask] + torch.logsumexp(cosine_similarity, dim=-1)
359
+
360
+ return (
361
+ nll.mean(),
362
+ cosine_similarity.cpu(),
363
+ )
364
+
365
+ def set_inf(
366
+ self,
367
+ tensor: torch.Tensor,
368
+ mask: torch.Tensor
369
+ ) -> torch.Tensor:
370
+ # Determine replacement value based on tensor dtype
371
+ if tensor.dtype == torch.float32:
372
+ replace_value = -9e15
373
+ elif tensor.dtype == torch.float16:
374
+ replace_value = -1e4
375
+ else:
376
+ raise ValueError("Unsupported tensor dtype for this operation.")
377
+
378
+ # Use masked_fill_ to replace positions in tensor where mask is True with the specified value
379
+ tensor.masked_fill_(mask, replace_value)
380
+
381
+ return tensor
382
+
383
+ def cross_entropy(
384
+ self,
385
+ preds: torch.Tensor,
386
+ targets: torch.Tensor,
387
+ reduction: str='none'
388
+ ) -> torch.Tensor:
389
+
390
+ # compute categorical cross entropy
391
+ log_softmax = nn.LogSoftmax(dim=-1)
392
+ loss = (-targets * log_softmax(preds)).sum(1)
393
+
394
+ if reduction == 'none':
395
+ return loss
396
+ elif reduction == 'mean':
397
+ return loss.mean()
398
+ else:
399
+ assert False, print('Choose either "none" or "mean" for reduction argument')
400
+
401
+ def compute_masked_lang_loss(
402
+ self,
403
+ logits_masked: torch.Tensor,
404
+ targets: torch.Tensor,
405
+ targets_masked: torch.Tensor,
406
+ mask_token_id: torch.Tensor
407
+ ) -> torch.Tensor:
408
+
409
+ """
410
+ Compute the masked language model loss for BERT-like architectures.
411
+
412
+ Given a batch of logits predicted for masked positions and their corresponding target tokens, this function
413
+ computes the cross-entropy loss between the predicted logits and the true labels, but only for positions
414
+ that have been masked in the input.
415
+
416
+ Parameters:
417
+ - logits_masked: Predicted token logits for masked positions from the model.
418
+ Shape: (batch_size, seq_len, vocab_size).
419
+ - targets: True token IDs for each position in the input sequence.
420
+ Shape: (batch_size, seq_len).
421
+ - targets_masked: Token IDs for the input sequence, including masked positions.
422
+ Shape: (batch_size, seq_len).
423
+ - mask_token_id: The ID corresponding to the [MASK] token in the vocabulary.
424
+
425
+ Steps:
426
+ 1. Compute the cross-entropy loss between predicted logits and true labels across all positions.
427
+ 2. For each sample in the batch, locate the positions that were masked.
428
+ 3. Extract the loss values corresponding to these masked positions.
429
+ 4. Compute and return the mean of these extracted loss values across the batch.
430
+
431
+ Returns:
432
+ - Mean cross-entropy loss for masked positions across the batch.
433
+
434
+ Note: This function focuses exclusively on masked positions in the input, as is typical for the MLM objective
435
+ in BERT-like models. It disregards unmasked positions.
436
+ """
437
+
438
+ # compute the masked langauge objective loss for masked logits
439
+ loss_func = nn.CrossEntropyLoss(reduction='none')
440
+ loss_mask = loss_func(
441
+ logits_masked.permute(0, 2, 1), # (batch_size, vocab_size, seq_len)
442
+ targets.squeeze(1) # (batch_size, seq_len)
443
+ )
444
+
445
+ # list to append loss values
446
+ batch_loss = []
447
+
448
+ for ii, target_mask_sample in enumerate(targets_masked):
449
+
450
+ # locate mask positions
451
+ masked_positions = (target_mask_sample == mask_token_id).tolist()
452
+ # extract the loss values at those masked positions
453
+ loss_mask_sample = loss_mask[ii][masked_positions]
454
+
455
+ # append mean loss value for a given batch sample
456
+ if loss_mask_sample.numel() > 0:
457
+ batch_loss.append(torch.mean(loss_mask_sample).unsqueeze(0))
458
+
459
+ if len(loss_mask_sample) > 0:
460
+ loss_mask_mean = torch.mean(torch.cat(batch_loss))
461
+ else:
462
+ # handle the case where there are no masked positions in any sample
463
+ loss_mask_mean = torch.tensor(0.0, device=logits_masked.device)
464
+
465
+ return loss_mask_mean
466
+
467
+
468
+ ###############
469
+ # Facilitator #
470
+ ###############
471
+
472
+
473
+ class Facilitator(nn.Module):
474
+
475
+ def __init__(self,
476
+ in_dim: int, # Input dimension
477
+ hid_dim: int, # Hidden layer dimension
478
+ out_dim: int, # Output dimension
479
+ dropout: float = 0. # Dropout rate
480
+ ):
481
+ super().__init__()
482
+
483
+ # Main neural network structure
484
+ self.main = nn.Sequential(
485
+ weight_norm(nn.Linear(in_dim, hid_dim), dim=None), # Weight-normalized linear layer
486
+ nn.GELU(), # GELU activation function
487
+ nn.Dropout(dropout, inplace=True), # Dropout layer
488
+ weight_norm(nn.Linear(hid_dim, out_dim), dim=None) # Weight-normalized output layer
489
+ )
490
+
491
+ def forward(self, x):
492
+ # Forward pass through the network
493
+ return self.main(x)
494
+
495
+ def compute_loss(self, output: torch.Tensor, target: torch.Tensor, loss_option='MSE') -> torch.Tensor:
496
+ # Compute loss based on the chosen loss_option ('MSE' or 'MMD')
497
+ if loss_option == 'MSE':
498
+ return Facilitator.compute_MSE(output, target)
499
+ elif loss_option == 'MMD':
500
+ return Facilitator.compute_mmd(output, target)
501
+ else:
502
+ return ValueError("Invalid loss option")
503
+
504
+ @staticmethod
505
+ def compute_MSE(output, target):
506
+ # Compute Mean Squared Error between output and target
507
+ mse_loss = nn.MSELoss()
508
+ loss = mse_loss(output, target)
509
+ return loss
510
+
511
+ @staticmethod
512
+ def compute_kernel(
513
+ x: torch.FloatTensor,
514
+ y: torch.FloatTensor
515
+ ) -> torch.FloatTensor:
516
+ """
517
+ Compute the Gaussian RBF kernel between tensors x and y
518
+ """
519
+
520
+ # Get the sizes of each mini-batch
521
+ x_size, y_size = x.shape[0], y.shape[0]
522
+
523
+ # Dimension based on z size
524
+ dim = x.shape[1]
525
+
526
+ x = x.view(x_size, 1, dim)
527
+ y = y.view(1, y_size, dim)
528
+
529
+ x_core = x.expand(x_size, y_size, dim)
530
+ y_core = y.expand(x_size, y_size, dim)
531
+
532
+ # Gaussian RBF kernel computation
533
+ return torch.exp(-(x_core - y_core).pow(2).mean(2) / dim)
534
+
535
+ @staticmethod
536
+ def compute_mmd(
537
+ x: torch.FloatTensor,
538
+ y: torch.FloatTensor
539
+ ) -> torch.FloatTensor:
540
+ """
541
+ Compute the Maximum Mean Discrepancy (MMD) between two distributions.
542
+ Args:
543
+ x: Samples from first distribution (z_t_to_p ~ q(z_p))
544
+ y: Samples from second distribution (z_p ~ p(z_p))
545
+ Returns:
546
+ MMD_loss: The MMD loss between the sampled distributions
547
+ """
548
+
549
+ x_kernel = Facilitator.compute_kernel(x, x)
550
+ y_kernel = Facilitator.compute_kernel(y, y)
551
+ xy_kernel = Facilitator.compute_kernel(x, y)
552
+
553
+ # Calculate MMD loss
554
+ return x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean()
555
+
556
+
Stage1_source/preprocess.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import random_split, Dataset, DataLoader, Subset, ConcatDataset
3
+ import pandas as pd
4
+ import random
5
+ import ast
6
+ import dask.dataframe as dd
7
+ import os
8
+ from sklearn.model_selection import train_test_split
9
+ from pytorch_lightning import LightningDataModule
10
+ from tqdm import tqdm
11
+ import gc
12
+ import psutil
13
+ import time
14
+ import copy
15
+
16
+ import esm
17
+ from esm import pretrained
18
+ from transformers import AutoTokenizer, AutoModel
19
+
20
+
21
+ ########################################
22
+ # Dataset iterator with masking tokens #
23
+ ########################################
24
+
25
+ class TextSeqPairing_Dataset(Dataset):
26
+
27
+ def __init__(self, args: any, df: pd.Series):
28
+
29
+ # dataframe
30
+ self.df = df
31
+ self.length = self.df.shape[0]
32
+ self.df_column_names = self.df.columns.tolist()
33
+ self.protein_sequence_list = self.df[args.sequence_keyword].tolist()
34
+ self.text_captions_list = self.df['[final]text_caption'].tolist()
35
+ self.accession_id_list = self.df[args.id_keyword].tolist()
36
+
37
+ # parameters
38
+ self.text_max_length = args.text_max_length # max BERT sequence tokenization length
39
+ self.seq_max_length = 1024 # max ESM model
40
+
41
+ # tokenizers
42
+ self.text_tokenizer = AutoTokenizer.from_pretrained(args.text_model_path) # for text encoder
43
+ _, self.sequence_tokenizer = pretrained.load_model_and_alphabet(args.seq_model_path) # for protein encoder
44
+
45
+ def caption_tokenizer(self, batch_captions: list) -> dict:
46
+
47
+ # transform input text tokens
48
+ text_inputs = self.text_tokenizer.batch_encode_plus(
49
+ batch_captions,
50
+ truncation=True,
51
+ max_length=self.text_max_length,
52
+ padding='max_length',
53
+ return_tensors='pt',
54
+ return_attention_mask=True,
55
+ return_token_type_ids=False
56
+ )
57
+
58
+ # track the original natural language captions
59
+ text_inputs['orig_captions'] = batch_captions
60
+
61
+ return text_inputs
62
+
63
+ def protein_tokenizer(self, batch_sequences: list) -> dict:
64
+
65
+ # perpare data for ESM
66
+ batch_converter = self.sequence_tokenizer.get_batch_converter()
67
+ batch_labels, batch_str, batch_tokens = batch_converter(batch_sequences)
68
+
69
+ # pad sequences
70
+ batch_tokens = torch.cat((
71
+ batch_tokens,
72
+ torch.ones((1,1024-batch_tokens.shape[1])),
73
+ ), dim=-1
74
+ )
75
+
76
+ sequence_inputs = {
77
+ 'protein_sequence_labels': batch_labels, # UniProtKB id
78
+ 'protein_sequence_str': batch_str, # original protein sequence (in amino acids)
79
+ 'protein_sequence_tokens': batch_tokens.long() # training data
80
+ }
81
+
82
+ return sequence_inputs
83
+
84
+
85
+ def __getitem__(self, idx: torch.Tensor) -> (
86
+ dict,
87
+ dict
88
+ ):
89
+
90
+ protein_sequence = self.protein_sequence_list[idx]
91
+ text_captions = self.text_captions_list[idx]
92
+ accession_id = self.accession_id_list[idx]
93
+
94
+ # prepare protein sequence in ESM format (e.g. tuple: (header, sequence)):
95
+ batch_sequences = [
96
+ (accession_id, protein_sequence)
97
+ ]
98
+
99
+ text_data = self.caption_tokenizer(batch_captions=[text_captions])
100
+ protein_data = self.protein_tokenizer(batch_sequences=batch_sequences)
101
+
102
+ return (
103
+ text_data['input_ids'],
104
+ protein_data['protein_sequence_tokens']
105
+ )
106
+
107
+ def __len__(self):
108
+ return self.length
109
+
110
+
111
+ ######################
112
+ # Default DataModule #
113
+ ######################
114
+
115
+
116
+ class Default_DataModule(LightningDataModule):
117
+ def __init__(self, args):
118
+ super().__init__()
119
+ self.args = args
120
+
121
+ # construct dataset iterator
122
+ dataset_options = {
123
+ 'default': TextSeqPairing_Dataset,
124
+ 'masked': MaskTextSeqPairing_Dataset,
125
+ 'pfam': Pfam_TextSeqPairing_Dataset,
126
+ 'pfam_ablated': Pfam_TextSeqPairing_Dataset
127
+ }
128
+
129
+ self.dataset_class = dataset_options.get(args.dataset_type, TextSeqPairing_Dataset)
130
+
131
+ def prepare_data(self):
132
+ pass
133
+
134
+ def setup(self, stage=None):
135
+
136
+ if self.trainer is not None:
137
+ print(f"Number of GPUs: {self.trainer.world_size}")
138
+ print(f"Current GPU index: {self.trainer.local_rank}")
139
+
140
+ # Load Swiss-Prot data
141
+ df = self.load_swiss_prot()
142
+
143
+ # Split the dataframe into train and valid sets
144
+ train_df, valid_df = train_test_split(
145
+ df,
146
+ test_size=self.args.valid_size,
147
+ random_state=self.args.seed
148
+ )
149
+
150
+ print(f"Available memory after pfam_df: {check_available_memory()} GB")
151
+
152
+ # Define datasets and dataloaders
153
+ self.train_dataset = self.dataset_class(args=self.args, df=train_df)
154
+ self.valid_dataset = self.dataset_class(args=self.args, df=valid_df)
155
+
156
+ def load_swiss_prot(self) -> pd.Series:
157
+ # Load and preprocess data (called on each GPU/TPU in DDP)
158
+ print(f'Load Swiss-Prot data...')
159
+
160
+ # Load Swiss-Prot data
161
+ df = pd.read_csv(os.path.expanduser(self.args.data_path))
162
+ df = df[df['protein_sequence'].apply(lambda seq: len(seq) <= 1022)]
163
+
164
+ return df
165
+
166
+ def train_dataloader(self):
167
+ return DataLoader(
168
+ self.train_dataset,
169
+ batch_size=self.args.batch_size,
170
+ num_workers=self.args.num_workers,
171
+ shuffle=True,
172
+ pin_memory=True
173
+ )
174
+
175
+ def val_dataloader(self):
176
+ return DataLoader(
177
+ self.valid_dataset,
178
+ batch_size=self.args.batch_size,
179
+ num_workers=self.args.num_workers,
180
+ pin_memory=True
181
+ )
182
+
183
+ def test_dataloader(self):
184
+ # Define test dataloader if needed
185
+ pass
186
+
187
+
188
+
189
+ ################################
190
+ # Facilitator Dataset Iterator #
191
+ ################################
192
+
193
+
194
+ class Facilitator_Dataset(Dataset):
195
+
196
+ def __init__(self, args: any, dataset: dict):
197
+
198
+ # Determine the device based on the number of GPUs
199
+ device = 'cuda' if args.num_gpus >= 1 else 'cpu'
200
+
201
+ # Check if text_embeddings is a list and convert to a tensor
202
+ if isinstance(dataset['text_embedding'], list):
203
+ # Convert list elements to tensors if they are not already
204
+ text_emb_tensors = [torch.tensor(emb).to(device) if not isinstance(emb, torch.Tensor) else emb.to(device) for emb in dataset['text_embedding']]
205
+ # Stack the list of tensors
206
+ self.text_embeddings = torch.stack(text_emb_tensors)
207
+ else:
208
+ self.text_embeddings = dataset['text_embedding'].to(device)
209
+
210
+ # Check if protein_embeddings is a list and convert to a tensor
211
+ if isinstance(dataset['protein_embedding'], list):
212
+ # Convert list elements to tensors if they are not already
213
+ protein_emb_tensors = [torch.tensor(emb).to(device) if not isinstance(emb, torch.Tensor) else emb.to(device) for emb in dataset['protein_embedding']]
214
+ # Stack the list of tensors
215
+ self.protein_embeddings = torch.stack(protein_emb_tensors)
216
+ else:
217
+ self.protein_embeddings = dataset['protein_embedding'].to(device)
218
+
219
+
220
+ def __getitem__(self, idx: torch.Tensor) -> (
221
+ torch.Tensor,
222
+ torch.Tensor
223
+ ):
224
+
225
+
226
+ z_t = self.text_embeddings[idx]
227
+ z_p = self.protein_embeddings[idx]
228
+
229
+ return (
230
+ z_t,
231
+ z_p
232
+ )
233
+
234
+
235
+ def __len__(self):
236
+ return len(self.text_embeddings)
237
+
238
+ ###########################
239
+ # Facilitator Data Module #
240
+ ###########################
241
+
242
+
243
+
244
+ class Facilitator_DataModule(LightningDataModule):
245
+ def __init__(self, args):
246
+ super().__init__()
247
+
248
+ self.args = args
249
+
250
+ self.OOD_pfam_labels = [
251
+ 'PF18369', # Polyketide synthase dimerisation element domain
252
+ 'PF04680', # Opioid growth factor receptor repeat
253
+ 'PF17988', # VEGFR-2 Transmembrane domain
254
+ 'PF12325', # TATA element modulatory factor 1 TATA binding
255
+ 'PF03272', # Putative mucin or carbohydrate-binding module
256
+ 'PF03938', # Outer membrane protein (OmpH-like)
257
+ 'PF17724', # Family of unknown function (DUF5568)
258
+ 'PF10696', # Protein of unknown function
259
+ 'PF11968', # 25S rRNA (adenine(2142)-N(1))-methyltransferase, Bmt2
260
+ 'PF04153' # NOT2/NOT3/NOT5 C-terminal
261
+ ]
262
+
263
+
264
+ # prepare embeddings
265
+ #self.embedding_data = torch.load(args.swissprot_data_path)
266
+ # dataset iterator
267
+ #dataset = Facilitator_Dataset(args=args, dataset=self.embedding_data)
268
+ # create a clone of the dataset
269
+ #cloned_dataset = copy.deepcopy(dataset)
270
+
271
+ # Get indices and split them
272
+ #indices = list(range(len(dataset)))
273
+ #train_indices, valid_indices = train_test_split(indices, test_size=args.valid_size, random_state=args.seed)
274
+
275
+ # create full dataloader
276
+ #self.all_dataloader = DataLoader(cloned_dataset, batch_size=args.batch_size, shuffle=False)
277
+
278
+ # Create PyTorch DataLoader using the indices
279
+ #self.train_sampler = Subset(dataset, train_indices)
280
+ #self.valid_sampler = Subset(dataset, valid_indices)
281
+ #train_dataloader = DataLoader(train_sampler, batch_size=args.batch_size, shuffle=True)
282
+ #valid_dataloader = DataLoader(test_sampler, batch_size=args.batch_size, shuffle=False)
283
+
284
+ ##########################################
285
+ # Load Stage 1 SwissProt+Pfam Embeddings #
286
+ ##########################################
287
+
288
+ # initialize the embedding data to None
289
+ self.swissprot_data, self.pfam_data = None, None
290
+
291
+ # get both the swissprot and pfam dataset iterator in one
292
+ if (args.swissprot_data_path != 'None') and (args.pfam_data_path != 'None'):
293
+ print('Load both SwissProt and Pfam dataset...')
294
+ self.train_dataset, self.valid_dataset, self.all_swiss_dataloader, self.all_pfam_dataloader = self.load_both()
295
+
296
+ # get the swissprot dataset iterator
297
+ elif args.pfam_data_path == 'None':
298
+ print('Load SwissProt dataset...')
299
+ self.train_dataset, self.valid_dataset, self.all_swiss_dataloader = self.load_swissprot()
300
+ self.all_pfam_dataloader = None
301
+
302
+ # get the pfam dataset iterator
303
+ elif args.swissprot_data_path == 'None':
304
+ print('Load Pfam dataset...')
305
+ self.train_dataset, self.valid_dataset, self.all_pfam_dataloader = self.load_pfam()
306
+ self.all_swiss_dataloader = None
307
+
308
+
309
+
310
+ def load_swissprot(self):
311
+
312
+ # prepare embeddings
313
+ self.swissprot_data = torch.load(self.args.swissprot_data_path)
314
+
315
+ # dataset iterator
316
+ swiss_dataset = Facilitator_Dataset(args=self.args, dataset=self.swissprot_data)
317
+ # create a clone of the dataset
318
+ cloned_swiss_dataset = copy.deepcopy(swiss_dataset)
319
+
320
+ # Get indices and split them
321
+ indices = list(range(len(swiss_dataset)))
322
+ train_indices, valid_indices = train_test_split(indices, test_size=self.args.valid_size, random_state=self.args.seed)
323
+
324
+ # Create Pytorch iterator using the indices
325
+ swiss_train_subset = Subset(swiss_dataset, train_indices)
326
+ swiss_valid_subset = Subset(swiss_dataset, valid_indices)
327
+
328
+ # Create Pytorch dataloader on all samples
329
+ swiss_all_dataloader = DataLoader(cloned_swiss_dataset, batch_size=self.args.batch_size, shuffle=False)
330
+
331
+
332
+ return (
333
+ swiss_train_subset,
334
+ swiss_valid_subset,
335
+ swiss_all_dataloader
336
+ )
337
+
338
+
339
+ def load_pfam(self):
340
+
341
+ # prepare embeddings
342
+ self.pfam_data = torch.load(self.args.pfam_data_path)
343
+
344
+ # dataset iterator
345
+ pfam_dataset = Facilitator_Dataset(args=self.args, dataset=self.pfam_data)
346
+ # create a clone of the dataset
347
+ cloned_pfam_dataset = copy.deepcopy(pfam_dataset)
348
+
349
+ # Get indices and split them
350
+ indices = list(range(len(pfam_dataset)))
351
+ train_indices, valid_indices = train_test_split(indices, test_size=self.args.valid_size, random_state=self.args.seed)
352
+
353
+ # Create Pytorch Dataloader using the indices
354
+ pfam_train_subset = Subset(pfam_dataset, train_indices)
355
+ pfam_valid_subset = Subset(pfam_dataset, valid_indices)
356
+
357
+ # Create Pytorch dataloader on all samples
358
+ pfam_all_dataloader = DataLoader(cloned_pfam_dataset, batch_size=self.args.batch_size, shuffle=False)
359
+
360
+ return (
361
+ pfam_train_subset,
362
+ pfam_valid_subset,
363
+ pfam_all_dataloader
364
+ )
365
+
366
+
367
+ def load_both(self):
368
+
369
+ # get swissprot
370
+ swissprot_train_subset, swissprot_valid_subset, swissprot_all_dataloader = self.load_swissprot()
371
+
372
+ # get pfam
373
+ pfam_train_subset, pfam_valid_subset, pfam_all_dataloader = self.load_pfam()
374
+
375
+ # combined subsets
376
+ combined_train_subset = ConcatDataset([swissprot_train_subset, pfam_train_subset])
377
+ combined_valid_subset = ConcatDataset([swissprot_valid_subset, pfam_valid_subset])
378
+
379
+ return (
380
+ combined_train_subset,
381
+ combined_valid_subset,
382
+ swissprot_all_dataloader,
383
+ pfam_all_dataloader
384
+ )
385
+
386
+
387
+ def train_dataloader(self):
388
+ return DataLoader(
389
+ self.train_dataset,
390
+ #self.train_sampler,
391
+ batch_size=self.args.batch_size,
392
+ #num_workers=self.args.num_workers,
393
+ shuffle=True,
394
+ #pin_memory=True
395
+ )
396
+
397
+ def val_dataloader(self):
398
+ return DataLoader(
399
+ self.valid_dataset,
400
+ #self.valid_sampler,
401
+ batch_size=self.args.batch_size,
402
+ #num_workers=self.args.num_workers,
403
+ #pin_memory=True
404
+ )
405
+
406
+ def test_dataloader(self):
407
+ # Define test dataloader if needed
408
+ pass
409
+
410
+
stage1_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data_path": "None",
3
+ "pfam_data_path": "None",
4
+ "tb_logger_path": "None",
5
+ "tb_logger_folder": "None",
6
+ "version_name": "None",
7
+ "model_checkpoint_path": "/project/andrewferguson/niksapraljak/Project_ProtARDM/logs/Stage1_final_models/checkpoints/Pretraining_PENCiL_45M/epoch=19-step=116600.ckpt",
8
+ "output_dict_path": "/project/ranganathanr/niksapraljak/BioM3_PDZ/outputs/output_dict.pt",
9
+ "valid_size": 0.2,
10
+ "epochs": 10,
11
+ "acc_grad_batches": 1,
12
+ "batch_size": 80,
13
+ "num_workers": 12,
14
+ "weight_decay": "5e-7",
15
+ "patience": 1,
16
+ "factor": 0.8,
17
+ "temperature": 0.8,
18
+ "seed": 42,
19
+ "num_gpus": 1,
20
+ "precision": "16",
21
+ "dataset_type": "default",
22
+ "model_type": "pfam",
23
+ "fast_dev_run": 0,
24
+ "sequence_keyword": "protein_sequence",
25
+ "id_keyword": "primary_Accession",
26
+ "dataset_source": "swissprot",
27
+ "pfam_data_split_label": "0",
28
+ "base_lr": 0.0016,
29
+ "global_batch_size": 80,
30
+ "lr": 0.0005,
31
+ "seq_model_path": "/project/ranganathanr/niksapraljak/TextDiff_model_weights/Stage_1/pretrained_models/esm2_t33_650M_UR50D.pt",
32
+ "pretrained_seq": true,
33
+ "trainable_seq": true,
34
+ "rep_layer": 33,
35
+ "protein_encoder_embedding": 1280,
36
+ "protein_encoder_lr": 0.0005,
37
+ "pLM_n_layers_to_finetune": 1,
38
+ "text_model_path": "/project/ranganathanr/niksapraljak/TextDiff_model_weights/Stage_1/pretrained_models/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
39
+ "pretrained_text": true,
40
+ "trainable_text": true,
41
+ "text_encoder_embedding": 768,
42
+ "text_encoder_lr": 0.0005,
43
+ "text_max_length": 512,
44
+ "bLM_n_layers_to_finetune": 1,
45
+ "proj_embedding_dim": 512,
46
+ "dropout": 0.1,
47
+ "head_lr": 0.0005,
48
+ "inference_data_path": "/project/ranganathanr/niksapraljak/BioM3_PDZ/data/test_prompts_PDZ_swissprot_pfam_dataset.csv",
49
+ "inference_output_path": "/project/ranganathanr/niksapraljak/BioM3_PDZ/outputs/Stage1_test_prompts_PDZ.pt"
50
+ }