omarperacha commited on
Commit
16bd580
·
1 Parent(s): 6af3edb

work towards gen emb

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +9 -3
  3. ps4_data/get_embeddings.py +110 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .DS_Store
2
- .idea/
 
 
1
  .DS_Store
2
+ .idea/
3
+ ps4_data/__pycache__/
app.py CHANGED
@@ -1,11 +1,17 @@
1
  import gradio as gr
2
  from ps4_models.classifiers import *
 
3
 
4
 
5
- def pred(seq):
 
6
  model = PS4_Mega()
7
- return "Hello " + seq + "!!"
8
 
9
 
10
- iface = gr.Interface(fn=pred, inputs="text", outputs="text")
 
 
 
 
11
  iface.launch()
 
1
  import gradio as gr
2
  from ps4_models.classifiers import *
3
+ from ps4_data.get_embeddings import generate_embedings
4
 
5
 
6
+ def pred(residue_seq):
7
+ generate_embedings(residue_seq)
8
  model = PS4_Mega()
9
+ return "Hello " + residue_seq + "!!"
10
 
11
 
12
+ iface = gr.Interface(fn=pred, title="Protein Secondary Structure Prediction with PS4-Mega",
13
+ inputs="text", outputs="text", examples=[
14
+ ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"],
15
+ ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"]
16
+ ])
17
  iface.launch()
ps4_data/get_embeddings.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5EncoderModel, T5Tokenizer
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ import os
6
+
7
+
8
+ def generate_embedings(input_seq, output_path=None):
9
+
10
+ # Create directories
11
+ protT5_path = "ps4_data/data/protT5"
12
+ # where to store the embeddings
13
+ per_residue_path = "ps4_data/data/protT5/output/per_residue_embeddings" if output_path is None else output_path
14
+ for dir_path in [protT5_path, per_residue_path]:
15
+ __create_dir(dir_path)
16
+
17
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
18
+ print("Using {}".format(device))
19
+
20
+ # Load the encoder part of ProtT5-XL-U50 in half-precision (recommended)
21
+ model, tokenizer = __get_T5_model(device)
22
+
23
+ # Load fasta.
24
+ all_seqs = {"0": input_seq}
25
+
26
+ chunk_size = 1000
27
+
28
+ # Compute embeddings and/or secondary structure predictions
29
+ for i in range(0, len(all_seqs), chunk_size):
30
+ keys = list(all_seqs.keys())[i: chunk_size + i]
31
+ seqs = {k: all_seqs[k] for k in keys}
32
+ results = __get_embeddings(model, tokenizer, seqs, device)
33
+
34
+ # Store per-residue embeddings
35
+ __save_embeddings(results["residue_embs"], per_residue_path + f"{i}.npz")
36
+
37
+
38
+ def __get_T5_model(device):
39
+
40
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
41
+ model = model.to(device) # move model to GPU
42
+ model = model.eval() # set model to evaluation model
43
+ tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
44
+
45
+ return model, tokenizer
46
+
47
+
48
+ def __save_embeddings(emb_dict,out_path):
49
+ np.savez_compressed(out_path, **emb_dict)
50
+
51
+
52
+ def __get_embeddings(model, tokenizer, seqs, device, per_residue=True,
53
+ max_residues=4000, max_seq_len=1000, max_batch=100):
54
+
55
+ results = {"residue_embs": dict(),
56
+ "protein_embs": dict(),
57
+ "sec_structs": dict()
58
+ }
59
+
60
+ # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
61
+ seq_dict = sorted(seqs.items(), key=lambda kv: len(seqs[kv[0]]), reverse=True)
62
+ start = time.time()
63
+ batch = list()
64
+ for seq_idx, (pdb_id, seq) in enumerate(seq_dict, 1):
65
+ seq = seq
66
+ seq_len = len(seq)
67
+ seq = ' '.join(list(seq))
68
+ batch.append((pdb_id, seq, seq_len))
69
+
70
+ # count residues in current batch and add the last sequence length to
71
+ # avoid that batches with (n_res_batch > max_residues) get processed
72
+ n_res_batch = sum([s_len for _, _, s_len in batch]) + seq_len
73
+ if len(batch) >= max_batch or n_res_batch >= max_residues or seq_idx == len(seq_dict) or seq_len > max_seq_len:
74
+ pdb_ids, seqs, seq_lens = zip(*batch)
75
+ batch = list()
76
+
77
+ # add_special_tokens adds extra token at the end of each sequence
78
+ token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
79
+ input_ids = torch.tensor(token_encoding['input_ids']).to(device)
80
+ attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
81
+
82
+ try:
83
+ with torch.no_grad():
84
+ # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
85
+ embedding_repr = model(input_ids, attention_mask=attention_mask)
86
+ except RuntimeError:
87
+ print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
88
+ continue
89
+
90
+ for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
91
+ s_len = seq_lens[batch_idx]
92
+ # slice off padding --> batch-size x seq_len x embedding_dim
93
+ emb = embedding_repr.last_hidden_state[batch_idx, :s_len]
94
+ if per_residue: # store per-residue embeddings (Lx1024)
95
+ results["residue_embs"][identifier] = emb.detach().cpu().numpy().squeeze()
96
+ print("emb_count:", len(results["residue_embs"]))
97
+
98
+ passed_time = time.time() - start
99
+ avg_time = passed_time / len(results["residue_embs"]) if per_residue else passed_time / len(results["protein_embs"])
100
+ print('\n############# EMBEDDING STATS #############')
101
+ print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"])))
102
+ print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
103
+ passed_time / 60, avg_time))
104
+ print('\n############# END #############')
105
+ return results
106
+
107
+
108
+ def __create_dir(path):
109
+ if not os.path.exists(path):
110
+ os.makedirs(path)