Upload 21 files
Browse files- bin/README.md +1 -0
- demo/__init__.py +0 -0
- demo/config.yaml +49 -0
- demo/modules/__init__.py +19 -0
- demo/modules/blocks.py +66 -0
- demo/modules/compute_score.py +127 -0
- demo/modules/init_model.py +118 -0
- demo/modules/search.py +304 -0
- demo/modules/tmalign.py +78 -0
- demo/run.py +22 -0
- model/ProTrek/protein_encoder.py +95 -0
- model/ProTrek/protrek_trimodal_model.py +874 -0
- model/ProTrek/structure_encoder.py +86 -0
- model/ProTrek/text_encoder.py +81 -0
- model/abstract_model.py +401 -0
- model/model_interface.py +104 -0
- utils/constants.py +54 -0
- utils/downloader.py +180 -0
- utils/foldseek_util.py +121 -0
- utils/lr_scheduler.py +187 -0
- utils/mpr.py +397 -0
bin/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Place the Foldseek binary file here
|
demo/__init__.py
ADDED
File without changes
|
demo/config.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_dir: /sujin/Models/ProTrek/ProTrek_650M_UniRef50
|
2 |
+
faiss_config:
|
3 |
+
IO_FLAG_MMAP: True
|
4 |
+
sequence_index_dir:
|
5 |
+
- name: UniRef50
|
6 |
+
index_dir: /mnt/5t/faiss_index/UniRef50/ProTrek_650M_UniRef50/sequence
|
7 |
+
- name: Swiss-Prot
|
8 |
+
index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/sequence
|
9 |
+
- name: PDB
|
10 |
+
index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/sequence
|
11 |
+
- name: Uncharacterized
|
12 |
+
index_dir: /mnt/5t/faiss_index/Uncharacterized/ProTrek_650M_UniRef50/sequence
|
13 |
+
|
14 |
+
structure_index_dir:
|
15 |
+
- name: Swiss-Prot
|
16 |
+
index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/structure
|
17 |
+
- name: PDB
|
18 |
+
index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/structure
|
19 |
+
|
20 |
+
text_index_dir:
|
21 |
+
- name: UniProt
|
22 |
+
index_dir: /mnt/5t/faiss_index/UniRef50/ProTrek_650M_UniRef50/text
|
23 |
+
- name: Swiss-Prot
|
24 |
+
index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/text
|
25 |
+
|
26 |
+
#model_dir: /sujin/Models/ProTrek/ProTrek_35M_UniRef50
|
27 |
+
#
|
28 |
+
#faiss_config:
|
29 |
+
# IO_FLAG_MMAP: True
|
30 |
+
#
|
31 |
+
#sequence_index_dir:
|
32 |
+
## - name: UniRef50
|
33 |
+
## index_dir: /sujin/Datasets/ProTrek/faiss_index/UniRef50/ProTrek_650M_UniRef50/sequence
|
34 |
+
# - name: Swiss-Prot
|
35 |
+
# index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/sequence
|
36 |
+
## - name: PDB
|
37 |
+
## index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/sequence
|
38 |
+
#
|
39 |
+
#structure_index_dir:
|
40 |
+
# - name: Swiss-Prot
|
41 |
+
# index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/structure
|
42 |
+
## - name: PDB
|
43 |
+
## index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/structure
|
44 |
+
#
|
45 |
+
#text_index_dir:
|
46 |
+
## - name: UniProt
|
47 |
+
## index_dir: /sujin/Datasets/ProTrek/faiss_index/UniRef50/ProTrek_650M_UniRef50/text
|
48 |
+
# - name: Swiss-Prot
|
49 |
+
# index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/text
|
demo/modules/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path += []
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
def get_args():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
return parser.parse_args()
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
args = get_args()
|
19 |
+
main()
|
demo/modules/blocks.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from utils.foldseek_util import get_struc_seq
|
4 |
+
|
5 |
+
|
6 |
+
####################################################
|
7 |
+
# gradio blocks #
|
8 |
+
####################################################
|
9 |
+
def upload_pdb_button(visible: bool = True, chain_visible: bool = True):
|
10 |
+
"""
|
11 |
+
Provide an upload button to upload a pdb file
|
12 |
+
Args:
|
13 |
+
visible: Whether the block is visible or not
|
14 |
+
"""
|
15 |
+
|
16 |
+
with gr.Column(scale=0):
|
17 |
+
|
18 |
+
# Which chain to be extracted
|
19 |
+
chain_box = gr.Textbox(label="Chain (to be extracted from the pdb file)", value="A",
|
20 |
+
visible=chain_visible, interactive=True)
|
21 |
+
|
22 |
+
upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", visible=visible)
|
23 |
+
|
24 |
+
return upload_btn, chain_box
|
25 |
+
|
26 |
+
|
27 |
+
####################################################
|
28 |
+
# Trigger functions #
|
29 |
+
####################################################
|
30 |
+
def parse_pdb_file(input_type: str, file: str, chain: str) -> str:
|
31 |
+
"""
|
32 |
+
Parse the uploaded structure file
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input_type: Type of input. Must be one of ["protein sequence", "protein structure"]
|
36 |
+
|
37 |
+
file: Path to the uploaded file
|
38 |
+
|
39 |
+
chain: Chain to be extracted from the pdb file
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Protein sequence or Foldseek sequence
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
parsed_seqs = get_struc_seq("bin/foldseek", file, [chain])[chain]
|
46 |
+
if input_type == "sequence":
|
47 |
+
return parsed_seqs[0]
|
48 |
+
else:
|
49 |
+
return parsed_seqs[1].lower()
|
50 |
+
|
51 |
+
except Exception:
|
52 |
+
raise gr.Error(f"Chain '{chain}' not found in the pdb file. Please check the chain id and try again.")
|
53 |
+
|
54 |
+
|
55 |
+
def set_upload_visible(visible: bool) -> gr.Interface:
|
56 |
+
"""
|
57 |
+
Set the visibility of the upload button
|
58 |
+
|
59 |
+
Args:
|
60 |
+
visible: Whether the block is visible or not
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
gr.Interface: Updated interface
|
64 |
+
"""
|
65 |
+
|
66 |
+
return gr.update(visible=visible)
|
demo/modules/compute_score.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .init_model import model
|
5 |
+
from .blocks import upload_pdb_button, parse_pdb_file
|
6 |
+
|
7 |
+
|
8 |
+
input_types = ["sequence", "structure", "text"]
|
9 |
+
|
10 |
+
input_examples = {
|
11 |
+
"sequence": [
|
12 |
+
"MQLQRLGAPLLKRLVGGCIRQSTAPIMPCVVVSGSGGFLTPVRTYMPLPNDQSDFSPYIEIDLPSESRIQSLHKSGLAAQEWVACEKVHGTNFGIYLINQGDHEVVRFAKRSGIMDPNENFFGYHILIDEFTAQIRILNDLLKQKYGLSRVGRLVLNGELFGAKYKHPLVPKSEKWCTLPNGKKFPIAGVQIQREPFPQYSPELHFFAFDIKYSVSGAEEDFVLLGYDEFVEFSSKVPNLLYARALVRGTLDECLAFDVENFMTPLPALLGLGNYPLEGNLAEGVVIRHVRRGDPAVEKHNVSTIIKLRCSSFMELKHPGKQKELKETFIDTVRSGALRRVRGNVTVISDSMLPQVEAAANDLLLNNVSDGRLSNVLSKIGREPLLSGEVSQVDVALMLAKDALKDFLKEVDSLVLNTTLAFRKLLITNVYFESKRLVEQKWKELMQEEAAAQSEAIPPLSPAAPTKGE",
|
13 |
+
"MSLSTEQMLRDYPRSMQINGQIPKNAIHETYGNDGVDVFIAGSGPIGATYAKLCVEAGLRVVMVEIGAADSFYAVNAEEGTAVPYVPGYHKKNEIEFQKDIDRFVNVIKGALQQVSVPVRNQNVPTLDPGAWSAPPGSSAISNGKNPHQREFENLSAEAVTRGVGGMSTHWTCSTPRIHPPMESLPGIGRPKLSNDPAEDDKEWNELYSEAERLIGTSTKEFDESIRHTLVLRSLQDAYKDRQRIFRPLPLACHRLKNAPEYVEWHSAENLFHSIYNDDKQKKLFTLLTNHRCTRLALTGGYEKKIGAAEVRNLLATRNPSSQLDSYIMAKVYVLASGAIGNPQILYNSGFSGLQVTPRNDSLIPNLGRYITEQPMAFCQIVLRQEFVDSVRDDPYGLPWWKEAVAQHIAKNPTDALPIPFRDPEPQVTTPFTEEHPWHTQIHRDAFSYGAVGPEVDSRVIVDLRWFGATDPEANNLLVFQNDVQDGYSMPQPTFRYRPSTASNVRARKMMADMCEVASNLGGYLPTSPPQFMDPGLALHLAGTTRIGFDKATTVADNNSLVWDFANLYVAGNGTIRTGFGENPTLTSMCHAIKSARSIINTLKGGTDGKNTGEHRNL",
|
14 |
+
"MGVHECPAWLWLLLSLLSLPLGLPVLGAPPRLICDSRVLERYLLEAKEAENITTGCAEHCSLNENITVPDTKVNFYAWKRMEVGQQAVEVWQGLALLSEAVLRGQALLVNSSQPWEPLQLHVDKAVSGLRSLTTLLRALGAQKEAISPPDAASAAPLRTITADTFRKLFRVYSNFLRGKLKLYTGEACRTGDR"
|
15 |
+
],
|
16 |
+
|
17 |
+
"structure": [
|
18 |
+
"ddddddddddddddddddddddddddddddddpdpddpddpqpdddfddpdqqlddadddfaaddpvqvvlcvvvvvlqakkfkwfdadffkkkwkwadpdpdidifidtnvgtdglqpddllclvcvvlsvqlvvllqvvvcvvvvapafrmkmfiwgkdalddpfppadadpdwhagsvgdidgsvpgdrdddpaqhahsdiaietewiwiarnsdpvriqtafqvvvcvsqvprpphhyidgqfmggnllnlldpqqpaaqlrnqqvvnqvgddpprggqfikmfrrpprppvvcvsvrhgihtdghlvnvcvvdppcsvvcccnrcvprnvvscvvvvndhdtdvlsrhhpvlsvllvqllvlldpvllvvldvvvdlpclqvvvqdllnsllsslvvsvvvsvvpddpvnvpgdpvsvvvssvsssvsssvvsvvcvvvvnvvsvvvvvvvddppdpdddpddd",
|
19 |
+
"dpdplvvqppdddplqappppfaadpvcvlvdpvaaaeeeeaqallsllllllclvlvgfyeyefqaeqpdwdddpddvpdddftqtqfapcqppvclqpqqvllvvqvvfwdwqeaefdqpppvpddppddhddppdgdddqqhdppfdpqqdlgqatwgghrntcqnhdpqfddawadadpvahqgtfdaldpdpvvrvvlvvvllvvlcvqlvkdqclqvpflqqcllqvllcvvcvvppwhkgggtgswhadpvhsldirhttsssscvvqrvdpssvssydyhyskhqqewhaghdpfgetawtkiarnccvvpvpdrgihigghrfyeypralprvllrcvssvqalqdpggdprhnqdqffalkwfwwkkkfkfffdpvsqvcqcvppppdpssnvqlvvqcvvcvpdpgsgdssrakhfmwtdadpvqqktktwidghhndddddppddpsrmimimiihwafrdrqfgwgfdppgdhpvrttrihtrddgdpvsvvsvvvrlvvsvvssvstgdtdprgpididrrnsvnlieqrqaedddsvngqayqlqhgpsyphygyfdrnhrngigngdcvsvrssssvsnsvvsscvvvvdpdddppdddddd",
|
20 |
+
"ddppppdcvvvvvvvvvppppppvppldplvvlldvvllvvqlvllvvllvvcvvpdpnfflqdwqkafdlddpvvvvvpddlllllqlllvrlvsllvrlvsslvslvpdpdrdvvnnvssvvlnvssvvvnvssvslvsvvsnppddppprdddgdididrgssvssvsvssnsvgsvvvssvvssvvvvd"
|
21 |
+
],
|
22 |
+
|
23 |
+
"text": [
|
24 |
+
"RNA-editing ligase in kinetoplastid mitochondrial.",
|
25 |
+
"Oxidase which catalyzes the oxidation of various aldopyranoses and disaccharides.",
|
26 |
+
"Erythropoietin for regulation of erythrocyte proliferation and differentiation."
|
27 |
+
]
|
28 |
+
}
|
29 |
+
|
30 |
+
samples = [[s1, s2] for s1, s2 in zip(input_examples["sequence"], input_examples["text"])]
|
31 |
+
|
32 |
+
|
33 |
+
def compute_score(input_type_1: str, input_1: str, input_type_2: str, input_2: str):
|
34 |
+
with torch.no_grad():
|
35 |
+
input_reprs = []
|
36 |
+
|
37 |
+
for input_type, input in [(input_type_1, input_1), (input_type_2, input_2)]:
|
38 |
+
if input_type == "sequence":
|
39 |
+
input_reprs.append(model.get_protein_repr([input]))
|
40 |
+
|
41 |
+
elif input_type == "structure":
|
42 |
+
input_reprs.append(model.get_structure_repr([input]))
|
43 |
+
|
44 |
+
else:
|
45 |
+
input_reprs.append(model.get_text_repr([input]))
|
46 |
+
|
47 |
+
score = input_reprs[0] @ input_reprs[1].T / model.temperature
|
48 |
+
|
49 |
+
return f"{score.item():.4f}"
|
50 |
+
|
51 |
+
|
52 |
+
def change_input_type(choice_1: str, choice_2: str):
|
53 |
+
examples_1 = input_examples[choice_1]
|
54 |
+
examples_2 = input_examples[choice_2]
|
55 |
+
|
56 |
+
# Change examples if input type is changed
|
57 |
+
global samples
|
58 |
+
samples = [[s1, s2] for s1, s2 in zip(examples_1, examples_2)]
|
59 |
+
|
60 |
+
# Set visibility of upload button
|
61 |
+
if choice_1 == "text":
|
62 |
+
visible_1 = False
|
63 |
+
else:
|
64 |
+
visible_1 = True
|
65 |
+
|
66 |
+
if choice_2 == "text":
|
67 |
+
visible_2 = False
|
68 |
+
else:
|
69 |
+
visible_2 = True
|
70 |
+
|
71 |
+
return (gr.update(samples=samples), "", "", gr.update(visible=visible_1), gr.update(visible=visible_1),
|
72 |
+
gr.update(visible=visible_2), gr.update(visible=visible_2))
|
73 |
+
|
74 |
+
|
75 |
+
# Load example from dataset
|
76 |
+
def load_example(example_id):
|
77 |
+
return samples[example_id]
|
78 |
+
|
79 |
+
|
80 |
+
# Build the block for computing protein-text similarity
|
81 |
+
def build_score_computation():
|
82 |
+
gr.Markdown(f"# Compute similarity score between two modalities")
|
83 |
+
with gr.Row(equal_height=True):
|
84 |
+
with gr.Column():
|
85 |
+
# Compute similarity score between sequence and text
|
86 |
+
with gr.Row():
|
87 |
+
input_1 = gr.Textbox(label="Input 1")
|
88 |
+
|
89 |
+
# Choose the type of input 1
|
90 |
+
input_type_1 = gr.Dropdown(input_types, label="Input type", value="sequence",
|
91 |
+
interactive=True, visible=True)
|
92 |
+
|
93 |
+
# Provide an upload button to upload a pdb file
|
94 |
+
upload_btn_1, chain_box_1 = upload_pdb_button(visible=True)
|
95 |
+
upload_btn_1.upload(parse_pdb_file, inputs=[input_type_1, upload_btn_1, chain_box_1], outputs=[input_1])
|
96 |
+
|
97 |
+
with gr.Row():
|
98 |
+
input_2 = gr.Textbox(label="Input 2")
|
99 |
+
|
100 |
+
# Choose the type of input 2
|
101 |
+
input_type_2 = gr.Dropdown(input_types, label="Input type", value="text",
|
102 |
+
interactive=True, visible=True)
|
103 |
+
|
104 |
+
# Provide an upload button to upload a pdb file
|
105 |
+
upload_btn_2, chain_box_2 = upload_pdb_button(visible=False)
|
106 |
+
upload_btn_2.upload(parse_pdb_file, inputs=[input_type_2, upload_btn_2, chain_box_2], outputs=[input_2])
|
107 |
+
|
108 |
+
# Provide examples
|
109 |
+
examples = gr.Dataset(samples=samples, type="index", components=[input_1, input_2], label="Input examples")
|
110 |
+
|
111 |
+
# Add click event to examples
|
112 |
+
examples.click(fn=load_example, inputs=[examples], outputs=[input_1, input_2])
|
113 |
+
|
114 |
+
compute_btn = gr.Button(value="Compute")
|
115 |
+
|
116 |
+
# Change examples based on input type
|
117 |
+
input_type_1.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
|
118 |
+
outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
|
119 |
+
upload_btn_2, chain_box_2])
|
120 |
+
|
121 |
+
input_type_2.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
|
122 |
+
outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
|
123 |
+
upload_btn_2, chain_box_2])
|
124 |
+
|
125 |
+
similarity_score = gr.Label(label="similarity score")
|
126 |
+
compute_btn.click(fn=compute_score, inputs=[input_type_1, input_1, input_type_2, input_2],
|
127 |
+
outputs=[similarity_score])
|
demo/modules/init_model.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import os
|
5 |
+
import yaml
|
6 |
+
import glob
|
7 |
+
|
8 |
+
from easydict import EasyDict
|
9 |
+
from utils.constants import sequence_level
|
10 |
+
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def load_model():
|
15 |
+
model_config = {
|
16 |
+
"protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
|
17 |
+
"text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
|
18 |
+
"structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
|
19 |
+
"load_protein_pretrained": False,
|
20 |
+
"load_text_pretrained": False,
|
21 |
+
"from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
|
22 |
+
}
|
23 |
+
|
24 |
+
model = ProTrekTrimodalModel(**model_config)
|
25 |
+
model.eval()
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
def load_faiss_index(index_path: str):
|
30 |
+
if config.faiss_config.IO_FLAG_MMAP:
|
31 |
+
index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
|
32 |
+
else:
|
33 |
+
index = faiss.read_index(index_path)
|
34 |
+
|
35 |
+
index.metric_type = faiss.METRIC_INNER_PRODUCT
|
36 |
+
return index
|
37 |
+
|
38 |
+
|
39 |
+
def load_index():
|
40 |
+
all_index = {}
|
41 |
+
|
42 |
+
# Load protein sequence index
|
43 |
+
all_index["sequence"] = {}
|
44 |
+
for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
|
45 |
+
db_name = db["name"]
|
46 |
+
index_dir = db["index_dir"]
|
47 |
+
|
48 |
+
index_path = f"{index_dir}/sequence.index"
|
49 |
+
sequence_index = load_faiss_index(index_path)
|
50 |
+
|
51 |
+
id_path = f"{index_dir}/ids.tsv"
|
52 |
+
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
|
53 |
+
|
54 |
+
all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
|
55 |
+
|
56 |
+
# Load protein structure index
|
57 |
+
print("Loading structure index...")
|
58 |
+
all_index["structure"] = {}
|
59 |
+
for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
|
60 |
+
db_name = db["name"]
|
61 |
+
index_dir = db["index_dir"]
|
62 |
+
|
63 |
+
index_path = f"{index_dir}/structure.index"
|
64 |
+
structure_index = load_faiss_index(index_path)
|
65 |
+
|
66 |
+
id_path = f"{index_dir}/ids.tsv"
|
67 |
+
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
|
68 |
+
|
69 |
+
all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
|
70 |
+
|
71 |
+
# Load text index
|
72 |
+
all_index["text"] = {}
|
73 |
+
valid_subsections = {}
|
74 |
+
for db in tqdm(config.text_index_dir, desc="Loading text index..."):
|
75 |
+
db_name = db["name"]
|
76 |
+
index_dir = db["index_dir"]
|
77 |
+
all_index["text"][db_name] = {}
|
78 |
+
text_dir = f"{index_dir}/subsections"
|
79 |
+
|
80 |
+
# Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
|
81 |
+
valid_subsections[db_name] = set()
|
82 |
+
sequence_level.add("Global")
|
83 |
+
for subsection in tqdm(sequence_level):
|
84 |
+
index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
|
85 |
+
if not os.path.exists(index_path):
|
86 |
+
continue
|
87 |
+
|
88 |
+
text_index = load_faiss_index(index_path)
|
89 |
+
|
90 |
+
id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
|
91 |
+
text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
|
92 |
+
|
93 |
+
all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
|
94 |
+
valid_subsections[db_name].add(subsection)
|
95 |
+
|
96 |
+
# Sort valid_subsections
|
97 |
+
for db_name in valid_subsections:
|
98 |
+
valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
|
99 |
+
|
100 |
+
return all_index, valid_subsections
|
101 |
+
|
102 |
+
|
103 |
+
# Load the config file
|
104 |
+
root_dir = __file__.rsplit("/", 3)[0]
|
105 |
+
config_path = f"{root_dir}/demo/config.yaml"
|
106 |
+
with open(config_path, 'r', encoding='utf-8') as r:
|
107 |
+
config = EasyDict(yaml.safe_load(r))
|
108 |
+
|
109 |
+
device = "cuda"
|
110 |
+
|
111 |
+
print("Loading model...")
|
112 |
+
model = load_model()
|
113 |
+
model.to(device)
|
114 |
+
|
115 |
+
all_index, valid_subsections = load_index()
|
116 |
+
print("Done...")
|
117 |
+
# model = None
|
118 |
+
# all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}
|
demo/modules/search.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from scipy.stats import norm
|
8 |
+
from .init_model import model, all_index, valid_subsections
|
9 |
+
from .blocks import upload_pdb_button, parse_pdb_file
|
10 |
+
|
11 |
+
|
12 |
+
tmp_file_path = "/tmp/results.tsv"
|
13 |
+
tmp_plot_path = "/tmp/histogram.svg"
|
14 |
+
|
15 |
+
# Samples for input
|
16 |
+
samples = [
|
17 |
+
["Proteins with zinc bindings."],
|
18 |
+
["Proteins locating at cell membrane."],
|
19 |
+
["Protein that serves as an enzyme."]
|
20 |
+
]
|
21 |
+
|
22 |
+
# Databases for different modalities
|
23 |
+
now_db = {
|
24 |
+
"sequence": list(all_index["sequence"].keys())[0],
|
25 |
+
"structure": list(all_index["structure"].keys())[0],
|
26 |
+
"text": list(all_index["text"].keys())[0]
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def clear_results():
|
31 |
+
return "", gr.update(visible=False), gr.update(visible=False)
|
32 |
+
|
33 |
+
|
34 |
+
def plot(scores) -> None:
|
35 |
+
"""
|
36 |
+
Plot the distribution of scores and fit a normal distribution.
|
37 |
+
Args:
|
38 |
+
scores: List of scores
|
39 |
+
"""
|
40 |
+
plt.hist(scores, bins=100, density=True, alpha=0.6)
|
41 |
+
plt.title('Distribution of similarity scores in the database', fontsize=15)
|
42 |
+
plt.xlabel('Similarity score', fontsize=15)
|
43 |
+
plt.ylabel('Density', fontsize=15)
|
44 |
+
|
45 |
+
mu, std = norm.fit(scores)
|
46 |
+
|
47 |
+
# Plot the Gaussian
|
48 |
+
xmin, xmax = plt.xlim()
|
49 |
+
_, ymax = plt.ylim()
|
50 |
+
x = np.linspace(xmin, xmax, 100)
|
51 |
+
p = norm.pdf(x, mu, std)
|
52 |
+
plt.plot(x, p)
|
53 |
+
|
54 |
+
# Plot total number of scores
|
55 |
+
plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
|
56 |
+
|
57 |
+
# Convert the plot to svg format
|
58 |
+
plt.savefig(tmp_plot_path)
|
59 |
+
plt.cla()
|
60 |
+
|
61 |
+
|
62 |
+
# Search from database
|
63 |
+
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
|
64 |
+
input_modality = input_type.replace("sequence", "protein")
|
65 |
+
with torch.no_grad():
|
66 |
+
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
67 |
+
|
68 |
+
db = now_db[query_type]
|
69 |
+
if query_type == "text":
|
70 |
+
index = all_index["text"][db][subsection_type]["index"]
|
71 |
+
ids = all_index["text"][db][subsection_type]["ids"]
|
72 |
+
|
73 |
+
else:
|
74 |
+
index = all_index[query_type][db]["index"]
|
75 |
+
ids = all_index[query_type][db]["ids"]
|
76 |
+
|
77 |
+
if check_index_ivf(query_type, subsection_type):
|
78 |
+
if index.nlist < nprobe:
|
79 |
+
raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
|
80 |
+
else:
|
81 |
+
index.nprobe = nprobe
|
82 |
+
|
83 |
+
if topk > index.ntotal:
|
84 |
+
raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
|
85 |
+
|
86 |
+
# Retrieve all scores to plot the distribution
|
87 |
+
scores, ranks = index.search(input_embedding, index.ntotal)
|
88 |
+
scores, ranks = scores[0], ranks[0]
|
89 |
+
|
90 |
+
# Remove inf values
|
91 |
+
selector = scores > -1
|
92 |
+
scores = scores[selector]
|
93 |
+
ranks = ranks[selector]
|
94 |
+
scores = scores / model.temperature.item()
|
95 |
+
plot(scores)
|
96 |
+
|
97 |
+
top_scores = scores[:topk]
|
98 |
+
top_ranks = ranks[:topk]
|
99 |
+
|
100 |
+
# ranks = [list(range(topk))]
|
101 |
+
# ids = ["P12345"] * topk
|
102 |
+
# scores = torch.randn(topk).tolist()
|
103 |
+
|
104 |
+
# Write the results to a temporary file for downloading
|
105 |
+
with open(tmp_file_path, "w") as w:
|
106 |
+
w.write("Id\tMatching score\n")
|
107 |
+
for i in range(topk):
|
108 |
+
rank = top_ranks[i]
|
109 |
+
w.write(f"{ids[rank]}\t{top_scores[i]}\n")
|
110 |
+
|
111 |
+
# Get topk ids
|
112 |
+
topk_ids = []
|
113 |
+
for rank in top_ranks:
|
114 |
+
now_id = ids[rank]
|
115 |
+
if query_type == "text":
|
116 |
+
topk_ids.append(now_id)
|
117 |
+
else:
|
118 |
+
if db != "PDB":
|
119 |
+
# Provide link to uniprot website
|
120 |
+
topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
|
121 |
+
else:
|
122 |
+
# Provide link to pdb website
|
123 |
+
pdb_id = now_id.split("-")[0]
|
124 |
+
topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
|
125 |
+
|
126 |
+
limit = 1000
|
127 |
+
df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
|
128 |
+
if len(topk_ids) > limit:
|
129 |
+
info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
|
130 |
+
index=[1000])
|
131 |
+
df = pd.concat([df, info_df], axis=0)
|
132 |
+
|
133 |
+
output = df.to_markdown()
|
134 |
+
return (output,
|
135 |
+
gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
|
136 |
+
gr.update(value=tmp_plot_path, visible=True))
|
137 |
+
|
138 |
+
|
139 |
+
def change_input_type(choice: str):
|
140 |
+
# Change examples if input type is changed
|
141 |
+
global samples
|
142 |
+
if choice == "text":
|
143 |
+
samples = [
|
144 |
+
["Proteins with zinc bindings."],
|
145 |
+
["Proteins locating at cell membrane."],
|
146 |
+
["Protein that serves as an enzyme."]
|
147 |
+
]
|
148 |
+
|
149 |
+
elif choice == "sequence":
|
150 |
+
samples = [
|
151 |
+
["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
|
152 |
+
["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
|
153 |
+
["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
|
154 |
+
]
|
155 |
+
|
156 |
+
elif choice == "structure":
|
157 |
+
samples = [
|
158 |
+
["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
|
159 |
+
["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
|
160 |
+
["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
|
161 |
+
]
|
162 |
+
|
163 |
+
# Set visibility of upload button
|
164 |
+
if choice == "text":
|
165 |
+
visible = False
|
166 |
+
else:
|
167 |
+
visible = True
|
168 |
+
|
169 |
+
return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
|
170 |
+
|
171 |
+
|
172 |
+
# Load example from dataset
|
173 |
+
def load_example(example_id):
|
174 |
+
return samples[example_id][0]
|
175 |
+
|
176 |
+
|
177 |
+
# Change the visibility of subsection type
|
178 |
+
def change_output_type(query_type: str, subsection_type: str):
|
179 |
+
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
180 |
+
subsection_visible = True if query_type == "text" else False
|
181 |
+
|
182 |
+
return (
|
183 |
+
gr.update(visible=subsection_visible),
|
184 |
+
gr.update(visible=nprobe_visible),
|
185 |
+
gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
|
190 |
+
"""
|
191 |
+
Check if the index is of IVF type.
|
192 |
+
Args:
|
193 |
+
index_type: Type of index.
|
194 |
+
subsection_type: If the "index_type" is "text", get the index based on the subsection type.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
Whether the index is of IVF type or not.
|
198 |
+
"""
|
199 |
+
db = now_db[index_type]
|
200 |
+
if index_type == "sequence":
|
201 |
+
index = all_index["sequence"][db]["index"]
|
202 |
+
|
203 |
+
elif index_type == "structure":
|
204 |
+
index = all_index["structure"][db]["index"]
|
205 |
+
|
206 |
+
elif index_type == "text":
|
207 |
+
index = all_index["text"][db][subsection_type]["index"]
|
208 |
+
|
209 |
+
nprobe_visible = True if hasattr(index, "nprobe") else False
|
210 |
+
return nprobe_visible
|
211 |
+
|
212 |
+
|
213 |
+
def change_db_type(query_type: str, subsection_type: str, db_type: str):
|
214 |
+
"""
|
215 |
+
Change the database to search.
|
216 |
+
Args:
|
217 |
+
query_type: The output type.
|
218 |
+
db_type: The database to search.
|
219 |
+
"""
|
220 |
+
now_db[query_type] = db_type
|
221 |
+
|
222 |
+
if query_type == "text":
|
223 |
+
subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
|
224 |
+
else:
|
225 |
+
subsection_update = gr.update(visible=False)
|
226 |
+
|
227 |
+
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
228 |
+
return subsection_update, gr.update(visible=nprobe_visible)
|
229 |
+
|
230 |
+
|
231 |
+
# Build the searching block
|
232 |
+
def build_search_module():
|
233 |
+
gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)")
|
234 |
+
with gr.Row(equal_height=True):
|
235 |
+
with gr.Column():
|
236 |
+
# Set input type
|
237 |
+
input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
# Set output type
|
241 |
+
query_type = gr.Radio(
|
242 |
+
["sequence", "structure", "text"],
|
243 |
+
label="Output type (e.g. 'sequence' means returning qualified sequences)",
|
244 |
+
value="sequence",
|
245 |
+
scale=2,
|
246 |
+
)
|
247 |
+
|
248 |
+
# If the output type is "text", provide an option to choose the subsection of text
|
249 |
+
subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
|
250 |
+
interactive=True, visible=False, scale=0)
|
251 |
+
|
252 |
+
db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
|
253 |
+
interactive=True, visible=True, scale=0)
|
254 |
+
|
255 |
+
with gr.Row():
|
256 |
+
# Input box
|
257 |
+
input = gr.Text(label="Input")
|
258 |
+
|
259 |
+
# Provide an upload button to upload a pdb file
|
260 |
+
upload_btn, chain_box = upload_pdb_button(visible=False)
|
261 |
+
upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
|
262 |
+
|
263 |
+
|
264 |
+
# If the index is of IVF type, provide an option to choose the number of clusters.
|
265 |
+
nprobe_visible = check_index_ivf(query_type.value)
|
266 |
+
nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
|
267 |
+
label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
|
268 |
+
|
269 |
+
# Add event listener to output type
|
270 |
+
query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
|
271 |
+
outputs=[subsection_type, nprobe, db_type])
|
272 |
+
|
273 |
+
# Add event listener to db type
|
274 |
+
db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
|
275 |
+
outputs=[subsection_type, nprobe])
|
276 |
+
|
277 |
+
# Choose topk results
|
278 |
+
topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
|
279 |
+
|
280 |
+
# Provide examples
|
281 |
+
examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
|
282 |
+
|
283 |
+
# Add click event to examples
|
284 |
+
examples.click(fn=load_example, inputs=[examples], outputs=input)
|
285 |
+
|
286 |
+
# Change examples based on input type
|
287 |
+
input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
|
288 |
+
|
289 |
+
with gr.Row():
|
290 |
+
search_btn = gr.Button(value="Search")
|
291 |
+
clear_btn = gr.Button(value="Clear")
|
292 |
+
|
293 |
+
with gr.Row():
|
294 |
+
with gr.Column():
|
295 |
+
results = gr.Markdown(label="results", height=450)
|
296 |
+
download_btn = gr.DownloadButton(label="Download results", visible=False)
|
297 |
+
|
298 |
+
# Plot the distribution of scores
|
299 |
+
histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
|
300 |
+
|
301 |
+
search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
|
302 |
+
outputs=[results, download_btn, histogram])
|
303 |
+
|
304 |
+
clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
|
demo/modules/tmalign.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
from .blocks import upload_pdb_button
|
5 |
+
from utils.downloader import download_pdb, download_af2
|
6 |
+
|
7 |
+
|
8 |
+
root_dir = __file__.rsplit("/", 3)[0]
|
9 |
+
structure_types = ["AlphaFoldDB", "PDB"]
|
10 |
+
|
11 |
+
|
12 |
+
def upload_structure(file: str):
|
13 |
+
return file
|
14 |
+
|
15 |
+
|
16 |
+
def get_structure_path(structure: str, structure_type: str) -> str:
|
17 |
+
# If the structure is manually uploaded
|
18 |
+
if structure[0] == "/":
|
19 |
+
return structure
|
20 |
+
|
21 |
+
# If the structure is a Uniprot ID, download the structure from AlphaFoldDB
|
22 |
+
elif structure_type == "AlphaFoldDB":
|
23 |
+
save_path = f"{root_dir}/demo/cache/{structure}.pdb"
|
24 |
+
if not os.path.exists(save_path):
|
25 |
+
download_af2(structure, "pdb", save_path)
|
26 |
+
return save_path
|
27 |
+
|
28 |
+
# If the structure is a PDB ID, download the structure from PDB
|
29 |
+
elif structure_type == "PDB":
|
30 |
+
save_path = f"{root_dir}/demo/cache/{structure}.cif"
|
31 |
+
if not os.path.exists(save_path):
|
32 |
+
download_pdb(structure, "cif", save_path)
|
33 |
+
return save_path
|
34 |
+
|
35 |
+
|
36 |
+
def tmalign(structure_1: str, structure_type_1: str, structure_2: str, structure_type_2: str):
|
37 |
+
structure_path_1 = get_structure_path(structure_1, structure_type_1)
|
38 |
+
structure_path_2 = get_structure_path(structure_2, structure_type_2)
|
39 |
+
|
40 |
+
cmd = f"bin/TMalign {structure_path_1} {structure_path_2}"
|
41 |
+
|
42 |
+
r = os.popen(cmd)
|
43 |
+
text = r.read()
|
44 |
+
return text
|
45 |
+
|
46 |
+
|
47 |
+
# Build the block for computing protein-text similarity
|
48 |
+
def build_TMalign():
|
49 |
+
gr.Markdown(f"# Calculate TM-score between two protein structures")
|
50 |
+
with gr.Row(equal_height=True):
|
51 |
+
with gr.Column():
|
52 |
+
# Compute similarity score between sequence and text
|
53 |
+
with gr.Row():
|
54 |
+
structure_1 = gr.Textbox(label="Protein structure 1 (input Uniprot ID or PDB ID or upload a pdb file)")
|
55 |
+
|
56 |
+
structure_type_1 = gr.Dropdown(structure_types, label="Structure type (if the structure is manually uploaded, ignore this field)",
|
57 |
+
value="AlphaFoldDB", interactive=True, visible=True)
|
58 |
+
|
59 |
+
# Provide an upload button to upload a pdb file
|
60 |
+
upload_btn_1, _ = upload_pdb_button(visible=True, chain_visible=False)
|
61 |
+
upload_btn_1.upload(upload_structure, inputs=[upload_btn_1], outputs=[structure_1])
|
62 |
+
|
63 |
+
with gr.Row():
|
64 |
+
structure_2 = gr.Textbox(label="Protein structure 2 (input Uniprot ID or PDB ID or upload a pdb file)")
|
65 |
+
|
66 |
+
structure_type_2 = gr.Dropdown(structure_types, label="Structure type (if the structure is manually uploaded, ignore this field)",
|
67 |
+
value="AlphaFoldDB", interactive=True, visible=True)
|
68 |
+
|
69 |
+
# Provide an upload button to upload a pdb file
|
70 |
+
upload_btn_2, _ = upload_pdb_button(visible=True, chain_visible=False)
|
71 |
+
upload_btn_2.upload(upload_structure, inputs=[upload_btn_2], outputs=[structure_2])
|
72 |
+
|
73 |
+
compute_btn = gr.Button(value="Compute TM-score")
|
74 |
+
tmscore = gr.TextArea(label="TM-score", interactive=False)
|
75 |
+
|
76 |
+
compute_btn.click(tmalign, inputs=[structure_1, structure_type_1, structure_2, structure_type_2],
|
77 |
+
outputs=[tmscore])
|
78 |
+
|
demo/run.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
root_dir = __file__.rsplit("/", 2)[0]
|
3 |
+
if root_dir not in sys.path:
|
4 |
+
sys.path.append(root_dir)
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules.search import build_search_module
|
9 |
+
from modules.compute_score import build_score_computation
|
10 |
+
from modules.tmalign import build_TMalign
|
11 |
+
|
12 |
+
|
13 |
+
# Build demo
|
14 |
+
with gr.Blocks() as demo:
|
15 |
+
build_search_module()
|
16 |
+
build_score_computation()
|
17 |
+
build_TMalign()
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == '__main__':
|
21 |
+
# Run the demo
|
22 |
+
demo.launch()
|
model/ProTrek/protein_encoder.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from torch.nn.functional import normalize
|
5 |
+
from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
|
6 |
+
|
7 |
+
|
8 |
+
class ProteinEncoder(torch.nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
config_path: str,
|
11 |
+
out_dim: int,
|
12 |
+
load_pretrained: bool = True,
|
13 |
+
gradient_checkpointing: bool = False):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
config_path: Path to the config file
|
17 |
+
|
18 |
+
out_dim : Output dimension of the protein representation
|
19 |
+
|
20 |
+
load_pretrained: Whether to load pretrained weights
|
21 |
+
|
22 |
+
gradient_checkpointing: Whether to use gradient checkpointing
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
config = EsmConfig.from_pretrained(config_path)
|
26 |
+
if load_pretrained:
|
27 |
+
self.model = EsmForMaskedLM.from_pretrained(config_path)
|
28 |
+
else:
|
29 |
+
self.model = EsmForMaskedLM(config)
|
30 |
+
self.out = torch.nn.Linear(config.hidden_size, out_dim)
|
31 |
+
|
32 |
+
# Set gradient checkpointing
|
33 |
+
self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
|
34 |
+
|
35 |
+
# Remove contact head
|
36 |
+
self.model.esm.contact_head = None
|
37 |
+
|
38 |
+
# Remove position embedding if the embedding type is ``rotary``
|
39 |
+
if config.position_embedding_type == "rotary":
|
40 |
+
self.model.esm.embeddings.position_embeddings = None
|
41 |
+
|
42 |
+
self.tokenizer = EsmTokenizer.from_pretrained(config_path)
|
43 |
+
|
44 |
+
def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
45 |
+
"""
|
46 |
+
Compute protein representation for the given proteins
|
47 |
+
Args:
|
48 |
+
protein: A list of protein sequences
|
49 |
+
batch_size: Batch size for inference
|
50 |
+
verbose: Whether to print progress
|
51 |
+
"""
|
52 |
+
device = next(self.parameters()).device
|
53 |
+
|
54 |
+
protein_repr = []
|
55 |
+
if verbose:
|
56 |
+
iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
|
57 |
+
else:
|
58 |
+
iterator = range(0, len(proteins), batch_size)
|
59 |
+
|
60 |
+
for i in iterator:
|
61 |
+
protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
|
62 |
+
return_tensors="pt",
|
63 |
+
padding=True)
|
64 |
+
protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
|
65 |
+
output, _ = self.forward(protein_inputs)
|
66 |
+
|
67 |
+
protein_repr.append(output)
|
68 |
+
|
69 |
+
protein_repr = torch.cat(protein_repr, dim=0)
|
70 |
+
return normalize(protein_repr, dim=-1)
|
71 |
+
|
72 |
+
def forward(self, inputs: dict, get_mask_logits: bool = False):
|
73 |
+
"""
|
74 |
+
Encode protein sequence into protein representation
|
75 |
+
Args:
|
76 |
+
inputs: A dictionary containing the following keys:
|
77 |
+
- input_ids: [batch, seq_len]
|
78 |
+
- attention_mask: [batch, seq_len]
|
79 |
+
get_mask_logits: Whether to return the logits for masked tokens
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
protein_repr: [batch, protein_repr_dim]
|
83 |
+
mask_logits : [batch, seq_len, vocab_size]
|
84 |
+
"""
|
85 |
+
last_hidden_state = self.model.esm(**inputs).last_hidden_state
|
86 |
+
reprs = last_hidden_state[:, 0, :]
|
87 |
+
reprs = self.out(reprs)
|
88 |
+
|
89 |
+
# Get logits for masked tokens
|
90 |
+
if get_mask_logits:
|
91 |
+
mask_logits = self.model.lm_head(last_hidden_state)
|
92 |
+
else:
|
93 |
+
mask_logits = None
|
94 |
+
|
95 |
+
return reprs, mask_logits
|
model/ProTrek/protrek_trimodal_model.py
ADDED
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
import torchmetrics
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import copy
|
9 |
+
import faiss
|
10 |
+
import time
|
11 |
+
import pandas as pd
|
12 |
+
import random
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
from .protein_encoder import ProteinEncoder
|
16 |
+
from .structure_encoder import StructureEncoder
|
17 |
+
from .text_encoder import TextEncoder
|
18 |
+
from ..abstract_model import AbstractModel
|
19 |
+
from ..model_interface import register_model
|
20 |
+
from utils.mpr import MultipleProcessRunnerSimplifier
|
21 |
+
from torch.nn.functional import normalize, cross_entropy
|
22 |
+
from utils.constants import residue_level, sequence_level
|
23 |
+
from sklearn.metrics import roc_auc_score
|
24 |
+
|
25 |
+
|
26 |
+
def multilabel_cross_entropy(logits, labels):
|
27 |
+
"""
|
28 |
+
Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf"
|
29 |
+
Args:
|
30 |
+
logits: [num_samples, num_classes]
|
31 |
+
labels: [num_samples, num_classes]
|
32 |
+
"""
|
33 |
+
|
34 |
+
loss = 0
|
35 |
+
for pred, label in zip(logits, labels):
|
36 |
+
pos_logits = pred[label == 1]
|
37 |
+
neg_logits = pred[label == 0]
|
38 |
+
|
39 |
+
diff = neg_logits.unsqueeze(-1) - pos_logits
|
40 |
+
loss += torch.log(1 + torch.exp(diff).sum())
|
41 |
+
|
42 |
+
return loss / len(logits)
|
43 |
+
|
44 |
+
# pred = (1 - 2 * labels) * logits
|
45 |
+
# pred_neg = pred - labels * 1e12
|
46 |
+
# pred_pos = pred - (1 - labels) * 1e12
|
47 |
+
#
|
48 |
+
# zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype)
|
49 |
+
# pred_neg = torch.cat([pred_neg, zeros], dim=-1)
|
50 |
+
# pred_pos = torch.cat([pred_pos, zeros], dim=-1)
|
51 |
+
#
|
52 |
+
# neg_loss = torch.logsumexp(pred_neg, dim=-1)
|
53 |
+
# pos_loss = torch.logsumexp(pred_pos, dim=-1)
|
54 |
+
#
|
55 |
+
# return (neg_loss + pos_loss).mean()
|
56 |
+
|
57 |
+
|
58 |
+
@register_model
|
59 |
+
class ProTrekTrimodalModel(AbstractModel):
|
60 |
+
def __init__(self,
|
61 |
+
protein_config: str,
|
62 |
+
text_config: str,
|
63 |
+
structure_config: str = None,
|
64 |
+
repr_dim: int = 1024,
|
65 |
+
temperature: float = 0.07,
|
66 |
+
load_protein_pretrained: bool = True,
|
67 |
+
load_text_pretrained: bool = True,
|
68 |
+
use_mlm_loss: bool = False,
|
69 |
+
use_zlpr_loss: bool = False,
|
70 |
+
use_saprot: bool = False,
|
71 |
+
gradient_checkpointing: bool = False,
|
72 |
+
**kwargs):
|
73 |
+
"""
|
74 |
+
Args:
|
75 |
+
protein_config: Path to the config file for protein sequence encoder
|
76 |
+
|
77 |
+
text_config: Path to the config file for text encoder
|
78 |
+
|
79 |
+
structure_config: Path to the config file for structure encoder
|
80 |
+
|
81 |
+
repr_dim: Output dimension of the protein and text representation
|
82 |
+
|
83 |
+
temperature: Temperature for softmax
|
84 |
+
|
85 |
+
load_protein_pretrained: Whether to load pretrained weights for protein encoder
|
86 |
+
|
87 |
+
load_text_pretrained: Whether to load pretrained weights for text encoder
|
88 |
+
|
89 |
+
use_mlm_loss: Whether to use masked language modeling loss
|
90 |
+
|
91 |
+
use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf"
|
92 |
+
|
93 |
+
use_saprot: Whether to use SaProt as protein encoder
|
94 |
+
|
95 |
+
gradient_checkpointing: Whether to use gradient checkpointing for protein encoder
|
96 |
+
"""
|
97 |
+
self.protein_config = protein_config
|
98 |
+
self.structure_config = structure_config
|
99 |
+
self.text_config = text_config
|
100 |
+
self.repr_dim = repr_dim
|
101 |
+
self.temperature = temperature
|
102 |
+
self.load_protein_pretrained = load_protein_pretrained
|
103 |
+
self.load_text_pretrained = load_text_pretrained
|
104 |
+
self.use_mlm_loss = use_mlm_loss
|
105 |
+
self.use_zlpr_loss = use_zlpr_loss
|
106 |
+
self.use_saprot = use_saprot
|
107 |
+
self.gradient_checkpointing = gradient_checkpointing
|
108 |
+
super().__init__(**kwargs)
|
109 |
+
|
110 |
+
def initialize_metrics(self, stage: str) -> dict:
|
111 |
+
return_dict = {
|
112 |
+
f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
|
113 |
+
f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
|
114 |
+
}
|
115 |
+
|
116 |
+
if self.use_mlm_loss:
|
117 |
+
return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
|
118 |
+
if self.structure_config is not None:
|
119 |
+
return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
|
120 |
+
|
121 |
+
if self.structure_config is not None:
|
122 |
+
return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy()
|
123 |
+
return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
|
124 |
+
return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
|
125 |
+
return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy()
|
126 |
+
|
127 |
+
return return_dict
|
128 |
+
|
129 |
+
def initialize_model(self):
|
130 |
+
# Initialize encoders
|
131 |
+
self.protein_encoder = ProteinEncoder(self.protein_config,
|
132 |
+
self.repr_dim,
|
133 |
+
self.load_protein_pretrained,
|
134 |
+
self.gradient_checkpointing)
|
135 |
+
|
136 |
+
self.text_encoder = TextEncoder(self.text_config,
|
137 |
+
self.repr_dim,
|
138 |
+
self.load_text_pretrained,
|
139 |
+
self.gradient_checkpointing)
|
140 |
+
|
141 |
+
# Learnable temperature
|
142 |
+
self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))
|
143 |
+
|
144 |
+
# self.model is used for saving and loading
|
145 |
+
self.model = torch.nn.ParameterList([self.temperature,
|
146 |
+
self.protein_encoder,
|
147 |
+
self.text_encoder])
|
148 |
+
|
149 |
+
# If the structure encoder is specified
|
150 |
+
if self.structure_config is not None:
|
151 |
+
self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim)
|
152 |
+
self.model.append(self.structure_encoder)
|
153 |
+
|
154 |
+
def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
155 |
+
return self.text_encoder.get_repr(texts, batch_size, verbose)
|
156 |
+
|
157 |
+
def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
158 |
+
return self.structure_encoder.get_repr(proteins, batch_size, verbose)
|
159 |
+
|
160 |
+
def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
161 |
+
return self.protein_encoder.get_repr(proteins, batch_size, verbose)
|
162 |
+
|
163 |
+
def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
protein_inputs: A dictionary for protein encoder
|
167 |
+
structure_inputs: A dictionary for structure encoder
|
168 |
+
text_inputs : A dictionary for text encoder
|
169 |
+
"""
|
170 |
+
protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss)
|
171 |
+
text_repr = self.text_encoder(text_inputs)
|
172 |
+
|
173 |
+
outputs = [text_repr, protein_repr, protein_mask_logits]
|
174 |
+
|
175 |
+
if self.structure_config is not None:
|
176 |
+
structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss)
|
177 |
+
outputs += [structure_repr, structure_mask_logits]
|
178 |
+
|
179 |
+
return outputs
|
180 |
+
|
181 |
+
def loss_func(self, stage: str, outputs, labels):
|
182 |
+
if self.structure_config is not None:
|
183 |
+
text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs
|
184 |
+
else:
|
185 |
+
text_repr, protein_repr, protein_mask_logits = outputs
|
186 |
+
|
187 |
+
device = text_repr.device
|
188 |
+
|
189 |
+
text_repr = normalize(text_repr, dim=-1)
|
190 |
+
protein_repr = normalize(protein_repr, dim=-1)
|
191 |
+
|
192 |
+
# Gather representations from all GPUs
|
193 |
+
all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach()
|
194 |
+
all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
|
195 |
+
|
196 |
+
if self.structure_config is not None:
|
197 |
+
structure_repr = normalize(structure_repr, dim=-1)
|
198 |
+
all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach()
|
199 |
+
|
200 |
+
# text_idx = labels["text_idx"]
|
201 |
+
# text_candidates = labels["text_candidates"]
|
202 |
+
#
|
203 |
+
# # Gather all text ids
|
204 |
+
# text_inds = self.all_gather(text_idx).flatten()
|
205 |
+
# # Create text classification labels
|
206 |
+
# text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device)
|
207 |
+
# for i, candidate in enumerate(text_candidates):
|
208 |
+
# for j, idx in enumerate(text_inds):
|
209 |
+
# if idx.item() in candidate:
|
210 |
+
# text_labels[i, j] = 1
|
211 |
+
#
|
212 |
+
# # Gather text labels from all GPUs
|
213 |
+
# text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1])
|
214 |
+
#
|
215 |
+
# # Protein classification labels are the transpose of text labels
|
216 |
+
# protein_labels = text_labels.T
|
217 |
+
|
218 |
+
# Batch size
|
219 |
+
rank = dist.get_rank()
|
220 |
+
bs = text_repr.shape[0]
|
221 |
+
|
222 |
+
# Get current labels
|
223 |
+
# protein_labels = protein_labels[rank * bs: rank * bs + bs]
|
224 |
+
# text_labels = text_labels[rank * bs: rank * bs + bs]
|
225 |
+
|
226 |
+
# Create classification labels between structure and sequence
|
227 |
+
bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
|
228 |
+
|
229 |
+
if self.structure_config is not None:
|
230 |
+
pairs = {
|
231 |
+
"protein": ["structure", "text"],
|
232 |
+
"structure": ["protein", "text"],
|
233 |
+
"text": ["protein", "structure"]
|
234 |
+
}
|
235 |
+
else:
|
236 |
+
pairs = {
|
237 |
+
"protein": ["text"],
|
238 |
+
"text": ["protein"]
|
239 |
+
}
|
240 |
+
|
241 |
+
loss_list = []
|
242 |
+
for k, values in pairs.items():
|
243 |
+
for v in values:
|
244 |
+
# Only calculate the similarity for the current batch
|
245 |
+
sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature)
|
246 |
+
|
247 |
+
# if k == "text":
|
248 |
+
# if self.use_zlpr_loss:
|
249 |
+
# loss = multilabel_cross_entropy(sim, protein_labels)
|
250 |
+
# else:
|
251 |
+
# loss = cross_entropy(sim, bs_labels)
|
252 |
+
#
|
253 |
+
# pred = []
|
254 |
+
# for s, l in zip(sim, protein_labels):
|
255 |
+
# n_label = l.sum()
|
256 |
+
# topk = torch.topk(s, k=n_label).indices
|
257 |
+
# if l[topk].sum() == n_label:
|
258 |
+
# pred.append(1)
|
259 |
+
# else:
|
260 |
+
# pred.append(0)
|
261 |
+
#
|
262 |
+
# pred = torch.tensor(pred).to(device)
|
263 |
+
# label = torch.ones_like(pred)
|
264 |
+
# self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
|
265 |
+
# # if v == "protein":
|
266 |
+
# # acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute()
|
267 |
+
# # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
|
268 |
+
#
|
269 |
+
# elif v == "text":
|
270 |
+
# if self.use_zlpr_loss:
|
271 |
+
# loss = multilabel_cross_entropy(sim, text_labels)
|
272 |
+
# else:
|
273 |
+
# loss = cross_entropy(sim, bs_labels)
|
274 |
+
#
|
275 |
+
# pred = []
|
276 |
+
# for s, l in zip(sim, text_labels):
|
277 |
+
# n_label = l.sum()
|
278 |
+
# topk = torch.topk(s, k=n_label).indices
|
279 |
+
# if l[topk].sum() == n_label:
|
280 |
+
# pred.append(1)
|
281 |
+
# else:
|
282 |
+
# pred.append(0)
|
283 |
+
#
|
284 |
+
# pred = torch.tensor(pred).to(device)
|
285 |
+
# label = torch.ones_like(pred)
|
286 |
+
# # if k == "protein":
|
287 |
+
# # acc = pred.sum() / len(pred)
|
288 |
+
# # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
|
289 |
+
# self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
|
290 |
+
#
|
291 |
+
# else:
|
292 |
+
# loss = cross_entropy(sim, bs_labels)
|
293 |
+
# self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
|
294 |
+
|
295 |
+
loss = cross_entropy(sim, bs_labels)
|
296 |
+
self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
|
297 |
+
loss_list.append(loss)
|
298 |
+
|
299 |
+
# Masked language modeling loss
|
300 |
+
if self.use_mlm_loss:
|
301 |
+
k_label = [("protein", labels["seq_labels"])]
|
302 |
+
if self.structure_config is not None:
|
303 |
+
k_label.append(("structure", labels["struc_labels"]))
|
304 |
+
|
305 |
+
for k, label in k_label:
|
306 |
+
logits = eval(f"{k}_mask_logits")
|
307 |
+
# merge the first and second dimension of logits
|
308 |
+
logits = logits.view(-1, logits.shape[-1])
|
309 |
+
label = label.flatten().to(device)
|
310 |
+
mlm_loss = cross_entropy(logits, label, ignore_index=-1)
|
311 |
+
loss_list.append(mlm_loss)
|
312 |
+
self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label)
|
313 |
+
|
314 |
+
loss = sum(loss_list) / len(loss_list)
|
315 |
+
|
316 |
+
if stage == "train":
|
317 |
+
log_dict = self.get_log_dict("train")
|
318 |
+
log_dict["train_loss"] = loss
|
319 |
+
self.log_info(log_dict)
|
320 |
+
|
321 |
+
# Reset train metrics
|
322 |
+
self.reset_metrics("train")
|
323 |
+
|
324 |
+
return loss
|
325 |
+
|
326 |
+
def padded_gather(self, tensor: torch.Tensor):
|
327 |
+
"""
|
328 |
+
Gather tensors from all GPUs, allowing different shapes at the batch dimension.
|
329 |
+
"""
|
330 |
+
|
331 |
+
# Get the size of the tensor
|
332 |
+
size = tensor.shape[0]
|
333 |
+
all_sizes = self.all_gather(torch.tensor(size, device=tensor.device))
|
334 |
+
max_size = max(all_sizes)
|
335 |
+
|
336 |
+
# Pad the tensor
|
337 |
+
if size != max_size:
|
338 |
+
tmp = torch.zeros(max_size, tensor.shape[-1], dtype=tensor.dtype, device=tensor.device)
|
339 |
+
tmp[:size] = tensor
|
340 |
+
tensor = tmp
|
341 |
+
|
342 |
+
padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1])
|
343 |
+
tensor = padded_tensor[:sum(all_sizes)]
|
344 |
+
|
345 |
+
return tensor
|
346 |
+
|
347 |
+
def _get_protein_indices(self):
|
348 |
+
world_size = dist.get_world_size()
|
349 |
+
rank = dist.get_rank()
|
350 |
+
|
351 |
+
if self.use_saprot:
|
352 |
+
proteins = []
|
353 |
+
for sub_dict in self.uniprot2label.values():
|
354 |
+
aa_seq = sub_dict["seq"]
|
355 |
+
foldseek_seq = sub_dict["foldseek"]
|
356 |
+
assert len(aa_seq) == len(foldseek_seq)
|
357 |
+
seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
|
358 |
+
proteins.append(seq)
|
359 |
+
|
360 |
+
else:
|
361 |
+
proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
|
362 |
+
|
363 |
+
span = math.ceil(len(proteins) / world_size)
|
364 |
+
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
365 |
+
|
366 |
+
# Display the progress bar on the rank 0 process
|
367 |
+
verbose = self.trainer.local_rank == 0
|
368 |
+
# Get protein representations
|
369 |
+
sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
|
370 |
+
protein_repr = self.padded_gather(sub_protein_repr)
|
371 |
+
|
372 |
+
# Construct faiss index
|
373 |
+
d = protein_repr.shape[-1]
|
374 |
+
protein_indices = faiss.IndexFlatIP(d)
|
375 |
+
protein_indices.add(protein_repr.cpu().numpy())
|
376 |
+
return protein_indices
|
377 |
+
|
378 |
+
def _get_structure_indices(self):
|
379 |
+
world_size = dist.get_world_size()
|
380 |
+
rank = dist.get_rank()
|
381 |
+
|
382 |
+
proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()]
|
383 |
+
span = math.ceil(len(proteins) / world_size)
|
384 |
+
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
385 |
+
|
386 |
+
# Display the progress bar on the rank 0 process
|
387 |
+
verbose = self.trainer.local_rank == 0
|
388 |
+
# Get protein representations
|
389 |
+
sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
|
390 |
+
protein_repr = self.padded_gather(sub_protein_repr)
|
391 |
+
|
392 |
+
# Construct faiss index
|
393 |
+
d = protein_repr.shape[-1]
|
394 |
+
structure_indices = faiss.IndexFlatIP(d)
|
395 |
+
structure_indices.add(protein_repr.cpu().numpy())
|
396 |
+
return structure_indices
|
397 |
+
|
398 |
+
def _get_text_indices(self):
|
399 |
+
world_size = dist.get_world_size()
|
400 |
+
rank = dist.get_rank()
|
401 |
+
|
402 |
+
# Display the progress bar on the rank 0 process
|
403 |
+
verbose = self.trainer.local_rank == 0
|
404 |
+
if verbose:
|
405 |
+
iterator = tqdm(self.label2text.keys(), desc="Get text representations")
|
406 |
+
else:
|
407 |
+
iterator = self.label2text.keys()
|
408 |
+
|
409 |
+
text_embeddings = {}
|
410 |
+
for subsection in iterator:
|
411 |
+
if subsection == "Total":
|
412 |
+
continue
|
413 |
+
|
414 |
+
texts = []
|
415 |
+
for text_list in self.label2text[subsection].values():
|
416 |
+
# Only use the first text for efficiency
|
417 |
+
texts.append(text_list[0:1])
|
418 |
+
|
419 |
+
span = math.ceil(len(texts) / world_size)
|
420 |
+
texts = texts[rank * span: (rank + 1) * span]
|
421 |
+
embeddings = []
|
422 |
+
for text_list in texts:
|
423 |
+
text_repr = self.text_encoder.get_repr(text_list)
|
424 |
+
mean_repr = text_repr.mean(dim=0, keepdim=True)
|
425 |
+
norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
|
426 |
+
embeddings.append(norm_repr)
|
427 |
+
|
428 |
+
if len(embeddings) > 0:
|
429 |
+
embeddings = torch.cat(embeddings, dim=0)
|
430 |
+
else:
|
431 |
+
embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device)
|
432 |
+
|
433 |
+
text_repr = self.padded_gather(embeddings)
|
434 |
+
text_embeddings[subsection] = text_repr
|
435 |
+
|
436 |
+
# Aggregate text embeddings for global retrieval
|
437 |
+
total_embeddings = []
|
438 |
+
for idx in self.label2text["Total"].values():
|
439 |
+
subsection, i = idx.split("|")
|
440 |
+
total_embeddings.append(text_embeddings[subsection][int(i)])
|
441 |
+
|
442 |
+
text_embeddings["Total"] = torch.stack(total_embeddings)
|
443 |
+
|
444 |
+
# Construct faiss index
|
445 |
+
text_indices = {}
|
446 |
+
for subsection, text_repr in text_embeddings.items():
|
447 |
+
d = text_repr.shape[-1]
|
448 |
+
text_indices[subsection] = faiss.IndexFlatIP(d)
|
449 |
+
text_indices[subsection].add(text_repr.cpu().numpy())
|
450 |
+
|
451 |
+
return text_indices
|
452 |
+
|
453 |
+
def _protein2text(self, modality: str, protein_indices, text_indices: dict):
|
454 |
+
def do(process_id, idx, row, writer):
|
455 |
+
subsection, uniprot_id, prob_idx, label = row
|
456 |
+
|
457 |
+
# Retrieve ranking results
|
458 |
+
p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
|
459 |
+
text_inds = text_indices[subsection]
|
460 |
+
sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal)
|
461 |
+
sim_scores, rank_inds = sim_scores[0], rank_inds[0]
|
462 |
+
|
463 |
+
# Calculate Average Precision(AP)
|
464 |
+
ranks = []
|
465 |
+
label = set(label)
|
466 |
+
for i, rk in enumerate(rank_inds):
|
467 |
+
# Find the rank of this label in all labels
|
468 |
+
if rk in label:
|
469 |
+
ranks.append(i + 1)
|
470 |
+
|
471 |
+
ranks = np.array(ranks)
|
472 |
+
ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
|
473 |
+
|
474 |
+
# Calculate Mean Reciprocal Rank(MRR)
|
475 |
+
best_rank = ranks[0]
|
476 |
+
mrr = 1 / best_rank
|
477 |
+
|
478 |
+
# Calculate the AUC
|
479 |
+
true_labels = np.zeros_like(sim_scores)
|
480 |
+
true_labels[ranks - 1] = 1
|
481 |
+
if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
|
482 |
+
auc = 0
|
483 |
+
else:
|
484 |
+
auc = roc_auc_score(true_labels, sim_scores)
|
485 |
+
|
486 |
+
output = json.dumps([ap, mrr, auc])
|
487 |
+
writer.write(output + "\n")
|
488 |
+
|
489 |
+
inputs = []
|
490 |
+
swissprot_subsections = set()
|
491 |
+
for subsection in text_indices.keys():
|
492 |
+
for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()):
|
493 |
+
if uniprot_id in self.swissprot_ids:
|
494 |
+
if subsection in labels:
|
495 |
+
swissprot_subsections.add(subsection)
|
496 |
+
label = labels[subsection]
|
497 |
+
inputs.append((subsection, uniprot_id, i, label))
|
498 |
+
|
499 |
+
# Randomly shuffle the inputs
|
500 |
+
random.seed(20000812)
|
501 |
+
random.shuffle(inputs)
|
502 |
+
|
503 |
+
# Split inputs into chunks for parallel processing
|
504 |
+
world_size = dist.get_world_size()
|
505 |
+
rank = dist.get_rank()
|
506 |
+
|
507 |
+
span = math.ceil(len(inputs) / world_size)
|
508 |
+
sub_inputs = inputs[rank * span: (rank + 1) * span]
|
509 |
+
|
510 |
+
# Display the progress bar on the rank 0 process
|
511 |
+
verbose = self.trainer.local_rank == 0
|
512 |
+
if verbose:
|
513 |
+
print("Evaluating on each subsection...")
|
514 |
+
tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
|
515 |
+
mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
|
516 |
+
return_results=True)
|
517 |
+
outputs = mpr.run()
|
518 |
+
os.remove(tmp_path)
|
519 |
+
|
520 |
+
# Aggregate results
|
521 |
+
tensor_outputs = []
|
522 |
+
for output in outputs:
|
523 |
+
ap, mrr, auc = json.loads(output)
|
524 |
+
tensor_outputs.append([float(ap), float(mrr), float(auc)])
|
525 |
+
|
526 |
+
tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
|
527 |
+
tensor_outputs = self.padded_gather(tensor_outputs)
|
528 |
+
|
529 |
+
# Record results
|
530 |
+
avg_results = {}
|
531 |
+
for subsection in swissprot_subsections:
|
532 |
+
avg_results[subsection] = {"map": [],
|
533 |
+
"mrr": [],
|
534 |
+
"auc": []}
|
535 |
+
|
536 |
+
for input, output in zip(inputs, tensor_outputs):
|
537 |
+
ap, mrr, auc = output
|
538 |
+
subsection, _, _, _ = input
|
539 |
+
|
540 |
+
avg_results[subsection]["map"].append(ap.cpu().item())
|
541 |
+
avg_results[subsection]["mrr"].append(mrr.cpu().item())
|
542 |
+
avg_results[subsection]["auc"].append(auc.cpu().item())
|
543 |
+
|
544 |
+
results = {
|
545 |
+
f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
|
546 |
+
f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
|
547 |
+
f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
|
548 |
+
}
|
549 |
+
|
550 |
+
# Average the precision and recall for each level
|
551 |
+
for level, labels in [("residue-level", residue_level),
|
552 |
+
("sequence-level", sequence_level),
|
553 |
+
("all", residue_level | sequence_level)]:
|
554 |
+
|
555 |
+
mrrs = []
|
556 |
+
maps = []
|
557 |
+
aucs = []
|
558 |
+
for subsection in labels:
|
559 |
+
if subsection in avg_results:
|
560 |
+
mrrs.append(np.mean(avg_results[subsection]["mrr"]))
|
561 |
+
maps.append(np.mean(avg_results[subsection]["map"]))
|
562 |
+
aucs.append(np.mean(avg_results[subsection]["auc"]))
|
563 |
+
|
564 |
+
results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
|
565 |
+
results[f"{modality}2Text_{level}_map"] = np.mean(maps)
|
566 |
+
results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)
|
567 |
+
|
568 |
+
return results
|
569 |
+
|
570 |
+
def _text2protein(self, modality: str, protein_indices, text_indices: dict):
|
571 |
+
def do(process_id, idx, row, writer):
|
572 |
+
subsection, text_id, label = row
|
573 |
+
|
574 |
+
# Retrieve ranking results
|
575 |
+
t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1)
|
576 |
+
sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal)
|
577 |
+
sim_scores, rank_inds = sim_scores[0], rank_inds[0]
|
578 |
+
|
579 |
+
# Calculate Average Precision(AP)
|
580 |
+
ranks = []
|
581 |
+
label = set(label)
|
582 |
+
for i, rk in enumerate(rank_inds):
|
583 |
+
# Find the rank of this label in all labels
|
584 |
+
if rk in label:
|
585 |
+
ranks.append(i + 1)
|
586 |
+
|
587 |
+
ranks = np.array(ranks)
|
588 |
+
ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
|
589 |
+
|
590 |
+
# Calculate Mean Reciprocal Rank(MRR)
|
591 |
+
best_rank = ranks[0]
|
592 |
+
mrr = 1 / best_rank
|
593 |
+
|
594 |
+
# Calculate the AUC
|
595 |
+
true_labels = np.zeros_like(sim_scores)
|
596 |
+
true_labels[ranks - 1] = 1
|
597 |
+
if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
|
598 |
+
auc = 0
|
599 |
+
else:
|
600 |
+
auc = roc_auc_score(true_labels, sim_scores)
|
601 |
+
|
602 |
+
output = json.dumps([ap, mrr, auc])
|
603 |
+
writer.write(output + "\n")
|
604 |
+
|
605 |
+
text2label = {}
|
606 |
+
swissprot_subsections = set()
|
607 |
+
for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()):
|
608 |
+
# Only evaluate the texts in Swiss-Prot
|
609 |
+
if uniprot_id not in self.swissprot_ids:
|
610 |
+
continue
|
611 |
+
|
612 |
+
for subsection, text_ids in subsections.items():
|
613 |
+
if subsection == "seq" or subsection == "foldseek":
|
614 |
+
continue
|
615 |
+
|
616 |
+
swissprot_subsections.add(subsection)
|
617 |
+
if subsection not in text2label:
|
618 |
+
text2label[subsection] = {}
|
619 |
+
|
620 |
+
for text_id in text_ids:
|
621 |
+
text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i]
|
622 |
+
|
623 |
+
inputs = []
|
624 |
+
for subsection in swissprot_subsections:
|
625 |
+
for i, (text_id, label) in enumerate(text2label[subsection].items()):
|
626 |
+
inputs.append((subsection, text_id, label))
|
627 |
+
|
628 |
+
# Randomly shuffle the inputs
|
629 |
+
random.seed(20000812)
|
630 |
+
random.shuffle(inputs)
|
631 |
+
|
632 |
+
# Split inputs into chunks for parallel processing
|
633 |
+
world_size = dist.get_world_size()
|
634 |
+
rank = dist.get_rank()
|
635 |
+
|
636 |
+
span = math.ceil(len(inputs) / world_size)
|
637 |
+
sub_inputs = inputs[rank * span: (rank + 1) * span]
|
638 |
+
|
639 |
+
# Display the progress bar on the rank 0 process
|
640 |
+
verbose = self.trainer.local_rank == 0
|
641 |
+
if verbose:
|
642 |
+
print("Evaluating on each text...")
|
643 |
+
|
644 |
+
# Add time stamp to the temporary file name to avoid conflicts
|
645 |
+
tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
|
646 |
+
mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
|
647 |
+
return_results=True)
|
648 |
+
outputs = mpr.run()
|
649 |
+
os.remove(tmp_path)
|
650 |
+
|
651 |
+
# Aggregate results
|
652 |
+
tensor_outputs = []
|
653 |
+
for output in outputs:
|
654 |
+
ap, mrr, auc = json.loads(output)
|
655 |
+
tensor_outputs.append([float(ap), float(mrr), float(auc)])
|
656 |
+
|
657 |
+
tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
|
658 |
+
tensor_outputs = self.padded_gather(tensor_outputs)
|
659 |
+
|
660 |
+
# Record results
|
661 |
+
avg_results = {}
|
662 |
+
for subsection in swissprot_subsections:
|
663 |
+
avg_results[subsection] = {"map": [],
|
664 |
+
"mrr": [],
|
665 |
+
"auc": []}
|
666 |
+
|
667 |
+
for input, output in zip(inputs, tensor_outputs):
|
668 |
+
ap, mrr, auc = output
|
669 |
+
subsection, _, _ = input
|
670 |
+
|
671 |
+
avg_results[subsection]["map"].append(ap.cpu().item())
|
672 |
+
avg_results[subsection]["mrr"].append(mrr.cpu().item())
|
673 |
+
avg_results[subsection]["auc"].append(auc.cpu().item())
|
674 |
+
|
675 |
+
results = {
|
676 |
+
f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
|
677 |
+
f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
|
678 |
+
f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
|
679 |
+
}
|
680 |
+
|
681 |
+
# Average the precision and recall for each level
|
682 |
+
for level, labels in [("residue-level", residue_level),
|
683 |
+
("sequence-level", sequence_level),
|
684 |
+
("all", residue_level | sequence_level)]:
|
685 |
+
|
686 |
+
mrrs = []
|
687 |
+
maps = []
|
688 |
+
aucs = []
|
689 |
+
for subsection in labels:
|
690 |
+
if subsection in avg_results:
|
691 |
+
mrrs.append(np.mean(avg_results[subsection]["mrr"]))
|
692 |
+
maps.append(np.mean(avg_results[subsection]["map"]))
|
693 |
+
aucs.append(np.mean(avg_results[subsection]["auc"]))
|
694 |
+
|
695 |
+
results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
|
696 |
+
results[f"Text2{modality}_{level}_map"] = np.mean(maps)
|
697 |
+
results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)
|
698 |
+
|
699 |
+
return results
|
700 |
+
|
701 |
+
def retrieval_eval(self) -> dict:
|
702 |
+
# Get protein representations
|
703 |
+
protein_indices = self._get_protein_indices()
|
704 |
+
|
705 |
+
# Get structure representations
|
706 |
+
# if self.structure_config is not None:
|
707 |
+
# structure_embeddings = self._get_structure_embeddings()
|
708 |
+
|
709 |
+
# Get text representations
|
710 |
+
text_indices = self._get_text_indices()
|
711 |
+
|
712 |
+
# Retrieve texts for each protein
|
713 |
+
results = {}
|
714 |
+
results.update(self._protein2text("Sequence", protein_indices, text_indices))
|
715 |
+
# if self.structure_config is not None:
|
716 |
+
# results.update(self._protein2text("Structure", structure_embeddings, text_embeddings))
|
717 |
+
# results.update(self._text2protein("Structure", structure_embeddings, text_embeddings))
|
718 |
+
|
719 |
+
# Retrieve proteins for each text
|
720 |
+
results.update(self._text2protein("Sequence", protein_indices, text_indices))
|
721 |
+
|
722 |
+
return results
|
723 |
+
|
724 |
+
def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
|
725 |
+
while True:
|
726 |
+
masked_tokens = copy.copy(tokens)
|
727 |
+
labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
|
728 |
+
vocab = [k for k in tokenizer.get_vocab().keys()]
|
729 |
+
|
730 |
+
for i in range(len(tokens)):
|
731 |
+
token = tokens[i]
|
732 |
+
|
733 |
+
prob = random.random()
|
734 |
+
if prob < mask_ratio:
|
735 |
+
prob /= mask_ratio
|
736 |
+
labels[i + 1] = tokenizer.convert_tokens_to_ids(token)
|
737 |
+
|
738 |
+
if prob < 0.8:
|
739 |
+
# 80% random change to mask token
|
740 |
+
if self.use_saprot:
|
741 |
+
token = "#" + token[-1]
|
742 |
+
else:
|
743 |
+
token = tokenizer.mask_token
|
744 |
+
elif prob < 0.9:
|
745 |
+
# 10% chance to change to random token
|
746 |
+
token = random.choice(vocab)
|
747 |
+
else:
|
748 |
+
# 10% chance to keep current token
|
749 |
+
pass
|
750 |
+
|
751 |
+
masked_tokens[i] = token
|
752 |
+
|
753 |
+
# Check if there is at least one masked token
|
754 |
+
if (labels != -1).any():
|
755 |
+
return masked_tokens, labels
|
756 |
+
|
757 |
+
def mlm_eval(self) -> float:
|
758 |
+
world_size = dist.get_world_size()
|
759 |
+
rank = dist.get_rank()
|
760 |
+
|
761 |
+
if self.use_saprot:
|
762 |
+
proteins = []
|
763 |
+
for sub_dict in self.uniprot2label.values():
|
764 |
+
aa_seq = sub_dict["seq"]
|
765 |
+
foldseek_seq = sub_dict["foldseek"]
|
766 |
+
assert len(aa_seq) == len(foldseek_seq)
|
767 |
+
seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
|
768 |
+
proteins.append(seq)
|
769 |
+
|
770 |
+
else:
|
771 |
+
proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
|
772 |
+
|
773 |
+
span = math.ceil(len(proteins) / world_size)
|
774 |
+
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
775 |
+
|
776 |
+
# Display the progress bar on the rank 0 process
|
777 |
+
if self.trainer.local_rank == 0:
|
778 |
+
iterator = tqdm(sub_proteins, desc="Computing mlm...")
|
779 |
+
else:
|
780 |
+
iterator = sub_proteins
|
781 |
+
|
782 |
+
total = torch.tensor([0], dtype=torch.long, device=self.device)
|
783 |
+
correct = torch.tensor([0], dtype=torch.long, device=self.device)
|
784 |
+
for seq in iterator:
|
785 |
+
tokens = self.protein_encoder.tokenizer.tokenize(seq)
|
786 |
+
masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15)
|
787 |
+
seq = " ".join(masked_tokens)
|
788 |
+
|
789 |
+
inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
|
790 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
791 |
+
_, logits = self.protein_encoder(inputs, get_mask_logits=True)
|
792 |
+
|
793 |
+
logits = logits.squeeze(0)
|
794 |
+
labels = labels.to(self.device)
|
795 |
+
|
796 |
+
selecor = labels != -1
|
797 |
+
preds = logits.argmax(dim=-1)[selecor]
|
798 |
+
labels = labels[selecor]
|
799 |
+
|
800 |
+
total += len(preds)
|
801 |
+
correct += (preds == labels).sum()
|
802 |
+
|
803 |
+
# Gather all results
|
804 |
+
total = self.padded_gather(total).sum()
|
805 |
+
correct = self.padded_gather(correct).sum()
|
806 |
+
|
807 |
+
acc = correct / total
|
808 |
+
return acc.cpu().item()
|
809 |
+
|
810 |
+
def _load_eval_data(self, stage):
|
811 |
+
# Load the data
|
812 |
+
lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
|
813 |
+
uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
|
814 |
+
label2text_path = os.path.join(lmdb_dir, "label2text.json")
|
815 |
+
swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")
|
816 |
+
|
817 |
+
self.uniprot2label = json.load(open(uniprot2label_path, "r"))
|
818 |
+
self.label2text = json.load(open(label2text_path, "r"))
|
819 |
+
self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist())
|
820 |
+
self.k = 3
|
821 |
+
|
822 |
+
def on_test_start(self):
|
823 |
+
self._load_eval_data("test")
|
824 |
+
|
825 |
+
log_dict = self.retrieval_eval()
|
826 |
+
log_dict = {"test_" + k: v for k, v in log_dict.items()}
|
827 |
+
if self.use_mlm_loss:
|
828 |
+
log_dict["test_mask_acc"] = self.mlm_eval()
|
829 |
+
self.log_info(log_dict)
|
830 |
+
print(log_dict)
|
831 |
+
|
832 |
+
def on_validation_start(self):
|
833 |
+
# Clear the cache
|
834 |
+
torch.cuda.empty_cache()
|
835 |
+
|
836 |
+
self._load_eval_data("valid")
|
837 |
+
|
838 |
+
log_dict = self.retrieval_eval()
|
839 |
+
log_dict = {"valid_" + k: v for k, v in log_dict.items()}
|
840 |
+
if self.use_mlm_loss:
|
841 |
+
log_dict["valid_mask_acc"] = self.mlm_eval()
|
842 |
+
self.log_info(log_dict)
|
843 |
+
|
844 |
+
self.check_save_condition(self.step, mode="max")
|
845 |
+
|
846 |
+
def test_step(self, batch, batch_idx):
|
847 |
+
return
|
848 |
+
|
849 |
+
def validation_step(self, batch, batch_idx):
|
850 |
+
return
|
851 |
+
|
852 |
+
def on_train_epoch_end(self):
|
853 |
+
super().on_train_epoch_end()
|
854 |
+
# Re-sample the subset of the training data
|
855 |
+
if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
|
856 |
+
self.trainer.datamodule.train_dataset.sample_subset()
|
857 |
+
|
858 |
+
# def test_epoch_end(self, outputs):
|
859 |
+
# log_dict = self.get_log_dict("test")
|
860 |
+
# log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
|
861 |
+
#
|
862 |
+
# print(log_dict)
|
863 |
+
# self.log_info(log_dict)
|
864 |
+
#
|
865 |
+
# self.reset_metrics("test")
|
866 |
+
#
|
867 |
+
# def validation_epoch_end(self, outputs):
|
868 |
+
# log_dict = self.get_log_dict("valid")
|
869 |
+
# log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
|
870 |
+
#
|
871 |
+
# self.log_info(log_dict)
|
872 |
+
# self.reset_metrics("valid")
|
873 |
+
# self.check_save_condition(log_dict["valid_loss"], mode="min")
|
874 |
+
|
model/ProTrek/structure_encoder.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
|
5 |
+
from torch.nn.functional import normalize
|
6 |
+
|
7 |
+
|
8 |
+
class StructureEncoder(torch.nn.Module):
|
9 |
+
def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False):
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
config_path: Path to the config file
|
13 |
+
|
14 |
+
out_dim: Output dimension of the structure representation
|
15 |
+
|
16 |
+
gradient_checkpointing: Whether to use gradient checkpointing
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
config = EsmConfig.from_pretrained(config_path)
|
20 |
+
self.model = EsmForMaskedLM(config)
|
21 |
+
self.out = torch.nn.Linear(config.hidden_size, out_dim)
|
22 |
+
|
23 |
+
# Set gradient checkpointing
|
24 |
+
self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
|
25 |
+
|
26 |
+
# Remove contact head
|
27 |
+
self.model.esm.contact_head = None
|
28 |
+
|
29 |
+
# Remove position embedding if the embedding type is ``rotary``
|
30 |
+
if config.position_embedding_type == "rotary":
|
31 |
+
self.model.esm.embeddings.position_embeddings = None
|
32 |
+
|
33 |
+
self.tokenizer = EsmTokenizer.from_pretrained(config_path)
|
34 |
+
|
35 |
+
def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
36 |
+
"""
|
37 |
+
Compute protein structure representation for the given proteins
|
38 |
+
Args:
|
39 |
+
protein: A list of protein structural sequences
|
40 |
+
batch_size: Batch size for inference
|
41 |
+
verbose: Whether to print progress
|
42 |
+
"""
|
43 |
+
device = next(self.parameters()).device
|
44 |
+
|
45 |
+
protein_repr = []
|
46 |
+
if verbose:
|
47 |
+
iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
|
48 |
+
else:
|
49 |
+
iterator = range(0, len(proteins), batch_size)
|
50 |
+
|
51 |
+
for i in iterator:
|
52 |
+
protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
|
53 |
+
return_tensors="pt",
|
54 |
+
padding=True)
|
55 |
+
protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
|
56 |
+
output, _ = self.forward(protein_inputs)
|
57 |
+
|
58 |
+
protein_repr.append(output)
|
59 |
+
|
60 |
+
protein_repr = torch.cat(protein_repr, dim=0)
|
61 |
+
return normalize(protein_repr, dim=-1)
|
62 |
+
|
63 |
+
def forward(self, inputs: dict, get_mask_logits: bool = False):
|
64 |
+
"""
|
65 |
+
Encode protein structure into protein representation
|
66 |
+
Args:
|
67 |
+
inputs: A dictionary containing the following keys:
|
68 |
+
- input_ids: [batch, seq_len]
|
69 |
+
- attention_mask: [batch, seq_len]
|
70 |
+
get_mask_logits: Whether to return the logits for masked tokens
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
protein_repr: [batch, protein_repr_dim]
|
74 |
+
mask_logits : [batch, seq_len, vocab_size]
|
75 |
+
"""
|
76 |
+
last_hidden_state = self.model.esm(**inputs).last_hidden_state
|
77 |
+
reprs = last_hidden_state[:, 0, :]
|
78 |
+
reprs = self.out(reprs)
|
79 |
+
|
80 |
+
# Get logits for masked tokens
|
81 |
+
if get_mask_logits:
|
82 |
+
mask_logits = self.model.lm_head(last_hidden_state)
|
83 |
+
else:
|
84 |
+
mask_logits = None
|
85 |
+
|
86 |
+
return reprs, mask_logits
|
model/ProTrek/text_encoder.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from torch.nn.functional import normalize
|
5 |
+
from transformers import BertConfig, BertModel, BertTokenizer
|
6 |
+
|
7 |
+
|
8 |
+
class TextEncoder(torch.nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
config_path: str,
|
11 |
+
out_dim: int,
|
12 |
+
load_pretrained: bool = True,
|
13 |
+
gradient_checkpointing: bool = False):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
config_path: Path to the config file
|
17 |
+
|
18 |
+
out_dim: Output dimension of the text representation
|
19 |
+
|
20 |
+
load_pretrained: Whether to load pretrained weights
|
21 |
+
|
22 |
+
gradient_checkpointing: Whether to enable gradient checkpointing
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
config = BertConfig.from_pretrained(config_path)
|
26 |
+
if load_pretrained:
|
27 |
+
self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
|
28 |
+
else:
|
29 |
+
self.model = BertModel(config, add_pooling_layer=False)
|
30 |
+
self.out = torch.nn.Linear(config.hidden_size, out_dim)
|
31 |
+
|
32 |
+
# Set gradient checkpointing
|
33 |
+
self.model.encoder.gradient_checkpointing = gradient_checkpointing
|
34 |
+
|
35 |
+
self.tokenizer = BertTokenizer.from_pretrained(config_path)
|
36 |
+
|
37 |
+
def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
38 |
+
"""
|
39 |
+
Compute text representation for the given texts
|
40 |
+
Args:
|
41 |
+
texts: A list of strings
|
42 |
+
batch_size: Batch size for inference
|
43 |
+
verbose: Whether to print progress
|
44 |
+
"""
|
45 |
+
device = next(self.parameters()).device
|
46 |
+
|
47 |
+
text_repr = []
|
48 |
+
if verbose:
|
49 |
+
iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
|
50 |
+
else:
|
51 |
+
iterator = range(0, len(texts), batch_size)
|
52 |
+
|
53 |
+
for i in iterator:
|
54 |
+
text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
|
55 |
+
return_tensors="pt",
|
56 |
+
truncation=True,
|
57 |
+
max_length=512,
|
58 |
+
padding=True)
|
59 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
60 |
+
output = self(text_inputs)
|
61 |
+
|
62 |
+
text_repr.append(output)
|
63 |
+
|
64 |
+
text_repr = torch.cat(text_repr, dim=0)
|
65 |
+
return normalize(text_repr, dim=-1)
|
66 |
+
|
67 |
+
def forward(self, inputs: dict):
|
68 |
+
"""
|
69 |
+
Encode text into text representation
|
70 |
+
Args:
|
71 |
+
inputs: A dictionary containing the following keys:
|
72 |
+
- input_ids: [batch, seq_len]
|
73 |
+
- attention_mask: [batch, seq_len]
|
74 |
+
- token_type_ids: [batch, seq_len]
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
text_repr: [batch, text_repr_dim]
|
78 |
+
"""
|
79 |
+
reprs = self.model(**inputs).last_hidden_state[:, 0, :]
|
80 |
+
reprs = self.out(reprs)
|
81 |
+
return reprs
|
model/abstract_model.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import abc
|
3 |
+
import os
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from utils.lr_scheduler import *
|
8 |
+
from torch import distributed as dist
|
9 |
+
|
10 |
+
|
11 |
+
class AbstractModel(pl.LightningModule):
|
12 |
+
def __init__(self,
|
13 |
+
lr_scheduler_kwargs: dict = None,
|
14 |
+
optimizer_kwargs: dict = None,
|
15 |
+
save_path: str = None,
|
16 |
+
from_checkpoint: str = None,
|
17 |
+
load_prev_scheduler: bool = False,
|
18 |
+
save_weights_only: bool = True,):
|
19 |
+
"""
|
20 |
+
|
21 |
+
Args:
|
22 |
+
lr_scheduler: Kwargs for lr_scheduler
|
23 |
+
optimizer_kwargs: Kwargs for optimizer_kwargs
|
24 |
+
save_path: Save trained model
|
25 |
+
from_checkpoint: Load model from checkpoint
|
26 |
+
load_prev_scheduler: Whether load previous scheduler from checkpoint
|
27 |
+
load_strict: Whether load model strictly
|
28 |
+
save_weights_only: Whether save only weights or also optimizer and lr_scheduler
|
29 |
+
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
self.initialize_model()
|
33 |
+
|
34 |
+
self.metrics = {}
|
35 |
+
for stage in ["train", "valid", "test"]:
|
36 |
+
stage_metrics = self.initialize_metrics(stage)
|
37 |
+
# Rigister metrics as attributes
|
38 |
+
for metric_name, metric in stage_metrics.items():
|
39 |
+
setattr(self, metric_name, metric)
|
40 |
+
|
41 |
+
self.metrics[stage] = stage_metrics
|
42 |
+
|
43 |
+
if lr_scheduler_kwargs is None:
|
44 |
+
# Default lr_scheduler
|
45 |
+
self.lr_scheduler_kwargs = {
|
46 |
+
"class": "ConstantLRScheduler",
|
47 |
+
"init_lr": 0,
|
48 |
+
}
|
49 |
+
print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
|
50 |
+
|
51 |
+
else:
|
52 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs
|
53 |
+
|
54 |
+
if optimizer_kwargs is None:
|
55 |
+
# Default optimizer
|
56 |
+
self.optimizer_kwargs = {
|
57 |
+
"class": "AdamW",
|
58 |
+
"betas": (0.9, 0.98),
|
59 |
+
"weight_decay": 0.01,
|
60 |
+
}
|
61 |
+
print("No optimizer_kwargs provided. The default optimizer is AdamW.")
|
62 |
+
else:
|
63 |
+
self.optimizer_kwargs = optimizer_kwargs
|
64 |
+
self.init_optimizers()
|
65 |
+
|
66 |
+
self.save_path = save_path
|
67 |
+
self.save_weights_only = save_weights_only
|
68 |
+
|
69 |
+
# temp_step is used for accumulating gradients
|
70 |
+
self.temp_step = 0
|
71 |
+
self.step = 0
|
72 |
+
self.epoch = 0
|
73 |
+
|
74 |
+
self.load_prev_scheduler = load_prev_scheduler
|
75 |
+
self.from_checkpoint = from_checkpoint
|
76 |
+
if from_checkpoint:
|
77 |
+
self.load_checkpoint(from_checkpoint)
|
78 |
+
|
79 |
+
@abc.abstractmethod
|
80 |
+
def initialize_model(self) -> None:
|
81 |
+
"""
|
82 |
+
All model initialization should be done here
|
83 |
+
Note that the whole model must be named as "self.model" for model saving and loading
|
84 |
+
"""
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
@abc.abstractmethod
|
88 |
+
def forward(self, *args, **kwargs):
|
89 |
+
"""
|
90 |
+
Forward propagation
|
91 |
+
"""
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
@abc.abstractmethod
|
95 |
+
def initialize_metrics(self, stage: str) -> dict:
|
96 |
+
"""
|
97 |
+
Initialize metrics for each stage
|
98 |
+
Args:
|
99 |
+
stage: "train", "valid" or "test"
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
A dictionary of metrics for the stage. Keys are metric names and values are metric objects
|
103 |
+
"""
|
104 |
+
raise NotImplementedError
|
105 |
+
|
106 |
+
@abc.abstractmethod
|
107 |
+
def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
|
108 |
+
"""
|
109 |
+
|
110 |
+
Args:
|
111 |
+
stage: "train", "valid" or "test"
|
112 |
+
outputs: model outputs for calculating loss
|
113 |
+
labels: labels for calculating loss
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
loss
|
117 |
+
|
118 |
+
"""
|
119 |
+
raise NotImplementedError
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def load_weights(model, weights):
|
123 |
+
model_dict = model.state_dict()
|
124 |
+
|
125 |
+
unused_params = []
|
126 |
+
missed_params = list(model_dict.keys())
|
127 |
+
|
128 |
+
for k, v in weights.items():
|
129 |
+
if k in model_dict.keys():
|
130 |
+
model_dict[k] = v
|
131 |
+
missed_params.remove(k)
|
132 |
+
|
133 |
+
else:
|
134 |
+
unused_params.append(k)
|
135 |
+
|
136 |
+
if len(missed_params) > 0:
|
137 |
+
print(f"\033[31mSome weights of {type(model).__name__} were not "
|
138 |
+
f"initialized from the model checkpoint: {missed_params}\033[0m")
|
139 |
+
|
140 |
+
if len(unused_params) > 0:
|
141 |
+
print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
|
142 |
+
|
143 |
+
model.load_state_dict(model_dict)
|
144 |
+
|
145 |
+
def optimizer_step(
|
146 |
+
self,
|
147 |
+
epoch: int,
|
148 |
+
batch_idx: int,
|
149 |
+
optimizer,
|
150 |
+
optimizer_closure=None,
|
151 |
+
) -> None:
|
152 |
+
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
|
153 |
+
|
154 |
+
self.temp_step += 1
|
155 |
+
if self.temp_step == self.trainer.accumulate_grad_batches:
|
156 |
+
self.step += 1
|
157 |
+
self.temp_step = 0
|
158 |
+
|
159 |
+
# For pytorch-lightning 1.9.5
|
160 |
+
# def optimizer_step(
|
161 |
+
# self,
|
162 |
+
# epoch: int,
|
163 |
+
# batch_idx: int,
|
164 |
+
# optimizer,
|
165 |
+
# optimizer_idx: int = 0,
|
166 |
+
# optimizer_closure=None,
|
167 |
+
# on_tpu: bool = False,
|
168 |
+
# using_native_amp: bool = False,
|
169 |
+
# using_lbfgs: bool = False,
|
170 |
+
# ) -> None:
|
171 |
+
# super().optimizer_step(
|
172 |
+
# epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
|
173 |
+
# )
|
174 |
+
# self.temp_step += 1
|
175 |
+
# if self.temp_step == self.trainer.accumulate_grad_batches:
|
176 |
+
# self.step += 1
|
177 |
+
# self.temp_step = 0
|
178 |
+
|
179 |
+
def on_train_epoch_end(self):
|
180 |
+
self.epoch += 1
|
181 |
+
|
182 |
+
def training_step(self, batch, batch_idx):
|
183 |
+
inputs, labels = batch
|
184 |
+
|
185 |
+
# optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
|
186 |
+
# for _ in range(1000):
|
187 |
+
# outputs = self(**inputs)
|
188 |
+
# loss = self.loss_func('train', outputs, labels)
|
189 |
+
# loss.backward()
|
190 |
+
# optimizer.step()
|
191 |
+
# optimizer.zero_grad()
|
192 |
+
#
|
193 |
+
# raise
|
194 |
+
|
195 |
+
outputs = self(**inputs)
|
196 |
+
loss = self.loss_func('train', outputs, labels)
|
197 |
+
|
198 |
+
self.log("loss", loss, prog_bar=True)
|
199 |
+
return loss
|
200 |
+
|
201 |
+
def validation_step(self, batch, batch_idx):
|
202 |
+
inputs, labels = batch
|
203 |
+
outputs = self(**inputs)
|
204 |
+
loss = self.loss_func('valid', outputs, labels)
|
205 |
+
self.valid_outputs.append(loss)
|
206 |
+
return loss
|
207 |
+
|
208 |
+
def test_step(self, batch, batch_idx):
|
209 |
+
inputs, labels = batch
|
210 |
+
outputs = self(**inputs)
|
211 |
+
|
212 |
+
loss = self.loss_func('test', outputs, labels)
|
213 |
+
self.test_outputs.append(loss)
|
214 |
+
return loss
|
215 |
+
|
216 |
+
def on_train_start(self) -> None:
|
217 |
+
# Load previous scheduler
|
218 |
+
if getattr(self, "prev_schechuler", None) is not None:
|
219 |
+
try:
|
220 |
+
self.step = self.prev_schechuler["global_step"]
|
221 |
+
self.epoch = self.prev_schechuler["epoch"]
|
222 |
+
self.best_value = self.prev_schechuler["best_value"]
|
223 |
+
self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
|
224 |
+
print(f"Previous training global step: {self.step}")
|
225 |
+
print(f"Previous training epoch: {self.epoch}")
|
226 |
+
print(f"Previous best value: {self.best_value}")
|
227 |
+
print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
|
228 |
+
|
229 |
+
# Load optimizer state
|
230 |
+
if hasattr(self.trainer.strategy, "deepspeed_engine"):
|
231 |
+
# For DeepSpeed strategy
|
232 |
+
try:
|
233 |
+
self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
|
234 |
+
except Exception as e:
|
235 |
+
print(e)
|
236 |
+
|
237 |
+
else:
|
238 |
+
# For DDP strategy
|
239 |
+
self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
|
240 |
+
|
241 |
+
except Exception as e:
|
242 |
+
print(e)
|
243 |
+
raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
|
244 |
+
|
245 |
+
def on_validation_epoch_start(self) -> None:
|
246 |
+
setattr(self, "valid_outputs", [])
|
247 |
+
|
248 |
+
def on_test_epoch_start(self) -> None:
|
249 |
+
setattr(self, "test_outputs", [])
|
250 |
+
|
251 |
+
def load_checkpoint(self, from_checkpoint: str) -> None:
|
252 |
+
"""
|
253 |
+
Args:
|
254 |
+
from_checkpoint: Path to checkpoint.
|
255 |
+
"""
|
256 |
+
|
257 |
+
# If ``from_checkpoint`` is a directory, load the checkpoint in it
|
258 |
+
if os.path.isdir(from_checkpoint):
|
259 |
+
basename = os.path.basename(from_checkpoint)
|
260 |
+
from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
|
261 |
+
|
262 |
+
state_dict = torch.load(from_checkpoint, map_location=self.device)
|
263 |
+
self.load_weights(self.model, state_dict["model"])
|
264 |
+
|
265 |
+
if self.load_prev_scheduler:
|
266 |
+
state_dict.pop("model")
|
267 |
+
self.prev_schechuler = state_dict
|
268 |
+
|
269 |
+
def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
|
270 |
+
"""
|
271 |
+
Save model to save_path
|
272 |
+
Args:
|
273 |
+
save_path: Path to save model
|
274 |
+
save_info: Other info to save
|
275 |
+
save_weights_only: Whether only save model weights
|
276 |
+
"""
|
277 |
+
dir = os.path.dirname(save_path)
|
278 |
+
os.makedirs(dir, exist_ok=True)
|
279 |
+
|
280 |
+
state_dict = {} if save_info is None else save_info
|
281 |
+
state_dict["model"] = self.model.state_dict()
|
282 |
+
|
283 |
+
# Convert model weights to fp32
|
284 |
+
for k, v in state_dict["model"].items():
|
285 |
+
state_dict["model"][k] = v.float()
|
286 |
+
|
287 |
+
if not save_weights_only:
|
288 |
+
state_dict["global_step"] = self.step
|
289 |
+
state_dict["epoch"] = self.epoch
|
290 |
+
state_dict["best_value"] = getattr(self, f"best_value", None)
|
291 |
+
state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
|
292 |
+
|
293 |
+
# If not using DeepSpeed, save optimizer state
|
294 |
+
if not hasattr(self.trainer.strategy, "deepspeed_engine"):
|
295 |
+
state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
|
296 |
+
|
297 |
+
torch.save(state_dict, save_path)
|
298 |
+
|
299 |
+
def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
|
300 |
+
"""
|
301 |
+
Check whether to save model. If save_path is not None and now_value is the best, save model.
|
302 |
+
Args:
|
303 |
+
now_value: Current metric value
|
304 |
+
mode: "min" or "max", meaning whether the lower the better or the higher the better
|
305 |
+
save_info: Other info to save
|
306 |
+
"""
|
307 |
+
|
308 |
+
assert mode in ["min", "max"], "mode should be 'min' or 'max'"
|
309 |
+
|
310 |
+
if self.save_path is not None:
|
311 |
+
# In case there are variables to be included in the save path
|
312 |
+
save_path = eval(f"f'{self.save_path}'")
|
313 |
+
|
314 |
+
dir = os.path.dirname(save_path)
|
315 |
+
os.makedirs(dir, exist_ok=True)
|
316 |
+
|
317 |
+
# Check whether to save model
|
318 |
+
best_value = getattr(self, f"best_value", None)
|
319 |
+
if best_value is not None:
|
320 |
+
if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
|
321 |
+
return
|
322 |
+
|
323 |
+
setattr(self, "best_value", now_value)
|
324 |
+
|
325 |
+
# For DeepSpeed strategy
|
326 |
+
if hasattr(self.trainer.strategy, "deepspeed_engine"):
|
327 |
+
if not self.save_weights_only:
|
328 |
+
self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
|
329 |
+
|
330 |
+
# Save a complete checkpoint
|
331 |
+
if dist.get_rank() == 0:
|
332 |
+
basename = os.path.basename(save_path)
|
333 |
+
ckpt_path = os.path.join(save_path, f"{basename}.pt")
|
334 |
+
self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
|
335 |
+
|
336 |
+
# For normal situation
|
337 |
+
else:
|
338 |
+
if dist.get_rank() == 0:
|
339 |
+
self.save_checkpoint(save_path, save_info, self.save_weights_only)
|
340 |
+
|
341 |
+
def reset_metrics(self, stage) -> None:
|
342 |
+
"""
|
343 |
+
Reset metrics for given stage
|
344 |
+
Args:
|
345 |
+
stage: "train", "valid" or "test"
|
346 |
+
"""
|
347 |
+
for metric in self.metrics[stage].values():
|
348 |
+
metric.reset()
|
349 |
+
|
350 |
+
def get_log_dict(self, stage: str) -> dict:
|
351 |
+
"""
|
352 |
+
Get log dict for the stage
|
353 |
+
Args:
|
354 |
+
stage: "train", "valid" or "test"
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
A dictionary of metrics for the stage. Keys are metric names and values are metric values
|
358 |
+
|
359 |
+
"""
|
360 |
+
return {name: metric.compute() for name, metric in self.metrics[stage].items()}
|
361 |
+
|
362 |
+
def log_info(self, info: dict) -> None:
|
363 |
+
"""
|
364 |
+
Record metrics during training and testing
|
365 |
+
Args:
|
366 |
+
info: dict of metrics
|
367 |
+
"""
|
368 |
+
if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
|
369 |
+
info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
|
370 |
+
info["epoch"] = self.epoch
|
371 |
+
self.logger.log_metrics(info, step=self.step)
|
372 |
+
|
373 |
+
def init_optimizers(self):
|
374 |
+
copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
|
375 |
+
|
376 |
+
# No decay for layer norm and bias
|
377 |
+
no_decay = ['LayerNorm.weight', 'bias']
|
378 |
+
weight_decay = copy_optimizer_kwargs.pop("weight_decay")
|
379 |
+
|
380 |
+
optimizer_grouped_parameters = [
|
381 |
+
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
382 |
+
'weight_decay': weight_decay},
|
383 |
+
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
384 |
+
'weight_decay': 0.0}
|
385 |
+
]
|
386 |
+
|
387 |
+
optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
|
388 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters,
|
389 |
+
lr=self.lr_scheduler_kwargs['init_lr'],
|
390 |
+
**copy_optimizer_kwargs)
|
391 |
+
|
392 |
+
tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
|
393 |
+
lr_scheduler = tmp_kwargs.pop("class")
|
394 |
+
self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
|
395 |
+
|
396 |
+
def configure_optimizers(self):
|
397 |
+
return {"optimizer": self.optimizer,
|
398 |
+
"lr_scheduler": {"scheduler": self.lr_scheduler,
|
399 |
+
"interval": "step",
|
400 |
+
"frequency": 1}
|
401 |
+
}
|
model/model_interface.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import glob
|
4 |
+
|
5 |
+
|
6 |
+
# register all available models through *_model.py files
|
7 |
+
# def construct_model():
|
8 |
+
# model_dir = os.path.dirname(__file__)
|
9 |
+
#
|
10 |
+
# # lists all model files
|
11 |
+
# model_list = []
|
12 |
+
# for root, _, names in os.walk(model_dir):
|
13 |
+
# for name in names:
|
14 |
+
# if name.endswith('_model.py'):
|
15 |
+
# sub_dirs = root.replace(model_dir, '').split(os.sep)
|
16 |
+
# model_list.append((sub_dirs, name[:-3]))
|
17 |
+
#
|
18 |
+
# # load model_config.yaml, controlling which models to be loaded
|
19 |
+
# model_config = yaml.safe_load(open(f"{model_dir}/model_config.yaml", "r"))
|
20 |
+
#
|
21 |
+
# if model_config["verbose"]:
|
22 |
+
# print("*" * 30 + f" Loading model " + "*" * 30)
|
23 |
+
#
|
24 |
+
# # register models
|
25 |
+
# for sub_dirs, name in model_list:
|
26 |
+
# if name in model_config["models"]:
|
27 |
+
# if len(sub_dirs) > 1:
|
28 |
+
# cmd = f"from {'.'.join(sub_dirs)} import {name}"
|
29 |
+
# else:
|
30 |
+
# cmd = f"from . import {name}"
|
31 |
+
#
|
32 |
+
# exec(cmd)
|
33 |
+
#
|
34 |
+
# if model_config["verbose"]:
|
35 |
+
# info = f"Loaded model: {name}"
|
36 |
+
# print(f"\033[32m{info}\033[0m")
|
37 |
+
# else:
|
38 |
+
# if model_config["verbose"]:
|
39 |
+
# info = f"Skipped model: {name}"
|
40 |
+
# print(f"\033[31m{info}\033[0m")
|
41 |
+
#
|
42 |
+
# if model_config["verbose"]:
|
43 |
+
# print("*" * 75)
|
44 |
+
#
|
45 |
+
#
|
46 |
+
# # register function as a wrapper for all models
|
47 |
+
# def register_model(cls):
|
48 |
+
# model_dict[cls.__name__] = cls
|
49 |
+
# return cls
|
50 |
+
#
|
51 |
+
#
|
52 |
+
# model_dict = {}
|
53 |
+
# construct_model()
|
54 |
+
#
|
55 |
+
#
|
56 |
+
# class ModelInterface:
|
57 |
+
# @classmethod
|
58 |
+
# def get_available_models(cls):
|
59 |
+
# return model_dict.keys()
|
60 |
+
#
|
61 |
+
# @classmethod
|
62 |
+
# def init_model(cls, model: str, **kwargs):
|
63 |
+
# """
|
64 |
+
#
|
65 |
+
# Args:
|
66 |
+
# model : Class name of model you want to use. Must be in model_dict.keys()
|
67 |
+
# **kwargs: Kwargs for model initialization
|
68 |
+
#
|
69 |
+
# Returns: Corresponding model
|
70 |
+
#
|
71 |
+
# """
|
72 |
+
# assert model in model_dict.keys(), f"class {model} doesn't exist!"
|
73 |
+
# return model_dict[model](**kwargs)
|
74 |
+
|
75 |
+
|
76 |
+
########################################################################
|
77 |
+
# Version 2 #
|
78 |
+
########################################################################
|
79 |
+
# register function as a wrapper for all models
|
80 |
+
def register_model(cls):
|
81 |
+
global now_cls
|
82 |
+
now_cls = cls
|
83 |
+
return cls
|
84 |
+
|
85 |
+
|
86 |
+
now_cls = None
|
87 |
+
|
88 |
+
|
89 |
+
class ModelInterface:
|
90 |
+
@classmethod
|
91 |
+
def init_model(cls, model_py_path: str, **kwargs):
|
92 |
+
"""
|
93 |
+
|
94 |
+
Args:
|
95 |
+
model_py_path: Py file Path of model you want to use.
|
96 |
+
**kwargs: Kwargs for model initialization
|
97 |
+
|
98 |
+
Returns: Corresponding model
|
99 |
+
"""
|
100 |
+
sub_dirs = model_py_path.split(os.sep)
|
101 |
+
cmd = f"from {'.' + '.'.join(sub_dirs[:-1])} import {sub_dirs[-1]}"
|
102 |
+
exec(cmd)
|
103 |
+
|
104 |
+
return now_cls(**kwargs)
|
utils/constants.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
|
3 |
+
|
4 |
+
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
|
5 |
+
aa_list = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
|
6 |
+
|
7 |
+
foldseek_seq_vocab = "ACDEFGHIKLMNPQRSTVWY#"
|
8 |
+
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"
|
9 |
+
|
10 |
+
struc_unit = "abcdefghijklmnopqrstuvwxyz"
|
11 |
+
|
12 |
+
|
13 |
+
def create_vocab(size: int) -> dict:
|
14 |
+
"""
|
15 |
+
|
16 |
+
Args:
|
17 |
+
size: Size of the vocabulary
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
vocab: Vocabulary
|
21 |
+
"""
|
22 |
+
|
23 |
+
token_len = 1
|
24 |
+
while size > len(struc_unit) ** token_len:
|
25 |
+
token_len += 1
|
26 |
+
|
27 |
+
vocab = {}
|
28 |
+
for i, token in enumerate(itertools.product(struc_unit, repeat=token_len)):
|
29 |
+
vocab[i] = "".join(token)
|
30 |
+
if len(vocab) == size:
|
31 |
+
vocab[i+1] = "#"
|
32 |
+
return vocab
|
33 |
+
|
34 |
+
# ProTrek
|
35 |
+
residue_level = {"Active site", "Binding site", "Site", "DNA binding", "Natural variant", "Mutagenesis",
|
36 |
+
"Transmembrane", "Topological domain", "Intramembrane", "Signal peptide", "Propeptide",
|
37 |
+
"Transit peptide",
|
38 |
+
"Chain", "Peptide", "Modified residue", "Lipidation", "Glycosylation", "Disulfide bond",
|
39 |
+
"Cross-link",
|
40 |
+
"Domain", "Repeat", "Compositional bias", "Region", "Coiled coil", "Motif"}
|
41 |
+
|
42 |
+
sequence_level = {"Function", "Miscellaneous", "Caution", "Catalytic activity", "Cofactor", "Activity regulation",
|
43 |
+
"Biophysicochemical properties", "Pathway", "Involvement in disease", "Allergenic properties",
|
44 |
+
"Toxic dose", "Pharmaceutical use", "Disruption phenotype", "Subcellular location",
|
45 |
+
"Post-translational modification", "Subunit", "Domain (non-positional annotation)",
|
46 |
+
"Sequence similarities", "RNA Editing", "Tissue specificity", "Developmental stage", "Induction",
|
47 |
+
"Biotechnology", "Polymorphism", "GO annotation", "Proteomes", "Protein names", "Gene names",
|
48 |
+
"Organism", "Taxonomic lineage", "Virus host"}
|
49 |
+
|
50 |
+
raw_text_level = {"Function", "Subunit", "Tissue specificity", "Disruption phenotype", "Post-translational modification",
|
51 |
+
"Induction", "Miscellaneous", "Sequence similarities", "Developmental stage",
|
52 |
+
"Domain (non-positional annotation)", "Activity regulation", "Caution", "Polymorphism", "Toxic dose",
|
53 |
+
"Allergenic properties", "Pharmaceutical use", "Cofactor", "Biophysicochemical properties",
|
54 |
+
"Subcellular location", "RNA Editing"}
|
utils/downloader.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
from utils.mpr import MultipleProcessRunner
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class Downloader(MultipleProcessRunner):
|
9 |
+
"""
|
10 |
+
Download files that has unified resource locator
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, base_url, save_path, overwrite=False, skip_error_info=False, **kwargs):
|
14 |
+
"""
|
15 |
+
|
16 |
+
Args:
|
17 |
+
base_url: Unified Resource Locator of pdb file
|
18 |
+
save_path: Unified Resource Locator of saving path
|
19 |
+
overwrite: whether overwrite existing files
|
20 |
+
"""
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
|
23 |
+
self.base_url = base_url
|
24 |
+
self.save_path = save_path
|
25 |
+
self.overwrite = overwrite
|
26 |
+
self.skip_error_info = skip_error_info
|
27 |
+
|
28 |
+
if not overwrite:
|
29 |
+
# remove existing files in data
|
30 |
+
self.data = [uniprot for uniprot in tqdm(self.data, desc="Filtering out existing files...")
|
31 |
+
if not os.path.exists(self.save_path.format(uniprot))]
|
32 |
+
|
33 |
+
def _aggregate(self, final_path: str, sub_paths):
|
34 |
+
pass
|
35 |
+
|
36 |
+
def _target_static(self, process_id, data, sub_path, *args):
|
37 |
+
for i, uniprot in enumerate(data):
|
38 |
+
url = self.base_url.format(uniprot)
|
39 |
+
save_path = self.save_path.format(uniprot)
|
40 |
+
|
41 |
+
# shell cmd to download files
|
42 |
+
wget = f"wget -q -o /dev/null {url} -O {save_path}"
|
43 |
+
|
44 |
+
rm = f"rm {save_path}"
|
45 |
+
err = f"echo 'Error: {url} cannot be downloaded!'"
|
46 |
+
if self.skip_error_info:
|
47 |
+
err += ">/dev/null"
|
48 |
+
|
49 |
+
os.system(f"{wget} || ({rm} && {err})")
|
50 |
+
|
51 |
+
self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} Downloading files...")
|
52 |
+
|
53 |
+
def run(self):
|
54 |
+
"""
|
55 |
+
Run this function to download files
|
56 |
+
"""
|
57 |
+
super().run()
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.data)
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
# Clear empty files in specific directory
|
64 |
+
def clear_empty_files(path):
|
65 |
+
cnt = 0
|
66 |
+
for file in tqdm(os.listdir(path), desc="Clearing empty files..."):
|
67 |
+
if os.path.getsize(os.path.join(path, file)) == 0:
|
68 |
+
os.remove(os.path.join(path, file))
|
69 |
+
cnt += 1
|
70 |
+
print(f"Removed {cnt} empty files")
|
71 |
+
return cnt
|
72 |
+
|
73 |
+
|
74 |
+
class AlphaDBDownloader(Downloader):
|
75 |
+
"""
|
76 |
+
Download files from AlphaFold2 database
|
77 |
+
"""
|
78 |
+
def __init__(self, uniprot_ids, type: str, save_dir: str, **kwargs):
|
79 |
+
"""
|
80 |
+
|
81 |
+
Args:
|
82 |
+
uniprots: Uniprot ids
|
83 |
+
type: Which type of files to download. Must be one of ['pdb', 'mmcif', 'plddt', "pae"]
|
84 |
+
save_dir: Saving directory
|
85 |
+
**kwargs:
|
86 |
+
"""
|
87 |
+
|
88 |
+
url_dict = {
|
89 |
+
"pdb": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.pdb",
|
90 |
+
"mmcif": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.cif",
|
91 |
+
"plddt": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-confidence_v4.json",
|
92 |
+
"pae": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-predicted_aligned_error_v4.json"
|
93 |
+
}
|
94 |
+
|
95 |
+
save_dict = {
|
96 |
+
"pdb": "{}.pdb",
|
97 |
+
"mmcif": "{}.cif",
|
98 |
+
"plddt": "{}.json",
|
99 |
+
"pae": "{}.json"
|
100 |
+
}
|
101 |
+
base_url = url_dict[type]
|
102 |
+
save_path = os.path.join(save_dir, save_dict[type])
|
103 |
+
|
104 |
+
super().__init__(data=uniprot_ids, base_url=base_url, save_path=save_path, **kwargs)
|
105 |
+
|
106 |
+
|
107 |
+
class PDBDownloader(Downloader):
|
108 |
+
"""
|
109 |
+
Download files from PDB
|
110 |
+
"""
|
111 |
+
def __init__(self, pdb_ids, type: str, save_dir: str, **kwargs):
|
112 |
+
"""
|
113 |
+
|
114 |
+
Args:
|
115 |
+
pdb_ids: PDB ids
|
116 |
+
type: Which type of files to download. Must be one of ['pdb', 'mmcif']
|
117 |
+
save_dir: Saving directory
|
118 |
+
"""
|
119 |
+
|
120 |
+
url_dict = {
|
121 |
+
"pdb": "https://files.rcsb.org/download/{}.pdb",
|
122 |
+
"mmcif": "https://files.rcsb.org/download/{}.cif"
|
123 |
+
}
|
124 |
+
|
125 |
+
save_dict = {
|
126 |
+
"pdb": "{}.pdb",
|
127 |
+
"mmcif": "{}.cif"
|
128 |
+
}
|
129 |
+
|
130 |
+
base_url = url_dict[type]
|
131 |
+
save_path = os.path.join(save_dir, save_dict[type])
|
132 |
+
|
133 |
+
super().__init__(data=pdb_ids, base_url=base_url, save_path=save_path, **kwargs)
|
134 |
+
|
135 |
+
|
136 |
+
class CATHDownloader(Downloader):
|
137 |
+
def __init__(self, cath_ids, save_dir, **kwargs):
|
138 |
+
"""
|
139 |
+
Download files from CATH
|
140 |
+
Args:
|
141 |
+
cath_ids: CATH ids
|
142 |
+
save_dir: Saving directory
|
143 |
+
"""
|
144 |
+
|
145 |
+
url = "http://www.cathdb.info/version/v4_3_0/api/rest/id/{}.pdb"
|
146 |
+
save_path = os.path.join(save_dir, "{}.pdb")
|
147 |
+
|
148 |
+
super().__init__(data=cath_ids, base_url=url, save_path=save_path, **kwargs)
|
149 |
+
|
150 |
+
|
151 |
+
def download_pdb(pdb_id: str, format: str, save_path: str):
|
152 |
+
"""
|
153 |
+
Download pdb file from PDB
|
154 |
+
Args:
|
155 |
+
pdb_id: PDB id
|
156 |
+
format: File , must be one of ['pdb', 'cif']
|
157 |
+
save_path: Saving path
|
158 |
+
"""
|
159 |
+
|
160 |
+
url = f"https://files.rcsb.org/download/{pdb_id}.{format}"
|
161 |
+
wget = f"wget -q -o /dev/null {url} -O {save_path}"
|
162 |
+
rm = f"rm {save_path}"
|
163 |
+
err = f"echo 'Error: {url} cannot be downloaded!'"
|
164 |
+
os.system(f"{wget} || ({rm} && {err})")
|
165 |
+
|
166 |
+
|
167 |
+
def download_af2(uniprot_id: str, format: str, save_path: str):
|
168 |
+
"""
|
169 |
+
Download files from AlphaFold2 database
|
170 |
+
Args:
|
171 |
+
uniprot_id: Uniprot id
|
172 |
+
format: File format, must be one of ['pdb', 'cif', 'plddt', 'pae']
|
173 |
+
save_path: Saving path
|
174 |
+
"""
|
175 |
+
|
176 |
+
url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.{format}"
|
177 |
+
wget = f"wget -q -o /dev/null {url} -O {save_path}"
|
178 |
+
rm = f"rm {save_path}"
|
179 |
+
err = f"echo 'Error: {url} cannot be downloaded!'"
|
180 |
+
os.system(f"{wget} || ({rm} && {err})")
|
utils/foldseek_util.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
sys.path.append(".")
|
8 |
+
|
9 |
+
|
10 |
+
# Get structural seqs from pdb file
|
11 |
+
def get_struc_seq(foldseek,
|
12 |
+
path,
|
13 |
+
chains: list = None,
|
14 |
+
process_id: int = 0,
|
15 |
+
plddt_mask: bool = False,
|
16 |
+
plddt_threshold: float = 70.,
|
17 |
+
foldseek_verbose: bool = False) -> dict:
|
18 |
+
"""
|
19 |
+
|
20 |
+
Args:
|
21 |
+
foldseek: Binary executable file of foldseek
|
22 |
+
|
23 |
+
path: Path to pdb file
|
24 |
+
|
25 |
+
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
|
26 |
+
|
27 |
+
process_id: Process ID for temporary files. This is used for parallel processing.
|
28 |
+
|
29 |
+
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
|
30 |
+
|
31 |
+
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
|
32 |
+
|
33 |
+
foldseek_verbose: If True, foldseek will print verbose messages.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
|
37 |
+
(seq, struc_seq, combined_seq).
|
38 |
+
"""
|
39 |
+
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
40 |
+
assert os.path.exists(path), f"PDB file not found: {path}"
|
41 |
+
|
42 |
+
tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
|
43 |
+
if foldseek_verbose:
|
44 |
+
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
45 |
+
else:
|
46 |
+
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
47 |
+
os.system(cmd)
|
48 |
+
|
49 |
+
seq_dict = {}
|
50 |
+
name = os.path.basename(path)
|
51 |
+
with open(tmp_save_path, "r") as r:
|
52 |
+
for i, line in enumerate(r):
|
53 |
+
desc, seq, struc_seq = line.split("\t")[:3]
|
54 |
+
|
55 |
+
# Mask low plddt
|
56 |
+
if plddt_mask:
|
57 |
+
plddts = extract_plddt(path)
|
58 |
+
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
|
59 |
+
|
60 |
+
# Mask regions with plddt < threshold
|
61 |
+
indices = np.where(plddts < plddt_threshold)[0]
|
62 |
+
np_seq = np.array(list(struc_seq))
|
63 |
+
np_seq[indices] = "#"
|
64 |
+
struc_seq = "".join(np_seq)
|
65 |
+
|
66 |
+
name_chain = desc.split(" ")[0]
|
67 |
+
chain = name_chain.replace(name, "").split("_")[-1]
|
68 |
+
|
69 |
+
if chains is None or chain in chains:
|
70 |
+
if chain not in seq_dict:
|
71 |
+
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
|
72 |
+
seq_dict[chain] = (seq, struc_seq, combined_seq)
|
73 |
+
|
74 |
+
os.remove(tmp_save_path)
|
75 |
+
os.remove(tmp_save_path + ".dbtype")
|
76 |
+
return seq_dict
|
77 |
+
|
78 |
+
|
79 |
+
def extract_plddt(pdb_path: str) -> np.ndarray:
|
80 |
+
"""
|
81 |
+
Extract plddt scores from pdb file.
|
82 |
+
Args:
|
83 |
+
pdb_path: Path to pdb file.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
plddts: plddt scores.
|
87 |
+
"""
|
88 |
+
with open(pdb_path, "r") as r:
|
89 |
+
plddt_dict = {}
|
90 |
+
for line in r:
|
91 |
+
line = re.sub(' +', ' ', line).strip()
|
92 |
+
splits = line.split(" ")
|
93 |
+
|
94 |
+
if splits[0] == "ATOM":
|
95 |
+
# If position < 1000
|
96 |
+
if len(splits[4]) == 1:
|
97 |
+
pos = int(splits[5])
|
98 |
+
|
99 |
+
# If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
|
100 |
+
# So the length of splits[4] is not 1
|
101 |
+
else:
|
102 |
+
pos = int(splits[4][1:])
|
103 |
+
|
104 |
+
plddt = float(splits[-2])
|
105 |
+
|
106 |
+
if pos not in plddt_dict:
|
107 |
+
plddt_dict[pos] = [plddt]
|
108 |
+
else:
|
109 |
+
plddt_dict[pos].append(plddt)
|
110 |
+
|
111 |
+
plddts = np.array([np.mean(v) for v in plddt_dict.values()])
|
112 |
+
return plddts
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
foldseek = "/sujin/bin/foldseek"
|
117 |
+
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
|
118 |
+
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
|
119 |
+
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
|
120 |
+
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
|
121 |
+
print(res["A"][1].lower())
|
utils/lr_scheduler.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
|
4 |
+
|
5 |
+
|
6 |
+
class ConstantLRScheduler(_LRScheduler):
|
7 |
+
def __init__(self,
|
8 |
+
optimizer,
|
9 |
+
last_epoch: int = -1,
|
10 |
+
verbose: bool = False,
|
11 |
+
init_lr: float = 0.,
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
This is an implementation of constant learning rate scheduler.
|
15 |
+
Args:
|
16 |
+
optimizer: Optimizer
|
17 |
+
|
18 |
+
last_epoch: The index of last epoch. Default: -1
|
19 |
+
|
20 |
+
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
|
21 |
+
|
22 |
+
init_lr: Initial learning rate
|
23 |
+
"""
|
24 |
+
|
25 |
+
self.init_lr = init_lr
|
26 |
+
super().__init__(optimizer, last_epoch, verbose)
|
27 |
+
|
28 |
+
def state_dict(self):
|
29 |
+
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
30 |
+
return state_dict
|
31 |
+
|
32 |
+
def load_state_dict(self, state_dict):
|
33 |
+
self.__dict__.update(state_dict)
|
34 |
+
|
35 |
+
def get_lr(self):
|
36 |
+
if not self._get_lr_called_within_step:
|
37 |
+
raise RuntimeError(
|
38 |
+
"To get the last learning rate computed by the scheduler, use "
|
39 |
+
"get_last_lr()"
|
40 |
+
)
|
41 |
+
|
42 |
+
return [self.init_lr for group in self.optimizer.param_groups]
|
43 |
+
|
44 |
+
|
45 |
+
class CosineAnnealingLRScheduler(_LRScheduler):
|
46 |
+
def __init__(self,
|
47 |
+
optimizer,
|
48 |
+
last_epoch: int = -1,
|
49 |
+
verbose: bool = False,
|
50 |
+
init_lr: float = 0.,
|
51 |
+
max_lr: float = 4e-4,
|
52 |
+
final_lr: float = 4e-5,
|
53 |
+
warmup_steps: int = 2000,
|
54 |
+
cosine_steps: int = 10000,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
This is an implementation of cosine annealing learning rate scheduler.
|
58 |
+
Args:
|
59 |
+
optimizer: Optimizer
|
60 |
+
|
61 |
+
last_epoch: The index of last epoch. Default: -1
|
62 |
+
|
63 |
+
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
|
64 |
+
|
65 |
+
init_lr: Initial learning rate
|
66 |
+
|
67 |
+
max_lr: Maximum learning rate after warmup
|
68 |
+
|
69 |
+
final_lr: Final learning rate after decay
|
70 |
+
|
71 |
+
warmup_steps: Number of steps for warmup
|
72 |
+
|
73 |
+
cosine_steps: Number of steps for cosine annealing
|
74 |
+
"""
|
75 |
+
|
76 |
+
self.init_lr = init_lr
|
77 |
+
self.max_lr = max_lr
|
78 |
+
self.final_lr = final_lr
|
79 |
+
self.warmup_steps = warmup_steps
|
80 |
+
self.cosine_steps = cosine_steps
|
81 |
+
super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)
|
82 |
+
|
83 |
+
def state_dict(self):
|
84 |
+
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
85 |
+
return state_dict
|
86 |
+
|
87 |
+
def load_state_dict(self, state_dict):
|
88 |
+
self.__dict__.update(state_dict)
|
89 |
+
|
90 |
+
def get_lr(self):
|
91 |
+
if not self._get_lr_called_within_step:
|
92 |
+
raise RuntimeError(
|
93 |
+
"To get the last learning rate computed by the scheduler, use "
|
94 |
+
"get_last_lr()"
|
95 |
+
)
|
96 |
+
|
97 |
+
step_no = self.last_epoch
|
98 |
+
|
99 |
+
if step_no <= self.warmup_steps:
|
100 |
+
lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
|
101 |
+
|
102 |
+
else:
|
103 |
+
lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \
|
104 |
+
* (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps))
|
105 |
+
|
106 |
+
return [lr for group in self.optimizer.param_groups]
|
107 |
+
|
108 |
+
|
109 |
+
class Esm2LRScheduler(_LRScheduler):
|
110 |
+
def __init__(self,
|
111 |
+
optimizer,
|
112 |
+
last_epoch: int = -1,
|
113 |
+
verbose: bool = False,
|
114 |
+
init_lr: float = 0.,
|
115 |
+
max_lr: float = 4e-4,
|
116 |
+
final_lr: float = 4e-5,
|
117 |
+
warmup_steps: int = 2000,
|
118 |
+
start_decay_after_n_steps: int = 500000,
|
119 |
+
end_decay_after_n_steps: int = 5000000,
|
120 |
+
on_use: bool = True,
|
121 |
+
):
|
122 |
+
"""
|
123 |
+
This is an implementation of ESM2's learning rate scheduler.
|
124 |
+
Args:
|
125 |
+
optimizer: Optimizer
|
126 |
+
|
127 |
+
last_epoch: The index of last epoch. Default: -1
|
128 |
+
|
129 |
+
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
|
130 |
+
|
131 |
+
init_lr: Initial learning rate
|
132 |
+
|
133 |
+
max_lr: Maximum learning rate after warmup
|
134 |
+
|
135 |
+
final_lr: Final learning rate after decay
|
136 |
+
|
137 |
+
warmup_steps: Number of steps for warmup
|
138 |
+
|
139 |
+
start_decay_after_n_steps: Start decay after this number of steps
|
140 |
+
|
141 |
+
end_decay_after_n_steps: End decay after this number of steps
|
142 |
+
|
143 |
+
on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
|
144 |
+
and will only use the ``init_lr``. Default: ``True``
|
145 |
+
"""
|
146 |
+
|
147 |
+
self.init_lr = init_lr
|
148 |
+
self.max_lr = max_lr
|
149 |
+
self.final_lr = final_lr
|
150 |
+
self.warmup_steps = warmup_steps
|
151 |
+
self.start_decay_after_n_steps = start_decay_after_n_steps
|
152 |
+
self.end_decay_after_n_steps = end_decay_after_n_steps
|
153 |
+
self.on_use = on_use
|
154 |
+
super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)
|
155 |
+
|
156 |
+
def state_dict(self):
|
157 |
+
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
158 |
+
return state_dict
|
159 |
+
|
160 |
+
def load_state_dict(self, state_dict):
|
161 |
+
self.__dict__.update(state_dict)
|
162 |
+
|
163 |
+
def get_lr(self):
|
164 |
+
if not self._get_lr_called_within_step:
|
165 |
+
raise RuntimeError(
|
166 |
+
"To get the last learning rate computed by the scheduler, use "
|
167 |
+
"get_last_lr()"
|
168 |
+
)
|
169 |
+
|
170 |
+
step_no = self.last_epoch
|
171 |
+
if not self.on_use:
|
172 |
+
return [base_lr for base_lr in self.base_lrs]
|
173 |
+
|
174 |
+
if step_no <= self.warmup_steps:
|
175 |
+
lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
|
176 |
+
|
177 |
+
elif step_no <= self.start_decay_after_n_steps:
|
178 |
+
lr = self.max_lr
|
179 |
+
|
180 |
+
elif step_no <= self.end_decay_after_n_steps:
|
181 |
+
portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps)
|
182 |
+
lr = self.max_lr - portion * (self.max_lr - self.final_lr)
|
183 |
+
|
184 |
+
else:
|
185 |
+
lr = self.final_lr
|
186 |
+
|
187 |
+
return [lr for group in self.optimizer.param_groups]
|
utils/mpr.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import sys
|
5 |
+
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
from math import ceil
|
9 |
+
|
10 |
+
|
11 |
+
class MultipleProcessRunner:
|
12 |
+
"""
|
13 |
+
Abstarct class for running tasks with multiple process
|
14 |
+
There are three abstract methods that should be implemented:
|
15 |
+
1. __len__() : return the length of data
|
16 |
+
2. _target() : target function for each process
|
17 |
+
3. _aggregate() : aggregate results from each process
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
data,
|
22 |
+
save_path=None,
|
23 |
+
n_process=1,
|
24 |
+
verbose=True,
|
25 |
+
total_only=True,
|
26 |
+
log_step=1,
|
27 |
+
start_method='fork'):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
data : data to be processed that can be sliced
|
31 |
+
|
32 |
+
path : final output path
|
33 |
+
|
34 |
+
n_process: number of process
|
35 |
+
|
36 |
+
verbose : if True, display progress bar
|
37 |
+
|
38 |
+
total_only: If True, only total progress bar is displayed
|
39 |
+
|
40 |
+
log_step : For total progress bar, Next log will be printed when
|
41 |
+
``current iteration`` - ``last log iteration`` >= log_step
|
42 |
+
|
43 |
+
start_method: start method for multiprocessing
|
44 |
+
"""
|
45 |
+
self.data = data
|
46 |
+
self.save_path = save_path
|
47 |
+
self.n_process = n_process
|
48 |
+
self.verbose = verbose
|
49 |
+
self.total_only = total_only
|
50 |
+
self.log_step = log_step
|
51 |
+
self.start_method = start_method
|
52 |
+
|
53 |
+
# get terminal width to format output
|
54 |
+
try:
|
55 |
+
self.terminal_y = os.get_terminal_size()[0]
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
print(e)
|
59 |
+
print("Can't get terminal size, set terminal_y = None")
|
60 |
+
self.terminal_y = None
|
61 |
+
|
62 |
+
def _s2hms(self, seconds: float):
|
63 |
+
"""
|
64 |
+
convert second format of time into hour:minute:second format
|
65 |
+
|
66 |
+
"""
|
67 |
+
m, s = divmod(seconds, 60)
|
68 |
+
h, m = divmod(m, 60)
|
69 |
+
|
70 |
+
return "%02d:%02d:%02d" % (h, m, s)
|
71 |
+
|
72 |
+
def _display_time(self, st_time, now, total):
|
73 |
+
ed_time = time.time()
|
74 |
+
running_time = ed_time - st_time
|
75 |
+
rest_time = running_time * (total - now) / now
|
76 |
+
iter_sec = f"{now / running_time:.2f}it/s" if now > running_time else f"{running_time / now:.2f}s/it"
|
77 |
+
|
78 |
+
return f' [{self._s2hms(running_time)} < {self._s2hms(rest_time)}, {iter_sec}]'
|
79 |
+
|
80 |
+
def _display_bar(self, now, total, length):
|
81 |
+
now = now if now <= total else total
|
82 |
+
num = now * length // total
|
83 |
+
progress_bar = '[' + '#' * num + '_' * (length - num) + ']'
|
84 |
+
return progress_bar
|
85 |
+
|
86 |
+
def _display_all(self, now, total, desc, st_time):
|
87 |
+
# make a progress bar
|
88 |
+
length = 50
|
89 |
+
progress_bar = self._display_bar(now, total, length)
|
90 |
+
time_display = self._display_time(st_time, now, total)
|
91 |
+
|
92 |
+
display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
|
93 |
+
|
94 |
+
# Clean a line
|
95 |
+
width = self.terminal_y if self.terminal_y is not None else 100
|
96 |
+
num_space = width - len(display)
|
97 |
+
if num_space > 0:
|
98 |
+
display += ' ' * num_space
|
99 |
+
else:
|
100 |
+
length += num_space
|
101 |
+
progress_bar = self._display_bar(now, total, length)
|
102 |
+
display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
|
103 |
+
|
104 |
+
# Set color
|
105 |
+
display = f"\033[31m{display}\033[0m"
|
106 |
+
|
107 |
+
return display
|
108 |
+
|
109 |
+
# Print progress bar at specific position in terminal
|
110 |
+
def terminal_progress_bar(self,
|
111 |
+
process_id: int,
|
112 |
+
now: int,
|
113 |
+
total: int,
|
114 |
+
desc: str = ''):
|
115 |
+
"""
|
116 |
+
|
117 |
+
Args:
|
118 |
+
process_id: process id
|
119 |
+
now: now iteration number
|
120 |
+
total: total iteration number
|
121 |
+
desc: description
|
122 |
+
|
123 |
+
"""
|
124 |
+
st_time = self.process_st_time[process_id]
|
125 |
+
|
126 |
+
# Aggregate total information
|
127 |
+
self.counts[process_id] = now
|
128 |
+
self._total_display(self.process_st_time["total"])
|
129 |
+
|
130 |
+
if not self.total_only:
|
131 |
+
process_display = self._display_all(now, total, desc, st_time)
|
132 |
+
if self.terminal_y is not None:
|
133 |
+
sys.stdout.write(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8")
|
134 |
+
sys.stdout.flush()
|
135 |
+
else:
|
136 |
+
print(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8", flush=True)
|
137 |
+
|
138 |
+
# Print global information
|
139 |
+
def _total_display(self, st_time):
|
140 |
+
if self.total_display_callable.value == 1:
|
141 |
+
self.total_display_callable.value = 0
|
142 |
+
|
143 |
+
cnt = sum([self.counts[i] for i in range(self.n_process)])
|
144 |
+
if cnt - self.last_cnt.value >= self.log_step:
|
145 |
+
total_display = self._display_all(cnt, self.__len__(), f"Total: ", st_time)
|
146 |
+
self.last_cnt.value = cnt
|
147 |
+
|
148 |
+
x = self.n_process + 1 if not self.total_only else 0
|
149 |
+
# if self.terminal_y is not None:
|
150 |
+
# sys.stdout.write(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8")
|
151 |
+
# sys.stdout.flush()
|
152 |
+
# else:
|
153 |
+
# print(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True)
|
154 |
+
print(f"\r\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True, end="")
|
155 |
+
|
156 |
+
self.total_display_callable.value = 1
|
157 |
+
|
158 |
+
def run(self):
|
159 |
+
"""
|
160 |
+
The function is used to run a multi-process task
|
161 |
+
Returns: return the result of function '_aggregate()'
|
162 |
+
"""
|
163 |
+
|
164 |
+
import multiprocess as mp
|
165 |
+
mp.set_start_method(self.start_method, force=True)
|
166 |
+
|
167 |
+
# total number of data that is already processed
|
168 |
+
self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
|
169 |
+
|
170 |
+
# record start time for each process
|
171 |
+
self.process_st_time = {"total": time.time()}
|
172 |
+
|
173 |
+
# set a lock to call total number display
|
174 |
+
self.total_display_callable = mp.Value('d', 1)
|
175 |
+
|
176 |
+
# Save last log iteration number
|
177 |
+
self.last_cnt = mp.Value('d', 0)
|
178 |
+
|
179 |
+
num_per_process = ceil(self.__len__() / self.n_process)
|
180 |
+
|
181 |
+
if self.save_path is not None:
|
182 |
+
file_name, suffix = os.path.splitext(self.save_path)
|
183 |
+
|
184 |
+
process_list = []
|
185 |
+
sub_paths = []
|
186 |
+
for i in range(self.n_process):
|
187 |
+
st = i * num_per_process
|
188 |
+
ed = st + num_per_process
|
189 |
+
|
190 |
+
# construct slice and sub path for sub process
|
191 |
+
data_slice = self.data[st: ed]
|
192 |
+
|
193 |
+
sub_path = None
|
194 |
+
# Create a directory to save sub-results
|
195 |
+
if self.save_path is not None:
|
196 |
+
save_dir = f"{file_name}{suffix}_temp"
|
197 |
+
os.makedirs(save_dir, exist_ok=True)
|
198 |
+
sub_path = f"{save_dir}/temp_{i}{suffix}"
|
199 |
+
|
200 |
+
# construct sub process
|
201 |
+
input_args = (i, data_slice, sub_path)
|
202 |
+
self.process_st_time[i] = time.time()
|
203 |
+
p = mp.Process(target=self._target, args=input_args)
|
204 |
+
p.start()
|
205 |
+
|
206 |
+
process_list.append(p)
|
207 |
+
sub_paths.append(sub_path)
|
208 |
+
|
209 |
+
for p in process_list:
|
210 |
+
p.join()
|
211 |
+
|
212 |
+
# aggregate results and remove temporary directory
|
213 |
+
results = self._aggregate(self.save_path, sub_paths)
|
214 |
+
if self.save_path is not None:
|
215 |
+
save_dir = f"{file_name}{suffix}_temp"
|
216 |
+
os.rmdir(save_dir)
|
217 |
+
|
218 |
+
return results
|
219 |
+
|
220 |
+
def parallel_run(self):
|
221 |
+
import multiprocess as mp
|
222 |
+
from joblib import Parallel, delayed
|
223 |
+
|
224 |
+
# total number of data that is already processed
|
225 |
+
self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
|
226 |
+
|
227 |
+
# record start time for each process
|
228 |
+
self.process_st_time = {"total": time.time()}
|
229 |
+
|
230 |
+
# set a lock to call total number display
|
231 |
+
self.total_display_callable = mp.Value('d', 1)
|
232 |
+
|
233 |
+
# Save last log iteration number
|
234 |
+
self.last_cnt = mp.Value('d', 0)
|
235 |
+
|
236 |
+
num_per_process = ceil(self.__len__() / self.n_process)
|
237 |
+
|
238 |
+
if self.save_path is not None:
|
239 |
+
file_name, suffix = os.path.splitext(self.save_path)
|
240 |
+
|
241 |
+
sub_paths = []
|
242 |
+
input_arg_list = []
|
243 |
+
for i in range(self.n_process):
|
244 |
+
st = i * num_per_process
|
245 |
+
ed = st + num_per_process
|
246 |
+
|
247 |
+
# construct slice and sub path for sub process
|
248 |
+
data_slice = self.data[st: ed]
|
249 |
+
|
250 |
+
sub_path = None
|
251 |
+
# Create a directory to save sub-results
|
252 |
+
if self.save_path is not None:
|
253 |
+
save_dir = f"{file_name}{suffix}_temp"
|
254 |
+
os.makedirs(save_dir, exist_ok=True)
|
255 |
+
sub_path = f"{save_dir}/temp_{i}{suffix}"
|
256 |
+
|
257 |
+
# construct sub process
|
258 |
+
input_args = (i, data_slice, sub_path)
|
259 |
+
self.process_st_time[i] = time.time()
|
260 |
+
|
261 |
+
sub_paths.append(sub_path)
|
262 |
+
input_arg_list.append(input_args)
|
263 |
+
|
264 |
+
# Start parallel processing
|
265 |
+
Parallel(n_jobs=self.n_process)(delayed(self._target)(input_args) for input_args in input_arg_list)
|
266 |
+
|
267 |
+
# aggregate results and remove temporary directory
|
268 |
+
results = self._aggregate(self.save_path, sub_paths)
|
269 |
+
if self.save_path is not None:
|
270 |
+
save_dir = f"{file_name}{suffix}_temp"
|
271 |
+
os.rmdir(save_dir)
|
272 |
+
|
273 |
+
return results
|
274 |
+
|
275 |
+
|
276 |
+
@abc.abstractmethod
|
277 |
+
def _aggregate(self, final_path: str, sub_paths):
|
278 |
+
"""
|
279 |
+
This function is used to aggregate results from sub processes into a file
|
280 |
+
|
281 |
+
Args:
|
282 |
+
final_path: path to save final results
|
283 |
+
sub_paths : list of sub paths
|
284 |
+
|
285 |
+
Returns: None or desirable results specified by user
|
286 |
+
|
287 |
+
"""
|
288 |
+
raise NotImplementedError
|
289 |
+
|
290 |
+
@abc.abstractmethod
|
291 |
+
def _target(self, process_id, data, sub_path):
|
292 |
+
"""
|
293 |
+
The main body to operate data in one process
|
294 |
+
|
295 |
+
Args:
|
296 |
+
i : process id
|
297 |
+
data : data slice
|
298 |
+
sub_path: sub path to save results
|
299 |
+
"""
|
300 |
+
raise NotImplementedError
|
301 |
+
|
302 |
+
@abc.abstractmethod
|
303 |
+
def __len__(self):
|
304 |
+
raise NotImplementedError
|
305 |
+
|
306 |
+
|
307 |
+
class MultipleProcessRunnerSimplifier(MultipleProcessRunner):
|
308 |
+
"""
|
309 |
+
A simplified version of MultipleProcessRunner.
|
310 |
+
User only need to implement the function 'do', then it will be automatically executed
|
311 |
+
in every iteration after call the function 'run'.
|
312 |
+
If 'save_path' is specified, it will open a file in the 'sub_path' into which
|
313 |
+
user can write results, and results will be aggregated into 'save_path'.
|
314 |
+
|
315 |
+
The procedure would be like:
|
316 |
+
...
|
317 |
+
with open(sub_path, 'w') as w:
|
318 |
+
for i, d in enumerate(data):
|
319 |
+
self.do(process_id, i, d, w) # You can write results into the file.
|
320 |
+
...
|
321 |
+
|
322 |
+
The 'do' function should be like:
|
323 |
+
def do(process_id, idx, data, writer):
|
324 |
+
...
|
325 |
+
|
326 |
+
If 'save_path' is None, the argument 'writer' will be set to None.
|
327 |
+
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self,
|
331 |
+
data,
|
332 |
+
do,
|
333 |
+
save_path=None,
|
334 |
+
n_process=1,
|
335 |
+
verbose=True,
|
336 |
+
total_only=True,
|
337 |
+
log_step=1,
|
338 |
+
return_results=False,
|
339 |
+
start_method='fork'):
|
340 |
+
|
341 |
+
super().__init__(data=data,
|
342 |
+
save_path=save_path,
|
343 |
+
n_process=n_process,
|
344 |
+
verbose=verbose,
|
345 |
+
total_only=total_only,
|
346 |
+
log_step=log_step,
|
347 |
+
start_method=start_method)
|
348 |
+
self.do = do
|
349 |
+
self.return_results = return_results
|
350 |
+
|
351 |
+
def run(self):
|
352 |
+
self.start_time = time.time()
|
353 |
+
return super().run()
|
354 |
+
|
355 |
+
def _aggregate(self, final_path: str, sub_paths):
|
356 |
+
results = []
|
357 |
+
|
358 |
+
w = open(final_path, 'w') if final_path is not None else None
|
359 |
+
|
360 |
+
if self.verbose:
|
361 |
+
iterator = tqdm(enumerate(sub_paths), "Aggregating results...")
|
362 |
+
else:
|
363 |
+
iterator = enumerate(sub_paths)
|
364 |
+
|
365 |
+
for i, sub_path in iterator:
|
366 |
+
if sub_path is None and self.return_results:
|
367 |
+
sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{i}.tmp"
|
368 |
+
|
369 |
+
if sub_path is not None:
|
370 |
+
with open(sub_path, 'r') as r:
|
371 |
+
for line in r:
|
372 |
+
if w is not None:
|
373 |
+
w.write(line)
|
374 |
+
|
375 |
+
if self.return_results:
|
376 |
+
results.append(line[:-1])
|
377 |
+
|
378 |
+
os.remove(sub_path)
|
379 |
+
|
380 |
+
return results
|
381 |
+
|
382 |
+
def _target(self, process_id, data, sub_path):
|
383 |
+
if sub_path is None and self.return_results:
|
384 |
+
sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{process_id}.tmp"
|
385 |
+
|
386 |
+
w = open(sub_path, 'w') if sub_path is not None else None
|
387 |
+
for i, d in enumerate(data):
|
388 |
+
self.do(process_id, i, d, w)
|
389 |
+
if self.verbose:
|
390 |
+
self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} running...")
|
391 |
+
|
392 |
+
if w is not None:
|
393 |
+
w.close()
|
394 |
+
|
395 |
+
def __len__(self):
|
396 |
+
return len(self.data)
|
397 |
+
|