LTEnjoy commited on
Commit
52da96f
1 Parent(s): 973b006

Upload 21 files

Browse files
bin/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Place the Foldseek binary file here
demo/__init__.py ADDED
File without changes
demo/config.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_dir: /sujin/Models/ProTrek/ProTrek_650M_UniRef50
2
+ faiss_config:
3
+ IO_FLAG_MMAP: True
4
+ sequence_index_dir:
5
+ - name: UniRef50
6
+ index_dir: /mnt/5t/faiss_index/UniRef50/ProTrek_650M_UniRef50/sequence
7
+ - name: Swiss-Prot
8
+ index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/sequence
9
+ - name: PDB
10
+ index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/sequence
11
+ - name: Uncharacterized
12
+ index_dir: /mnt/5t/faiss_index/Uncharacterized/ProTrek_650M_UniRef50/sequence
13
+
14
+ structure_index_dir:
15
+ - name: Swiss-Prot
16
+ index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/structure
17
+ - name: PDB
18
+ index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/structure
19
+
20
+ text_index_dir:
21
+ - name: UniProt
22
+ index_dir: /mnt/5t/faiss_index/UniRef50/ProTrek_650M_UniRef50/text
23
+ - name: Swiss-Prot
24
+ index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/text
25
+
26
+ #model_dir: /sujin/Models/ProTrek/ProTrek_35M_UniRef50
27
+ #
28
+ #faiss_config:
29
+ # IO_FLAG_MMAP: True
30
+ #
31
+ #sequence_index_dir:
32
+ ## - name: UniRef50
33
+ ## index_dir: /sujin/Datasets/ProTrek/faiss_index/UniRef50/ProTrek_650M_UniRef50/sequence
34
+ # - name: Swiss-Prot
35
+ # index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/sequence
36
+ ## - name: PDB
37
+ ## index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/sequence
38
+ #
39
+ #structure_index_dir:
40
+ # - name: Swiss-Prot
41
+ # index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/structure
42
+ ## - name: PDB
43
+ ## index_dir: /sujin/Datasets/ProTrek/faiss_index/PDB/ProTrek_650M_UniRef50/structure
44
+ #
45
+ #text_index_dir:
46
+ ## - name: UniProt
47
+ ## index_dir: /sujin/Datasets/ProTrek/faiss_index/UniRef50/ProTrek_650M_UniRef50/text
48
+ # - name: Swiss-Prot
49
+ # index_dir: /sujin/Datasets/ProTrek/faiss_index/SwissProt/ProTrek_650M_UniRef50/text
demo/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path += []
4
+
5
+ import argparse
6
+
7
+
8
+ def main():
9
+ pass
10
+
11
+
12
+ def get_args():
13
+ parser = argparse.ArgumentParser()
14
+ return parser.parse_args()
15
+
16
+
17
+ if __name__ == '__main__':
18
+ args = get_args()
19
+ main()
demo/modules/blocks.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from utils.foldseek_util import get_struc_seq
4
+
5
+
6
+ ####################################################
7
+ # gradio blocks #
8
+ ####################################################
9
+ def upload_pdb_button(visible: bool = True, chain_visible: bool = True):
10
+ """
11
+ Provide an upload button to upload a pdb file
12
+ Args:
13
+ visible: Whether the block is visible or not
14
+ """
15
+
16
+ with gr.Column(scale=0):
17
+
18
+ # Which chain to be extracted
19
+ chain_box = gr.Textbox(label="Chain (to be extracted from the pdb file)", value="A",
20
+ visible=chain_visible, interactive=True)
21
+
22
+ upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", visible=visible)
23
+
24
+ return upload_btn, chain_box
25
+
26
+
27
+ ####################################################
28
+ # Trigger functions #
29
+ ####################################################
30
+ def parse_pdb_file(input_type: str, file: str, chain: str) -> str:
31
+ """
32
+ Parse the uploaded structure file
33
+
34
+ Args:
35
+ input_type: Type of input. Must be one of ["protein sequence", "protein structure"]
36
+
37
+ file: Path to the uploaded file
38
+
39
+ chain: Chain to be extracted from the pdb file
40
+
41
+ Returns:
42
+ Protein sequence or Foldseek sequence
43
+ """
44
+ try:
45
+ parsed_seqs = get_struc_seq("bin/foldseek", file, [chain])[chain]
46
+ if input_type == "sequence":
47
+ return parsed_seqs[0]
48
+ else:
49
+ return parsed_seqs[1].lower()
50
+
51
+ except Exception:
52
+ raise gr.Error(f"Chain '{chain}' not found in the pdb file. Please check the chain id and try again.")
53
+
54
+
55
+ def set_upload_visible(visible: bool) -> gr.Interface:
56
+ """
57
+ Set the visibility of the upload button
58
+
59
+ Args:
60
+ visible: Whether the block is visible or not
61
+
62
+ Returns:
63
+ gr.Interface: Updated interface
64
+ """
65
+
66
+ return gr.update(visible=visible)
demo/modules/compute_score.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from .init_model import model
5
+ from .blocks import upload_pdb_button, parse_pdb_file
6
+
7
+
8
+ input_types = ["sequence", "structure", "text"]
9
+
10
+ input_examples = {
11
+ "sequence": [
12
+ "MQLQRLGAPLLKRLVGGCIRQSTAPIMPCVVVSGSGGFLTPVRTYMPLPNDQSDFSPYIEIDLPSESRIQSLHKSGLAAQEWVACEKVHGTNFGIYLINQGDHEVVRFAKRSGIMDPNENFFGYHILIDEFTAQIRILNDLLKQKYGLSRVGRLVLNGELFGAKYKHPLVPKSEKWCTLPNGKKFPIAGVQIQREPFPQYSPELHFFAFDIKYSVSGAEEDFVLLGYDEFVEFSSKVPNLLYARALVRGTLDECLAFDVENFMTPLPALLGLGNYPLEGNLAEGVVIRHVRRGDPAVEKHNVSTIIKLRCSSFMELKHPGKQKELKETFIDTVRSGALRRVRGNVTVISDSMLPQVEAAANDLLLNNVSDGRLSNVLSKIGREPLLSGEVSQVDVALMLAKDALKDFLKEVDSLVLNTTLAFRKLLITNVYFESKRLVEQKWKELMQEEAAAQSEAIPPLSPAAPTKGE",
13
+ "MSLSTEQMLRDYPRSMQINGQIPKNAIHETYGNDGVDVFIAGSGPIGATYAKLCVEAGLRVVMVEIGAADSFYAVNAEEGTAVPYVPGYHKKNEIEFQKDIDRFVNVIKGALQQVSVPVRNQNVPTLDPGAWSAPPGSSAISNGKNPHQREFENLSAEAVTRGVGGMSTHWTCSTPRIHPPMESLPGIGRPKLSNDPAEDDKEWNELYSEAERLIGTSTKEFDESIRHTLVLRSLQDAYKDRQRIFRPLPLACHRLKNAPEYVEWHSAENLFHSIYNDDKQKKLFTLLTNHRCTRLALTGGYEKKIGAAEVRNLLATRNPSSQLDSYIMAKVYVLASGAIGNPQILYNSGFSGLQVTPRNDSLIPNLGRYITEQPMAFCQIVLRQEFVDSVRDDPYGLPWWKEAVAQHIAKNPTDALPIPFRDPEPQVTTPFTEEHPWHTQIHRDAFSYGAVGPEVDSRVIVDLRWFGATDPEANNLLVFQNDVQDGYSMPQPTFRYRPSTASNVRARKMMADMCEVASNLGGYLPTSPPQFMDPGLALHLAGTTRIGFDKATTVADNNSLVWDFANLYVAGNGTIRTGFGENPTLTSMCHAIKSARSIINTLKGGTDGKNTGEHRNL",
14
+ "MGVHECPAWLWLLLSLLSLPLGLPVLGAPPRLICDSRVLERYLLEAKEAENITTGCAEHCSLNENITVPDTKVNFYAWKRMEVGQQAVEVWQGLALLSEAVLRGQALLVNSSQPWEPLQLHVDKAVSGLRSLTTLLRALGAQKEAISPPDAASAAPLRTITADTFRKLFRVYSNFLRGKLKLYTGEACRTGDR"
15
+ ],
16
+
17
+ "structure": [
18
+ "ddddddddddddddddddddddddddddddddpdpddpddpqpdddfddpdqqlddadddfaaddpvqvvlcvvvvvlqakkfkwfdadffkkkwkwadpdpdidifidtnvgtdglqpddllclvcvvlsvqlvvllqvvvcvvvvapafrmkmfiwgkdalddpfppadadpdwhagsvgdidgsvpgdrdddpaqhahsdiaietewiwiarnsdpvriqtafqvvvcvsqvprpphhyidgqfmggnllnlldpqqpaaqlrnqqvvnqvgddpprggqfikmfrrpprppvvcvsvrhgihtdghlvnvcvvdppcsvvcccnrcvprnvvscvvvvndhdtdvlsrhhpvlsvllvqllvlldpvllvvldvvvdlpclqvvvqdllnsllsslvvsvvvsvvpddpvnvpgdpvsvvvssvsssvsssvvsvvcvvvvnvvsvvvvvvvddppdpdddpddd",
19
+ "dpdplvvqppdddplqappppfaadpvcvlvdpvaaaeeeeaqallsllllllclvlvgfyeyefqaeqpdwdddpddvpdddftqtqfapcqppvclqpqqvllvvqvvfwdwqeaefdqpppvpddppddhddppdgdddqqhdppfdpqqdlgqatwgghrntcqnhdpqfddawadadpvahqgtfdaldpdpvvrvvlvvvllvvlcvqlvkdqclqvpflqqcllqvllcvvcvvppwhkgggtgswhadpvhsldirhttsssscvvqrvdpssvssydyhyskhqqewhaghdpfgetawtkiarnccvvpvpdrgihigghrfyeypralprvllrcvssvqalqdpggdprhnqdqffalkwfwwkkkfkfffdpvsqvcqcvppppdpssnvqlvvqcvvcvpdpgsgdssrakhfmwtdadpvqqktktwidghhndddddppddpsrmimimiihwafrdrqfgwgfdppgdhpvrttrihtrddgdpvsvvsvvvrlvvsvvssvstgdtdprgpididrrnsvnlieqrqaedddsvngqayqlqhgpsyphygyfdrnhrngigngdcvsvrssssvsnsvvsscvvvvdpdddppdddddd",
20
+ "ddppppdcvvvvvvvvvppppppvppldplvvlldvvllvvqlvllvvllvvcvvpdpnfflqdwqkafdlddpvvvvvpddlllllqlllvrlvsllvrlvsslvslvpdpdrdvvnnvssvvlnvssvvvnvssvslvsvvsnppddppprdddgdididrgssvssvsvssnsvgsvvvssvvssvvvvd"
21
+ ],
22
+
23
+ "text": [
24
+ "RNA-editing ligase in kinetoplastid mitochondrial.",
25
+ "Oxidase which catalyzes the oxidation of various aldopyranoses and disaccharides.",
26
+ "Erythropoietin for regulation of erythrocyte proliferation and differentiation."
27
+ ]
28
+ }
29
+
30
+ samples = [[s1, s2] for s1, s2 in zip(input_examples["sequence"], input_examples["text"])]
31
+
32
+
33
+ def compute_score(input_type_1: str, input_1: str, input_type_2: str, input_2: str):
34
+ with torch.no_grad():
35
+ input_reprs = []
36
+
37
+ for input_type, input in [(input_type_1, input_1), (input_type_2, input_2)]:
38
+ if input_type == "sequence":
39
+ input_reprs.append(model.get_protein_repr([input]))
40
+
41
+ elif input_type == "structure":
42
+ input_reprs.append(model.get_structure_repr([input]))
43
+
44
+ else:
45
+ input_reprs.append(model.get_text_repr([input]))
46
+
47
+ score = input_reprs[0] @ input_reprs[1].T / model.temperature
48
+
49
+ return f"{score.item():.4f}"
50
+
51
+
52
+ def change_input_type(choice_1: str, choice_2: str):
53
+ examples_1 = input_examples[choice_1]
54
+ examples_2 = input_examples[choice_2]
55
+
56
+ # Change examples if input type is changed
57
+ global samples
58
+ samples = [[s1, s2] for s1, s2 in zip(examples_1, examples_2)]
59
+
60
+ # Set visibility of upload button
61
+ if choice_1 == "text":
62
+ visible_1 = False
63
+ else:
64
+ visible_1 = True
65
+
66
+ if choice_2 == "text":
67
+ visible_2 = False
68
+ else:
69
+ visible_2 = True
70
+
71
+ return (gr.update(samples=samples), "", "", gr.update(visible=visible_1), gr.update(visible=visible_1),
72
+ gr.update(visible=visible_2), gr.update(visible=visible_2))
73
+
74
+
75
+ # Load example from dataset
76
+ def load_example(example_id):
77
+ return samples[example_id]
78
+
79
+
80
+ # Build the block for computing protein-text similarity
81
+ def build_score_computation():
82
+ gr.Markdown(f"# Compute similarity score between two modalities")
83
+ with gr.Row(equal_height=True):
84
+ with gr.Column():
85
+ # Compute similarity score between sequence and text
86
+ with gr.Row():
87
+ input_1 = gr.Textbox(label="Input 1")
88
+
89
+ # Choose the type of input 1
90
+ input_type_1 = gr.Dropdown(input_types, label="Input type", value="sequence",
91
+ interactive=True, visible=True)
92
+
93
+ # Provide an upload button to upload a pdb file
94
+ upload_btn_1, chain_box_1 = upload_pdb_button(visible=True)
95
+ upload_btn_1.upload(parse_pdb_file, inputs=[input_type_1, upload_btn_1, chain_box_1], outputs=[input_1])
96
+
97
+ with gr.Row():
98
+ input_2 = gr.Textbox(label="Input 2")
99
+
100
+ # Choose the type of input 2
101
+ input_type_2 = gr.Dropdown(input_types, label="Input type", value="text",
102
+ interactive=True, visible=True)
103
+
104
+ # Provide an upload button to upload a pdb file
105
+ upload_btn_2, chain_box_2 = upload_pdb_button(visible=False)
106
+ upload_btn_2.upload(parse_pdb_file, inputs=[input_type_2, upload_btn_2, chain_box_2], outputs=[input_2])
107
+
108
+ # Provide examples
109
+ examples = gr.Dataset(samples=samples, type="index", components=[input_1, input_2], label="Input examples")
110
+
111
+ # Add click event to examples
112
+ examples.click(fn=load_example, inputs=[examples], outputs=[input_1, input_2])
113
+
114
+ compute_btn = gr.Button(value="Compute")
115
+
116
+ # Change examples based on input type
117
+ input_type_1.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
118
+ outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
119
+ upload_btn_2, chain_box_2])
120
+
121
+ input_type_2.change(fn=change_input_type, inputs=[input_type_1, input_type_2],
122
+ outputs=[examples, input_1, input_2, upload_btn_1, chain_box_1,
123
+ upload_btn_2, chain_box_2])
124
+
125
+ similarity_score = gr.Label(label="similarity score")
126
+ compute_btn.click(fn=compute_score, inputs=[input_type_1, input_1, input_type_2, input_2],
127
+ outputs=[similarity_score])
demo/modules/init_model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ import yaml
6
+ import glob
7
+
8
+ from easydict import EasyDict
9
+ from utils.constants import sequence_level
10
+ from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
11
+ from tqdm import tqdm
12
+
13
+
14
+ def load_model():
15
+ model_config = {
16
+ "protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
17
+ "text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
18
+ "structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
19
+ "load_protein_pretrained": False,
20
+ "load_text_pretrained": False,
21
+ "from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
22
+ }
23
+
24
+ model = ProTrekTrimodalModel(**model_config)
25
+ model.eval()
26
+ return model
27
+
28
+
29
+ def load_faiss_index(index_path: str):
30
+ if config.faiss_config.IO_FLAG_MMAP:
31
+ index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
32
+ else:
33
+ index = faiss.read_index(index_path)
34
+
35
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
36
+ return index
37
+
38
+
39
+ def load_index():
40
+ all_index = {}
41
+
42
+ # Load protein sequence index
43
+ all_index["sequence"] = {}
44
+ for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
45
+ db_name = db["name"]
46
+ index_dir = db["index_dir"]
47
+
48
+ index_path = f"{index_dir}/sequence.index"
49
+ sequence_index = load_faiss_index(index_path)
50
+
51
+ id_path = f"{index_dir}/ids.tsv"
52
+ uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
53
+
54
+ all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
55
+
56
+ # Load protein structure index
57
+ print("Loading structure index...")
58
+ all_index["structure"] = {}
59
+ for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
60
+ db_name = db["name"]
61
+ index_dir = db["index_dir"]
62
+
63
+ index_path = f"{index_dir}/structure.index"
64
+ structure_index = load_faiss_index(index_path)
65
+
66
+ id_path = f"{index_dir}/ids.tsv"
67
+ uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
68
+
69
+ all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
70
+
71
+ # Load text index
72
+ all_index["text"] = {}
73
+ valid_subsections = {}
74
+ for db in tqdm(config.text_index_dir, desc="Loading text index..."):
75
+ db_name = db["name"]
76
+ index_dir = db["index_dir"]
77
+ all_index["text"][db_name] = {}
78
+ text_dir = f"{index_dir}/subsections"
79
+
80
+ # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
81
+ valid_subsections[db_name] = set()
82
+ sequence_level.add("Global")
83
+ for subsection in tqdm(sequence_level):
84
+ index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
85
+ if not os.path.exists(index_path):
86
+ continue
87
+
88
+ text_index = load_faiss_index(index_path)
89
+
90
+ id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
91
+ text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
92
+
93
+ all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
94
+ valid_subsections[db_name].add(subsection)
95
+
96
+ # Sort valid_subsections
97
+ for db_name in valid_subsections:
98
+ valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
99
+
100
+ return all_index, valid_subsections
101
+
102
+
103
+ # Load the config file
104
+ root_dir = __file__.rsplit("/", 3)[0]
105
+ config_path = f"{root_dir}/demo/config.yaml"
106
+ with open(config_path, 'r', encoding='utf-8') as r:
107
+ config = EasyDict(yaml.safe_load(r))
108
+
109
+ device = "cuda"
110
+
111
+ print("Loading model...")
112
+ model = load_model()
113
+ model.to(device)
114
+
115
+ all_index, valid_subsections = load_index()
116
+ print("Done...")
117
+ # model = None
118
+ # all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}
demo/modules/search.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+ from scipy.stats import norm
8
+ from .init_model import model, all_index, valid_subsections
9
+ from .blocks import upload_pdb_button, parse_pdb_file
10
+
11
+
12
+ tmp_file_path = "/tmp/results.tsv"
13
+ tmp_plot_path = "/tmp/histogram.svg"
14
+
15
+ # Samples for input
16
+ samples = [
17
+ ["Proteins with zinc bindings."],
18
+ ["Proteins locating at cell membrane."],
19
+ ["Protein that serves as an enzyme."]
20
+ ]
21
+
22
+ # Databases for different modalities
23
+ now_db = {
24
+ "sequence": list(all_index["sequence"].keys())[0],
25
+ "structure": list(all_index["structure"].keys())[0],
26
+ "text": list(all_index["text"].keys())[0]
27
+ }
28
+
29
+
30
+ def clear_results():
31
+ return "", gr.update(visible=False), gr.update(visible=False)
32
+
33
+
34
+ def plot(scores) -> None:
35
+ """
36
+ Plot the distribution of scores and fit a normal distribution.
37
+ Args:
38
+ scores: List of scores
39
+ """
40
+ plt.hist(scores, bins=100, density=True, alpha=0.6)
41
+ plt.title('Distribution of similarity scores in the database', fontsize=15)
42
+ plt.xlabel('Similarity score', fontsize=15)
43
+ plt.ylabel('Density', fontsize=15)
44
+
45
+ mu, std = norm.fit(scores)
46
+
47
+ # Plot the Gaussian
48
+ xmin, xmax = plt.xlim()
49
+ _, ymax = plt.ylim()
50
+ x = np.linspace(xmin, xmax, 100)
51
+ p = norm.pdf(x, mu, std)
52
+ plt.plot(x, p)
53
+
54
+ # Plot total number of scores
55
+ plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
56
+
57
+ # Convert the plot to svg format
58
+ plt.savefig(tmp_plot_path)
59
+ plt.cla()
60
+
61
+
62
+ # Search from database
63
+ def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
64
+ input_modality = input_type.replace("sequence", "protein")
65
+ with torch.no_grad():
66
+ input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
67
+
68
+ db = now_db[query_type]
69
+ if query_type == "text":
70
+ index = all_index["text"][db][subsection_type]["index"]
71
+ ids = all_index["text"][db][subsection_type]["ids"]
72
+
73
+ else:
74
+ index = all_index[query_type][db]["index"]
75
+ ids = all_index[query_type][db]["ids"]
76
+
77
+ if check_index_ivf(query_type, subsection_type):
78
+ if index.nlist < nprobe:
79
+ raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
80
+ else:
81
+ index.nprobe = nprobe
82
+
83
+ if topk > index.ntotal:
84
+ raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
85
+
86
+ # Retrieve all scores to plot the distribution
87
+ scores, ranks = index.search(input_embedding, index.ntotal)
88
+ scores, ranks = scores[0], ranks[0]
89
+
90
+ # Remove inf values
91
+ selector = scores > -1
92
+ scores = scores[selector]
93
+ ranks = ranks[selector]
94
+ scores = scores / model.temperature.item()
95
+ plot(scores)
96
+
97
+ top_scores = scores[:topk]
98
+ top_ranks = ranks[:topk]
99
+
100
+ # ranks = [list(range(topk))]
101
+ # ids = ["P12345"] * topk
102
+ # scores = torch.randn(topk).tolist()
103
+
104
+ # Write the results to a temporary file for downloading
105
+ with open(tmp_file_path, "w") as w:
106
+ w.write("Id\tMatching score\n")
107
+ for i in range(topk):
108
+ rank = top_ranks[i]
109
+ w.write(f"{ids[rank]}\t{top_scores[i]}\n")
110
+
111
+ # Get topk ids
112
+ topk_ids = []
113
+ for rank in top_ranks:
114
+ now_id = ids[rank]
115
+ if query_type == "text":
116
+ topk_ids.append(now_id)
117
+ else:
118
+ if db != "PDB":
119
+ # Provide link to uniprot website
120
+ topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
121
+ else:
122
+ # Provide link to pdb website
123
+ pdb_id = now_id.split("-")[0]
124
+ topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
125
+
126
+ limit = 1000
127
+ df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
128
+ if len(topk_ids) > limit:
129
+ info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
130
+ index=[1000])
131
+ df = pd.concat([df, info_df], axis=0)
132
+
133
+ output = df.to_markdown()
134
+ return (output,
135
+ gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
136
+ gr.update(value=tmp_plot_path, visible=True))
137
+
138
+
139
+ def change_input_type(choice: str):
140
+ # Change examples if input type is changed
141
+ global samples
142
+ if choice == "text":
143
+ samples = [
144
+ ["Proteins with zinc bindings."],
145
+ ["Proteins locating at cell membrane."],
146
+ ["Protein that serves as an enzyme."]
147
+ ]
148
+
149
+ elif choice == "sequence":
150
+ samples = [
151
+ ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
152
+ ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
153
+ ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
154
+ ]
155
+
156
+ elif choice == "structure":
157
+ samples = [
158
+ ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
159
+ ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
160
+ ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
161
+ ]
162
+
163
+ # Set visibility of upload button
164
+ if choice == "text":
165
+ visible = False
166
+ else:
167
+ visible = True
168
+
169
+ return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
170
+
171
+
172
+ # Load example from dataset
173
+ def load_example(example_id):
174
+ return samples[example_id][0]
175
+
176
+
177
+ # Change the visibility of subsection type
178
+ def change_output_type(query_type: str, subsection_type: str):
179
+ nprobe_visible = check_index_ivf(query_type, subsection_type)
180
+ subsection_visible = True if query_type == "text" else False
181
+
182
+ return (
183
+ gr.update(visible=subsection_visible),
184
+ gr.update(visible=nprobe_visible),
185
+ gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
186
+ )
187
+
188
+
189
+ def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
190
+ """
191
+ Check if the index is of IVF type.
192
+ Args:
193
+ index_type: Type of index.
194
+ subsection_type: If the "index_type" is "text", get the index based on the subsection type.
195
+
196
+ Returns:
197
+ Whether the index is of IVF type or not.
198
+ """
199
+ db = now_db[index_type]
200
+ if index_type == "sequence":
201
+ index = all_index["sequence"][db]["index"]
202
+
203
+ elif index_type == "structure":
204
+ index = all_index["structure"][db]["index"]
205
+
206
+ elif index_type == "text":
207
+ index = all_index["text"][db][subsection_type]["index"]
208
+
209
+ nprobe_visible = True if hasattr(index, "nprobe") else False
210
+ return nprobe_visible
211
+
212
+
213
+ def change_db_type(query_type: str, subsection_type: str, db_type: str):
214
+ """
215
+ Change the database to search.
216
+ Args:
217
+ query_type: The output type.
218
+ db_type: The database to search.
219
+ """
220
+ now_db[query_type] = db_type
221
+
222
+ if query_type == "text":
223
+ subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
224
+ else:
225
+ subsection_update = gr.update(visible=False)
226
+
227
+ nprobe_visible = check_index_ivf(query_type, subsection_type)
228
+ return subsection_update, gr.update(visible=nprobe_visible)
229
+
230
+
231
+ # Build the searching block
232
+ def build_search_module():
233
+ gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)")
234
+ with gr.Row(equal_height=True):
235
+ with gr.Column():
236
+ # Set input type
237
+ input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
238
+
239
+ with gr.Row():
240
+ # Set output type
241
+ query_type = gr.Radio(
242
+ ["sequence", "structure", "text"],
243
+ label="Output type (e.g. 'sequence' means returning qualified sequences)",
244
+ value="sequence",
245
+ scale=2,
246
+ )
247
+
248
+ # If the output type is "text", provide an option to choose the subsection of text
249
+ subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
250
+ interactive=True, visible=False, scale=0)
251
+
252
+ db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
253
+ interactive=True, visible=True, scale=0)
254
+
255
+ with gr.Row():
256
+ # Input box
257
+ input = gr.Text(label="Input")
258
+
259
+ # Provide an upload button to upload a pdb file
260
+ upload_btn, chain_box = upload_pdb_button(visible=False)
261
+ upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
262
+
263
+
264
+ # If the index is of IVF type, provide an option to choose the number of clusters.
265
+ nprobe_visible = check_index_ivf(query_type.value)
266
+ nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
267
+ label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
268
+
269
+ # Add event listener to output type
270
+ query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
271
+ outputs=[subsection_type, nprobe, db_type])
272
+
273
+ # Add event listener to db type
274
+ db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
275
+ outputs=[subsection_type, nprobe])
276
+
277
+ # Choose topk results
278
+ topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
279
+
280
+ # Provide examples
281
+ examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
282
+
283
+ # Add click event to examples
284
+ examples.click(fn=load_example, inputs=[examples], outputs=input)
285
+
286
+ # Change examples based on input type
287
+ input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
288
+
289
+ with gr.Row():
290
+ search_btn = gr.Button(value="Search")
291
+ clear_btn = gr.Button(value="Clear")
292
+
293
+ with gr.Row():
294
+ with gr.Column():
295
+ results = gr.Markdown(label="results", height=450)
296
+ download_btn = gr.DownloadButton(label="Download results", visible=False)
297
+
298
+ # Plot the distribution of scores
299
+ histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
300
+
301
+ search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
302
+ outputs=[results, download_btn, histogram])
303
+
304
+ clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
demo/modules/tmalign.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from .blocks import upload_pdb_button
5
+ from utils.downloader import download_pdb, download_af2
6
+
7
+
8
+ root_dir = __file__.rsplit("/", 3)[0]
9
+ structure_types = ["AlphaFoldDB", "PDB"]
10
+
11
+
12
+ def upload_structure(file: str):
13
+ return file
14
+
15
+
16
+ def get_structure_path(structure: str, structure_type: str) -> str:
17
+ # If the structure is manually uploaded
18
+ if structure[0] == "/":
19
+ return structure
20
+
21
+ # If the structure is a Uniprot ID, download the structure from AlphaFoldDB
22
+ elif structure_type == "AlphaFoldDB":
23
+ save_path = f"{root_dir}/demo/cache/{structure}.pdb"
24
+ if not os.path.exists(save_path):
25
+ download_af2(structure, "pdb", save_path)
26
+ return save_path
27
+
28
+ # If the structure is a PDB ID, download the structure from PDB
29
+ elif structure_type == "PDB":
30
+ save_path = f"{root_dir}/demo/cache/{structure}.cif"
31
+ if not os.path.exists(save_path):
32
+ download_pdb(structure, "cif", save_path)
33
+ return save_path
34
+
35
+
36
+ def tmalign(structure_1: str, structure_type_1: str, structure_2: str, structure_type_2: str):
37
+ structure_path_1 = get_structure_path(structure_1, structure_type_1)
38
+ structure_path_2 = get_structure_path(structure_2, structure_type_2)
39
+
40
+ cmd = f"bin/TMalign {structure_path_1} {structure_path_2}"
41
+
42
+ r = os.popen(cmd)
43
+ text = r.read()
44
+ return text
45
+
46
+
47
+ # Build the block for computing protein-text similarity
48
+ def build_TMalign():
49
+ gr.Markdown(f"# Calculate TM-score between two protein structures")
50
+ with gr.Row(equal_height=True):
51
+ with gr.Column():
52
+ # Compute similarity score between sequence and text
53
+ with gr.Row():
54
+ structure_1 = gr.Textbox(label="Protein structure 1 (input Uniprot ID or PDB ID or upload a pdb file)")
55
+
56
+ structure_type_1 = gr.Dropdown(structure_types, label="Structure type (if the structure is manually uploaded, ignore this field)",
57
+ value="AlphaFoldDB", interactive=True, visible=True)
58
+
59
+ # Provide an upload button to upload a pdb file
60
+ upload_btn_1, _ = upload_pdb_button(visible=True, chain_visible=False)
61
+ upload_btn_1.upload(upload_structure, inputs=[upload_btn_1], outputs=[structure_1])
62
+
63
+ with gr.Row():
64
+ structure_2 = gr.Textbox(label="Protein structure 2 (input Uniprot ID or PDB ID or upload a pdb file)")
65
+
66
+ structure_type_2 = gr.Dropdown(structure_types, label="Structure type (if the structure is manually uploaded, ignore this field)",
67
+ value="AlphaFoldDB", interactive=True, visible=True)
68
+
69
+ # Provide an upload button to upload a pdb file
70
+ upload_btn_2, _ = upload_pdb_button(visible=True, chain_visible=False)
71
+ upload_btn_2.upload(upload_structure, inputs=[upload_btn_2], outputs=[structure_2])
72
+
73
+ compute_btn = gr.Button(value="Compute TM-score")
74
+ tmscore = gr.TextArea(label="TM-score", interactive=False)
75
+
76
+ compute_btn.click(tmalign, inputs=[structure_1, structure_type_1, structure_2, structure_type_2],
77
+ outputs=[tmscore])
78
+
demo/run.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ root_dir = __file__.rsplit("/", 2)[0]
3
+ if root_dir not in sys.path:
4
+ sys.path.append(root_dir)
5
+
6
+ import gradio as gr
7
+
8
+ from modules.search import build_search_module
9
+ from modules.compute_score import build_score_computation
10
+ from modules.tmalign import build_TMalign
11
+
12
+
13
+ # Build demo
14
+ with gr.Blocks() as demo:
15
+ build_search_module()
16
+ build_score_computation()
17
+ build_TMalign()
18
+
19
+
20
+ if __name__ == '__main__':
21
+ # Run the demo
22
+ demo.launch()
model/ProTrek/protein_encoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm import tqdm
4
+ from torch.nn.functional import normalize
5
+ from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
6
+
7
+
8
+ class ProteinEncoder(torch.nn.Module):
9
+ def __init__(self,
10
+ config_path: str,
11
+ out_dim: int,
12
+ load_pretrained: bool = True,
13
+ gradient_checkpointing: bool = False):
14
+ """
15
+ Args:
16
+ config_path: Path to the config file
17
+
18
+ out_dim : Output dimension of the protein representation
19
+
20
+ load_pretrained: Whether to load pretrained weights
21
+
22
+ gradient_checkpointing: Whether to use gradient checkpointing
23
+ """
24
+ super().__init__()
25
+ config = EsmConfig.from_pretrained(config_path)
26
+ if load_pretrained:
27
+ self.model = EsmForMaskedLM.from_pretrained(config_path)
28
+ else:
29
+ self.model = EsmForMaskedLM(config)
30
+ self.out = torch.nn.Linear(config.hidden_size, out_dim)
31
+
32
+ # Set gradient checkpointing
33
+ self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
34
+
35
+ # Remove contact head
36
+ self.model.esm.contact_head = None
37
+
38
+ # Remove position embedding if the embedding type is ``rotary``
39
+ if config.position_embedding_type == "rotary":
40
+ self.model.esm.embeddings.position_embeddings = None
41
+
42
+ self.tokenizer = EsmTokenizer.from_pretrained(config_path)
43
+
44
+ def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
45
+ """
46
+ Compute protein representation for the given proteins
47
+ Args:
48
+ protein: A list of protein sequences
49
+ batch_size: Batch size for inference
50
+ verbose: Whether to print progress
51
+ """
52
+ device = next(self.parameters()).device
53
+
54
+ protein_repr = []
55
+ if verbose:
56
+ iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
57
+ else:
58
+ iterator = range(0, len(proteins), batch_size)
59
+
60
+ for i in iterator:
61
+ protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
62
+ return_tensors="pt",
63
+ padding=True)
64
+ protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
65
+ output, _ = self.forward(protein_inputs)
66
+
67
+ protein_repr.append(output)
68
+
69
+ protein_repr = torch.cat(protein_repr, dim=0)
70
+ return normalize(protein_repr, dim=-1)
71
+
72
+ def forward(self, inputs: dict, get_mask_logits: bool = False):
73
+ """
74
+ Encode protein sequence into protein representation
75
+ Args:
76
+ inputs: A dictionary containing the following keys:
77
+ - input_ids: [batch, seq_len]
78
+ - attention_mask: [batch, seq_len]
79
+ get_mask_logits: Whether to return the logits for masked tokens
80
+
81
+ Returns:
82
+ protein_repr: [batch, protein_repr_dim]
83
+ mask_logits : [batch, seq_len, vocab_size]
84
+ """
85
+ last_hidden_state = self.model.esm(**inputs).last_hidden_state
86
+ reprs = last_hidden_state[:, 0, :]
87
+ reprs = self.out(reprs)
88
+
89
+ # Get logits for masked tokens
90
+ if get_mask_logits:
91
+ mask_logits = self.model.lm_head(last_hidden_state)
92
+ else:
93
+ mask_logits = None
94
+
95
+ return reprs, mask_logits
model/ProTrek/protrek_trimodal_model.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torchmetrics
4
+ import json
5
+ import math
6
+ import numpy as np
7
+ import os
8
+ import copy
9
+ import faiss
10
+ import time
11
+ import pandas as pd
12
+ import random
13
+
14
+ from tqdm import tqdm
15
+ from .protein_encoder import ProteinEncoder
16
+ from .structure_encoder import StructureEncoder
17
+ from .text_encoder import TextEncoder
18
+ from ..abstract_model import AbstractModel
19
+ from ..model_interface import register_model
20
+ from utils.mpr import MultipleProcessRunnerSimplifier
21
+ from torch.nn.functional import normalize, cross_entropy
22
+ from utils.constants import residue_level, sequence_level
23
+ from sklearn.metrics import roc_auc_score
24
+
25
+
26
+ def multilabel_cross_entropy(logits, labels):
27
+ """
28
+ Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf"
29
+ Args:
30
+ logits: [num_samples, num_classes]
31
+ labels: [num_samples, num_classes]
32
+ """
33
+
34
+ loss = 0
35
+ for pred, label in zip(logits, labels):
36
+ pos_logits = pred[label == 1]
37
+ neg_logits = pred[label == 0]
38
+
39
+ diff = neg_logits.unsqueeze(-1) - pos_logits
40
+ loss += torch.log(1 + torch.exp(diff).sum())
41
+
42
+ return loss / len(logits)
43
+
44
+ # pred = (1 - 2 * labels) * logits
45
+ # pred_neg = pred - labels * 1e12
46
+ # pred_pos = pred - (1 - labels) * 1e12
47
+ #
48
+ # zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype)
49
+ # pred_neg = torch.cat([pred_neg, zeros], dim=-1)
50
+ # pred_pos = torch.cat([pred_pos, zeros], dim=-1)
51
+ #
52
+ # neg_loss = torch.logsumexp(pred_neg, dim=-1)
53
+ # pos_loss = torch.logsumexp(pred_pos, dim=-1)
54
+ #
55
+ # return (neg_loss + pos_loss).mean()
56
+
57
+
58
+ @register_model
59
+ class ProTrekTrimodalModel(AbstractModel):
60
+ def __init__(self,
61
+ protein_config: str,
62
+ text_config: str,
63
+ structure_config: str = None,
64
+ repr_dim: int = 1024,
65
+ temperature: float = 0.07,
66
+ load_protein_pretrained: bool = True,
67
+ load_text_pretrained: bool = True,
68
+ use_mlm_loss: bool = False,
69
+ use_zlpr_loss: bool = False,
70
+ use_saprot: bool = False,
71
+ gradient_checkpointing: bool = False,
72
+ **kwargs):
73
+ """
74
+ Args:
75
+ protein_config: Path to the config file for protein sequence encoder
76
+
77
+ text_config: Path to the config file for text encoder
78
+
79
+ structure_config: Path to the config file for structure encoder
80
+
81
+ repr_dim: Output dimension of the protein and text representation
82
+
83
+ temperature: Temperature for softmax
84
+
85
+ load_protein_pretrained: Whether to load pretrained weights for protein encoder
86
+
87
+ load_text_pretrained: Whether to load pretrained weights for text encoder
88
+
89
+ use_mlm_loss: Whether to use masked language modeling loss
90
+
91
+ use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf"
92
+
93
+ use_saprot: Whether to use SaProt as protein encoder
94
+
95
+ gradient_checkpointing: Whether to use gradient checkpointing for protein encoder
96
+ """
97
+ self.protein_config = protein_config
98
+ self.structure_config = structure_config
99
+ self.text_config = text_config
100
+ self.repr_dim = repr_dim
101
+ self.temperature = temperature
102
+ self.load_protein_pretrained = load_protein_pretrained
103
+ self.load_text_pretrained = load_text_pretrained
104
+ self.use_mlm_loss = use_mlm_loss
105
+ self.use_zlpr_loss = use_zlpr_loss
106
+ self.use_saprot = use_saprot
107
+ self.gradient_checkpointing = gradient_checkpointing
108
+ super().__init__(**kwargs)
109
+
110
+ def initialize_metrics(self, stage: str) -> dict:
111
+ return_dict = {
112
+ f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
113
+ f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
114
+ }
115
+
116
+ if self.use_mlm_loss:
117
+ return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
118
+ if self.structure_config is not None:
119
+ return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
120
+
121
+ if self.structure_config is not None:
122
+ return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy()
123
+ return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
124
+ return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
125
+ return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy()
126
+
127
+ return return_dict
128
+
129
+ def initialize_model(self):
130
+ # Initialize encoders
131
+ self.protein_encoder = ProteinEncoder(self.protein_config,
132
+ self.repr_dim,
133
+ self.load_protein_pretrained,
134
+ self.gradient_checkpointing)
135
+
136
+ self.text_encoder = TextEncoder(self.text_config,
137
+ self.repr_dim,
138
+ self.load_text_pretrained,
139
+ self.gradient_checkpointing)
140
+
141
+ # Learnable temperature
142
+ self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))
143
+
144
+ # self.model is used for saving and loading
145
+ self.model = torch.nn.ParameterList([self.temperature,
146
+ self.protein_encoder,
147
+ self.text_encoder])
148
+
149
+ # If the structure encoder is specified
150
+ if self.structure_config is not None:
151
+ self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim)
152
+ self.model.append(self.structure_encoder)
153
+
154
+ def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
155
+ return self.text_encoder.get_repr(texts, batch_size, verbose)
156
+
157
+ def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
158
+ return self.structure_encoder.get_repr(proteins, batch_size, verbose)
159
+
160
+ def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
161
+ return self.protein_encoder.get_repr(proteins, batch_size, verbose)
162
+
163
+ def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None):
164
+ """
165
+ Args:
166
+ protein_inputs: A dictionary for protein encoder
167
+ structure_inputs: A dictionary for structure encoder
168
+ text_inputs : A dictionary for text encoder
169
+ """
170
+ protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss)
171
+ text_repr = self.text_encoder(text_inputs)
172
+
173
+ outputs = [text_repr, protein_repr, protein_mask_logits]
174
+
175
+ if self.structure_config is not None:
176
+ structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss)
177
+ outputs += [structure_repr, structure_mask_logits]
178
+
179
+ return outputs
180
+
181
+ def loss_func(self, stage: str, outputs, labels):
182
+ if self.structure_config is not None:
183
+ text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs
184
+ else:
185
+ text_repr, protein_repr, protein_mask_logits = outputs
186
+
187
+ device = text_repr.device
188
+
189
+ text_repr = normalize(text_repr, dim=-1)
190
+ protein_repr = normalize(protein_repr, dim=-1)
191
+
192
+ # Gather representations from all GPUs
193
+ all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach()
194
+ all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
195
+
196
+ if self.structure_config is not None:
197
+ structure_repr = normalize(structure_repr, dim=-1)
198
+ all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach()
199
+
200
+ # text_idx = labels["text_idx"]
201
+ # text_candidates = labels["text_candidates"]
202
+ #
203
+ # # Gather all text ids
204
+ # text_inds = self.all_gather(text_idx).flatten()
205
+ # # Create text classification labels
206
+ # text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device)
207
+ # for i, candidate in enumerate(text_candidates):
208
+ # for j, idx in enumerate(text_inds):
209
+ # if idx.item() in candidate:
210
+ # text_labels[i, j] = 1
211
+ #
212
+ # # Gather text labels from all GPUs
213
+ # text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1])
214
+ #
215
+ # # Protein classification labels are the transpose of text labels
216
+ # protein_labels = text_labels.T
217
+
218
+ # Batch size
219
+ rank = dist.get_rank()
220
+ bs = text_repr.shape[0]
221
+
222
+ # Get current labels
223
+ # protein_labels = protein_labels[rank * bs: rank * bs + bs]
224
+ # text_labels = text_labels[rank * bs: rank * bs + bs]
225
+
226
+ # Create classification labels between structure and sequence
227
+ bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
228
+
229
+ if self.structure_config is not None:
230
+ pairs = {
231
+ "protein": ["structure", "text"],
232
+ "structure": ["protein", "text"],
233
+ "text": ["protein", "structure"]
234
+ }
235
+ else:
236
+ pairs = {
237
+ "protein": ["text"],
238
+ "text": ["protein"]
239
+ }
240
+
241
+ loss_list = []
242
+ for k, values in pairs.items():
243
+ for v in values:
244
+ # Only calculate the similarity for the current batch
245
+ sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature)
246
+
247
+ # if k == "text":
248
+ # if self.use_zlpr_loss:
249
+ # loss = multilabel_cross_entropy(sim, protein_labels)
250
+ # else:
251
+ # loss = cross_entropy(sim, bs_labels)
252
+ #
253
+ # pred = []
254
+ # for s, l in zip(sim, protein_labels):
255
+ # n_label = l.sum()
256
+ # topk = torch.topk(s, k=n_label).indices
257
+ # if l[topk].sum() == n_label:
258
+ # pred.append(1)
259
+ # else:
260
+ # pred.append(0)
261
+ #
262
+ # pred = torch.tensor(pred).to(device)
263
+ # label = torch.ones_like(pred)
264
+ # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
265
+ # # if v == "protein":
266
+ # # acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute()
267
+ # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
268
+ #
269
+ # elif v == "text":
270
+ # if self.use_zlpr_loss:
271
+ # loss = multilabel_cross_entropy(sim, text_labels)
272
+ # else:
273
+ # loss = cross_entropy(sim, bs_labels)
274
+ #
275
+ # pred = []
276
+ # for s, l in zip(sim, text_labels):
277
+ # n_label = l.sum()
278
+ # topk = torch.topk(s, k=n_label).indices
279
+ # if l[topk].sum() == n_label:
280
+ # pred.append(1)
281
+ # else:
282
+ # pred.append(0)
283
+ #
284
+ # pred = torch.tensor(pred).to(device)
285
+ # label = torch.ones_like(pred)
286
+ # # if k == "protein":
287
+ # # acc = pred.sum() / len(pred)
288
+ # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
289
+ # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
290
+ #
291
+ # else:
292
+ # loss = cross_entropy(sim, bs_labels)
293
+ # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
294
+
295
+ loss = cross_entropy(sim, bs_labels)
296
+ self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
297
+ loss_list.append(loss)
298
+
299
+ # Masked language modeling loss
300
+ if self.use_mlm_loss:
301
+ k_label = [("protein", labels["seq_labels"])]
302
+ if self.structure_config is not None:
303
+ k_label.append(("structure", labels["struc_labels"]))
304
+
305
+ for k, label in k_label:
306
+ logits = eval(f"{k}_mask_logits")
307
+ # merge the first and second dimension of logits
308
+ logits = logits.view(-1, logits.shape[-1])
309
+ label = label.flatten().to(device)
310
+ mlm_loss = cross_entropy(logits, label, ignore_index=-1)
311
+ loss_list.append(mlm_loss)
312
+ self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label)
313
+
314
+ loss = sum(loss_list) / len(loss_list)
315
+
316
+ if stage == "train":
317
+ log_dict = self.get_log_dict("train")
318
+ log_dict["train_loss"] = loss
319
+ self.log_info(log_dict)
320
+
321
+ # Reset train metrics
322
+ self.reset_metrics("train")
323
+
324
+ return loss
325
+
326
+ def padded_gather(self, tensor: torch.Tensor):
327
+ """
328
+ Gather tensors from all GPUs, allowing different shapes at the batch dimension.
329
+ """
330
+
331
+ # Get the size of the tensor
332
+ size = tensor.shape[0]
333
+ all_sizes = self.all_gather(torch.tensor(size, device=tensor.device))
334
+ max_size = max(all_sizes)
335
+
336
+ # Pad the tensor
337
+ if size != max_size:
338
+ tmp = torch.zeros(max_size, tensor.shape[-1], dtype=tensor.dtype, device=tensor.device)
339
+ tmp[:size] = tensor
340
+ tensor = tmp
341
+
342
+ padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1])
343
+ tensor = padded_tensor[:sum(all_sizes)]
344
+
345
+ return tensor
346
+
347
+ def _get_protein_indices(self):
348
+ world_size = dist.get_world_size()
349
+ rank = dist.get_rank()
350
+
351
+ if self.use_saprot:
352
+ proteins = []
353
+ for sub_dict in self.uniprot2label.values():
354
+ aa_seq = sub_dict["seq"]
355
+ foldseek_seq = sub_dict["foldseek"]
356
+ assert len(aa_seq) == len(foldseek_seq)
357
+ seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
358
+ proteins.append(seq)
359
+
360
+ else:
361
+ proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
362
+
363
+ span = math.ceil(len(proteins) / world_size)
364
+ sub_proteins = proteins[rank * span: (rank + 1) * span]
365
+
366
+ # Display the progress bar on the rank 0 process
367
+ verbose = self.trainer.local_rank == 0
368
+ # Get protein representations
369
+ sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
370
+ protein_repr = self.padded_gather(sub_protein_repr)
371
+
372
+ # Construct faiss index
373
+ d = protein_repr.shape[-1]
374
+ protein_indices = faiss.IndexFlatIP(d)
375
+ protein_indices.add(protein_repr.cpu().numpy())
376
+ return protein_indices
377
+
378
+ def _get_structure_indices(self):
379
+ world_size = dist.get_world_size()
380
+ rank = dist.get_rank()
381
+
382
+ proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()]
383
+ span = math.ceil(len(proteins) / world_size)
384
+ sub_proteins = proteins[rank * span: (rank + 1) * span]
385
+
386
+ # Display the progress bar on the rank 0 process
387
+ verbose = self.trainer.local_rank == 0
388
+ # Get protein representations
389
+ sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
390
+ protein_repr = self.padded_gather(sub_protein_repr)
391
+
392
+ # Construct faiss index
393
+ d = protein_repr.shape[-1]
394
+ structure_indices = faiss.IndexFlatIP(d)
395
+ structure_indices.add(protein_repr.cpu().numpy())
396
+ return structure_indices
397
+
398
+ def _get_text_indices(self):
399
+ world_size = dist.get_world_size()
400
+ rank = dist.get_rank()
401
+
402
+ # Display the progress bar on the rank 0 process
403
+ verbose = self.trainer.local_rank == 0
404
+ if verbose:
405
+ iterator = tqdm(self.label2text.keys(), desc="Get text representations")
406
+ else:
407
+ iterator = self.label2text.keys()
408
+
409
+ text_embeddings = {}
410
+ for subsection in iterator:
411
+ if subsection == "Total":
412
+ continue
413
+
414
+ texts = []
415
+ for text_list in self.label2text[subsection].values():
416
+ # Only use the first text for efficiency
417
+ texts.append(text_list[0:1])
418
+
419
+ span = math.ceil(len(texts) / world_size)
420
+ texts = texts[rank * span: (rank + 1) * span]
421
+ embeddings = []
422
+ for text_list in texts:
423
+ text_repr = self.text_encoder.get_repr(text_list)
424
+ mean_repr = text_repr.mean(dim=0, keepdim=True)
425
+ norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
426
+ embeddings.append(norm_repr)
427
+
428
+ if len(embeddings) > 0:
429
+ embeddings = torch.cat(embeddings, dim=0)
430
+ else:
431
+ embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device)
432
+
433
+ text_repr = self.padded_gather(embeddings)
434
+ text_embeddings[subsection] = text_repr
435
+
436
+ # Aggregate text embeddings for global retrieval
437
+ total_embeddings = []
438
+ for idx in self.label2text["Total"].values():
439
+ subsection, i = idx.split("|")
440
+ total_embeddings.append(text_embeddings[subsection][int(i)])
441
+
442
+ text_embeddings["Total"] = torch.stack(total_embeddings)
443
+
444
+ # Construct faiss index
445
+ text_indices = {}
446
+ for subsection, text_repr in text_embeddings.items():
447
+ d = text_repr.shape[-1]
448
+ text_indices[subsection] = faiss.IndexFlatIP(d)
449
+ text_indices[subsection].add(text_repr.cpu().numpy())
450
+
451
+ return text_indices
452
+
453
+ def _protein2text(self, modality: str, protein_indices, text_indices: dict):
454
+ def do(process_id, idx, row, writer):
455
+ subsection, uniprot_id, prob_idx, label = row
456
+
457
+ # Retrieve ranking results
458
+ p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
459
+ text_inds = text_indices[subsection]
460
+ sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal)
461
+ sim_scores, rank_inds = sim_scores[0], rank_inds[0]
462
+
463
+ # Calculate Average Precision(AP)
464
+ ranks = []
465
+ label = set(label)
466
+ for i, rk in enumerate(rank_inds):
467
+ # Find the rank of this label in all labels
468
+ if rk in label:
469
+ ranks.append(i + 1)
470
+
471
+ ranks = np.array(ranks)
472
+ ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
473
+
474
+ # Calculate Mean Reciprocal Rank(MRR)
475
+ best_rank = ranks[0]
476
+ mrr = 1 / best_rank
477
+
478
+ # Calculate the AUC
479
+ true_labels = np.zeros_like(sim_scores)
480
+ true_labels[ranks - 1] = 1
481
+ if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
482
+ auc = 0
483
+ else:
484
+ auc = roc_auc_score(true_labels, sim_scores)
485
+
486
+ output = json.dumps([ap, mrr, auc])
487
+ writer.write(output + "\n")
488
+
489
+ inputs = []
490
+ swissprot_subsections = set()
491
+ for subsection in text_indices.keys():
492
+ for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()):
493
+ if uniprot_id in self.swissprot_ids:
494
+ if subsection in labels:
495
+ swissprot_subsections.add(subsection)
496
+ label = labels[subsection]
497
+ inputs.append((subsection, uniprot_id, i, label))
498
+
499
+ # Randomly shuffle the inputs
500
+ random.seed(20000812)
501
+ random.shuffle(inputs)
502
+
503
+ # Split inputs into chunks for parallel processing
504
+ world_size = dist.get_world_size()
505
+ rank = dist.get_rank()
506
+
507
+ span = math.ceil(len(inputs) / world_size)
508
+ sub_inputs = inputs[rank * span: (rank + 1) * span]
509
+
510
+ # Display the progress bar on the rank 0 process
511
+ verbose = self.trainer.local_rank == 0
512
+ if verbose:
513
+ print("Evaluating on each subsection...")
514
+ tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
515
+ mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
516
+ return_results=True)
517
+ outputs = mpr.run()
518
+ os.remove(tmp_path)
519
+
520
+ # Aggregate results
521
+ tensor_outputs = []
522
+ for output in outputs:
523
+ ap, mrr, auc = json.loads(output)
524
+ tensor_outputs.append([float(ap), float(mrr), float(auc)])
525
+
526
+ tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
527
+ tensor_outputs = self.padded_gather(tensor_outputs)
528
+
529
+ # Record results
530
+ avg_results = {}
531
+ for subsection in swissprot_subsections:
532
+ avg_results[subsection] = {"map": [],
533
+ "mrr": [],
534
+ "auc": []}
535
+
536
+ for input, output in zip(inputs, tensor_outputs):
537
+ ap, mrr, auc = output
538
+ subsection, _, _, _ = input
539
+
540
+ avg_results[subsection]["map"].append(ap.cpu().item())
541
+ avg_results[subsection]["mrr"].append(mrr.cpu().item())
542
+ avg_results[subsection]["auc"].append(auc.cpu().item())
543
+
544
+ results = {
545
+ f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
546
+ f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
547
+ f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
548
+ }
549
+
550
+ # Average the precision and recall for each level
551
+ for level, labels in [("residue-level", residue_level),
552
+ ("sequence-level", sequence_level),
553
+ ("all", residue_level | sequence_level)]:
554
+
555
+ mrrs = []
556
+ maps = []
557
+ aucs = []
558
+ for subsection in labels:
559
+ if subsection in avg_results:
560
+ mrrs.append(np.mean(avg_results[subsection]["mrr"]))
561
+ maps.append(np.mean(avg_results[subsection]["map"]))
562
+ aucs.append(np.mean(avg_results[subsection]["auc"]))
563
+
564
+ results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
565
+ results[f"{modality}2Text_{level}_map"] = np.mean(maps)
566
+ results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)
567
+
568
+ return results
569
+
570
+ def _text2protein(self, modality: str, protein_indices, text_indices: dict):
571
+ def do(process_id, idx, row, writer):
572
+ subsection, text_id, label = row
573
+
574
+ # Retrieve ranking results
575
+ t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1)
576
+ sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal)
577
+ sim_scores, rank_inds = sim_scores[0], rank_inds[0]
578
+
579
+ # Calculate Average Precision(AP)
580
+ ranks = []
581
+ label = set(label)
582
+ for i, rk in enumerate(rank_inds):
583
+ # Find the rank of this label in all labels
584
+ if rk in label:
585
+ ranks.append(i + 1)
586
+
587
+ ranks = np.array(ranks)
588
+ ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
589
+
590
+ # Calculate Mean Reciprocal Rank(MRR)
591
+ best_rank = ranks[0]
592
+ mrr = 1 / best_rank
593
+
594
+ # Calculate the AUC
595
+ true_labels = np.zeros_like(sim_scores)
596
+ true_labels[ranks - 1] = 1
597
+ if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
598
+ auc = 0
599
+ else:
600
+ auc = roc_auc_score(true_labels, sim_scores)
601
+
602
+ output = json.dumps([ap, mrr, auc])
603
+ writer.write(output + "\n")
604
+
605
+ text2label = {}
606
+ swissprot_subsections = set()
607
+ for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()):
608
+ # Only evaluate the texts in Swiss-Prot
609
+ if uniprot_id not in self.swissprot_ids:
610
+ continue
611
+
612
+ for subsection, text_ids in subsections.items():
613
+ if subsection == "seq" or subsection == "foldseek":
614
+ continue
615
+
616
+ swissprot_subsections.add(subsection)
617
+ if subsection not in text2label:
618
+ text2label[subsection] = {}
619
+
620
+ for text_id in text_ids:
621
+ text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i]
622
+
623
+ inputs = []
624
+ for subsection in swissprot_subsections:
625
+ for i, (text_id, label) in enumerate(text2label[subsection].items()):
626
+ inputs.append((subsection, text_id, label))
627
+
628
+ # Randomly shuffle the inputs
629
+ random.seed(20000812)
630
+ random.shuffle(inputs)
631
+
632
+ # Split inputs into chunks for parallel processing
633
+ world_size = dist.get_world_size()
634
+ rank = dist.get_rank()
635
+
636
+ span = math.ceil(len(inputs) / world_size)
637
+ sub_inputs = inputs[rank * span: (rank + 1) * span]
638
+
639
+ # Display the progress bar on the rank 0 process
640
+ verbose = self.trainer.local_rank == 0
641
+ if verbose:
642
+ print("Evaluating on each text...")
643
+
644
+ # Add time stamp to the temporary file name to avoid conflicts
645
+ tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
646
+ mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
647
+ return_results=True)
648
+ outputs = mpr.run()
649
+ os.remove(tmp_path)
650
+
651
+ # Aggregate results
652
+ tensor_outputs = []
653
+ for output in outputs:
654
+ ap, mrr, auc = json.loads(output)
655
+ tensor_outputs.append([float(ap), float(mrr), float(auc)])
656
+
657
+ tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
658
+ tensor_outputs = self.padded_gather(tensor_outputs)
659
+
660
+ # Record results
661
+ avg_results = {}
662
+ for subsection in swissprot_subsections:
663
+ avg_results[subsection] = {"map": [],
664
+ "mrr": [],
665
+ "auc": []}
666
+
667
+ for input, output in zip(inputs, tensor_outputs):
668
+ ap, mrr, auc = output
669
+ subsection, _, _ = input
670
+
671
+ avg_results[subsection]["map"].append(ap.cpu().item())
672
+ avg_results[subsection]["mrr"].append(mrr.cpu().item())
673
+ avg_results[subsection]["auc"].append(auc.cpu().item())
674
+
675
+ results = {
676
+ f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
677
+ f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
678
+ f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
679
+ }
680
+
681
+ # Average the precision and recall for each level
682
+ for level, labels in [("residue-level", residue_level),
683
+ ("sequence-level", sequence_level),
684
+ ("all", residue_level | sequence_level)]:
685
+
686
+ mrrs = []
687
+ maps = []
688
+ aucs = []
689
+ for subsection in labels:
690
+ if subsection in avg_results:
691
+ mrrs.append(np.mean(avg_results[subsection]["mrr"]))
692
+ maps.append(np.mean(avg_results[subsection]["map"]))
693
+ aucs.append(np.mean(avg_results[subsection]["auc"]))
694
+
695
+ results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
696
+ results[f"Text2{modality}_{level}_map"] = np.mean(maps)
697
+ results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)
698
+
699
+ return results
700
+
701
+ def retrieval_eval(self) -> dict:
702
+ # Get protein representations
703
+ protein_indices = self._get_protein_indices()
704
+
705
+ # Get structure representations
706
+ # if self.structure_config is not None:
707
+ # structure_embeddings = self._get_structure_embeddings()
708
+
709
+ # Get text representations
710
+ text_indices = self._get_text_indices()
711
+
712
+ # Retrieve texts for each protein
713
+ results = {}
714
+ results.update(self._protein2text("Sequence", protein_indices, text_indices))
715
+ # if self.structure_config is not None:
716
+ # results.update(self._protein2text("Structure", structure_embeddings, text_embeddings))
717
+ # results.update(self._text2protein("Structure", structure_embeddings, text_embeddings))
718
+
719
+ # Retrieve proteins for each text
720
+ results.update(self._text2protein("Sequence", protein_indices, text_indices))
721
+
722
+ return results
723
+
724
+ def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
725
+ while True:
726
+ masked_tokens = copy.copy(tokens)
727
+ labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
728
+ vocab = [k for k in tokenizer.get_vocab().keys()]
729
+
730
+ for i in range(len(tokens)):
731
+ token = tokens[i]
732
+
733
+ prob = random.random()
734
+ if prob < mask_ratio:
735
+ prob /= mask_ratio
736
+ labels[i + 1] = tokenizer.convert_tokens_to_ids(token)
737
+
738
+ if prob < 0.8:
739
+ # 80% random change to mask token
740
+ if self.use_saprot:
741
+ token = "#" + token[-1]
742
+ else:
743
+ token = tokenizer.mask_token
744
+ elif prob < 0.9:
745
+ # 10% chance to change to random token
746
+ token = random.choice(vocab)
747
+ else:
748
+ # 10% chance to keep current token
749
+ pass
750
+
751
+ masked_tokens[i] = token
752
+
753
+ # Check if there is at least one masked token
754
+ if (labels != -1).any():
755
+ return masked_tokens, labels
756
+
757
+ def mlm_eval(self) -> float:
758
+ world_size = dist.get_world_size()
759
+ rank = dist.get_rank()
760
+
761
+ if self.use_saprot:
762
+ proteins = []
763
+ for sub_dict in self.uniprot2label.values():
764
+ aa_seq = sub_dict["seq"]
765
+ foldseek_seq = sub_dict["foldseek"]
766
+ assert len(aa_seq) == len(foldseek_seq)
767
+ seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
768
+ proteins.append(seq)
769
+
770
+ else:
771
+ proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
772
+
773
+ span = math.ceil(len(proteins) / world_size)
774
+ sub_proteins = proteins[rank * span: (rank + 1) * span]
775
+
776
+ # Display the progress bar on the rank 0 process
777
+ if self.trainer.local_rank == 0:
778
+ iterator = tqdm(sub_proteins, desc="Computing mlm...")
779
+ else:
780
+ iterator = sub_proteins
781
+
782
+ total = torch.tensor([0], dtype=torch.long, device=self.device)
783
+ correct = torch.tensor([0], dtype=torch.long, device=self.device)
784
+ for seq in iterator:
785
+ tokens = self.protein_encoder.tokenizer.tokenize(seq)
786
+ masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15)
787
+ seq = " ".join(masked_tokens)
788
+
789
+ inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
790
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
791
+ _, logits = self.protein_encoder(inputs, get_mask_logits=True)
792
+
793
+ logits = logits.squeeze(0)
794
+ labels = labels.to(self.device)
795
+
796
+ selecor = labels != -1
797
+ preds = logits.argmax(dim=-1)[selecor]
798
+ labels = labels[selecor]
799
+
800
+ total += len(preds)
801
+ correct += (preds == labels).sum()
802
+
803
+ # Gather all results
804
+ total = self.padded_gather(total).sum()
805
+ correct = self.padded_gather(correct).sum()
806
+
807
+ acc = correct / total
808
+ return acc.cpu().item()
809
+
810
+ def _load_eval_data(self, stage):
811
+ # Load the data
812
+ lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
813
+ uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
814
+ label2text_path = os.path.join(lmdb_dir, "label2text.json")
815
+ swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")
816
+
817
+ self.uniprot2label = json.load(open(uniprot2label_path, "r"))
818
+ self.label2text = json.load(open(label2text_path, "r"))
819
+ self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist())
820
+ self.k = 3
821
+
822
+ def on_test_start(self):
823
+ self._load_eval_data("test")
824
+
825
+ log_dict = self.retrieval_eval()
826
+ log_dict = {"test_" + k: v for k, v in log_dict.items()}
827
+ if self.use_mlm_loss:
828
+ log_dict["test_mask_acc"] = self.mlm_eval()
829
+ self.log_info(log_dict)
830
+ print(log_dict)
831
+
832
+ def on_validation_start(self):
833
+ # Clear the cache
834
+ torch.cuda.empty_cache()
835
+
836
+ self._load_eval_data("valid")
837
+
838
+ log_dict = self.retrieval_eval()
839
+ log_dict = {"valid_" + k: v for k, v in log_dict.items()}
840
+ if self.use_mlm_loss:
841
+ log_dict["valid_mask_acc"] = self.mlm_eval()
842
+ self.log_info(log_dict)
843
+
844
+ self.check_save_condition(self.step, mode="max")
845
+
846
+ def test_step(self, batch, batch_idx):
847
+ return
848
+
849
+ def validation_step(self, batch, batch_idx):
850
+ return
851
+
852
+ def on_train_epoch_end(self):
853
+ super().on_train_epoch_end()
854
+ # Re-sample the subset of the training data
855
+ if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
856
+ self.trainer.datamodule.train_dataset.sample_subset()
857
+
858
+ # def test_epoch_end(self, outputs):
859
+ # log_dict = self.get_log_dict("test")
860
+ # log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
861
+ #
862
+ # print(log_dict)
863
+ # self.log_info(log_dict)
864
+ #
865
+ # self.reset_metrics("test")
866
+ #
867
+ # def validation_epoch_end(self, outputs):
868
+ # log_dict = self.get_log_dict("valid")
869
+ # log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
870
+ #
871
+ # self.log_info(log_dict)
872
+ # self.reset_metrics("valid")
873
+ # self.check_save_condition(log_dict["valid_loss"], mode="min")
874
+
model/ProTrek/structure_encoder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm import tqdm
4
+ from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
5
+ from torch.nn.functional import normalize
6
+
7
+
8
+ class StructureEncoder(torch.nn.Module):
9
+ def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False):
10
+ """
11
+ Args:
12
+ config_path: Path to the config file
13
+
14
+ out_dim: Output dimension of the structure representation
15
+
16
+ gradient_checkpointing: Whether to use gradient checkpointing
17
+ """
18
+ super().__init__()
19
+ config = EsmConfig.from_pretrained(config_path)
20
+ self.model = EsmForMaskedLM(config)
21
+ self.out = torch.nn.Linear(config.hidden_size, out_dim)
22
+
23
+ # Set gradient checkpointing
24
+ self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
25
+
26
+ # Remove contact head
27
+ self.model.esm.contact_head = None
28
+
29
+ # Remove position embedding if the embedding type is ``rotary``
30
+ if config.position_embedding_type == "rotary":
31
+ self.model.esm.embeddings.position_embeddings = None
32
+
33
+ self.tokenizer = EsmTokenizer.from_pretrained(config_path)
34
+
35
+ def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
36
+ """
37
+ Compute protein structure representation for the given proteins
38
+ Args:
39
+ protein: A list of protein structural sequences
40
+ batch_size: Batch size for inference
41
+ verbose: Whether to print progress
42
+ """
43
+ device = next(self.parameters()).device
44
+
45
+ protein_repr = []
46
+ if verbose:
47
+ iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
48
+ else:
49
+ iterator = range(0, len(proteins), batch_size)
50
+
51
+ for i in iterator:
52
+ protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
53
+ return_tensors="pt",
54
+ padding=True)
55
+ protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
56
+ output, _ = self.forward(protein_inputs)
57
+
58
+ protein_repr.append(output)
59
+
60
+ protein_repr = torch.cat(protein_repr, dim=0)
61
+ return normalize(protein_repr, dim=-1)
62
+
63
+ def forward(self, inputs: dict, get_mask_logits: bool = False):
64
+ """
65
+ Encode protein structure into protein representation
66
+ Args:
67
+ inputs: A dictionary containing the following keys:
68
+ - input_ids: [batch, seq_len]
69
+ - attention_mask: [batch, seq_len]
70
+ get_mask_logits: Whether to return the logits for masked tokens
71
+
72
+ Returns:
73
+ protein_repr: [batch, protein_repr_dim]
74
+ mask_logits : [batch, seq_len, vocab_size]
75
+ """
76
+ last_hidden_state = self.model.esm(**inputs).last_hidden_state
77
+ reprs = last_hidden_state[:, 0, :]
78
+ reprs = self.out(reprs)
79
+
80
+ # Get logits for masked tokens
81
+ if get_mask_logits:
82
+ mask_logits = self.model.lm_head(last_hidden_state)
83
+ else:
84
+ mask_logits = None
85
+
86
+ return reprs, mask_logits
model/ProTrek/text_encoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm import tqdm
4
+ from torch.nn.functional import normalize
5
+ from transformers import BertConfig, BertModel, BertTokenizer
6
+
7
+
8
+ class TextEncoder(torch.nn.Module):
9
+ def __init__(self,
10
+ config_path: str,
11
+ out_dim: int,
12
+ load_pretrained: bool = True,
13
+ gradient_checkpointing: bool = False):
14
+ """
15
+ Args:
16
+ config_path: Path to the config file
17
+
18
+ out_dim: Output dimension of the text representation
19
+
20
+ load_pretrained: Whether to load pretrained weights
21
+
22
+ gradient_checkpointing: Whether to enable gradient checkpointing
23
+ """
24
+ super().__init__()
25
+ config = BertConfig.from_pretrained(config_path)
26
+ if load_pretrained:
27
+ self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
28
+ else:
29
+ self.model = BertModel(config, add_pooling_layer=False)
30
+ self.out = torch.nn.Linear(config.hidden_size, out_dim)
31
+
32
+ # Set gradient checkpointing
33
+ self.model.encoder.gradient_checkpointing = gradient_checkpointing
34
+
35
+ self.tokenizer = BertTokenizer.from_pretrained(config_path)
36
+
37
+ def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
38
+ """
39
+ Compute text representation for the given texts
40
+ Args:
41
+ texts: A list of strings
42
+ batch_size: Batch size for inference
43
+ verbose: Whether to print progress
44
+ """
45
+ device = next(self.parameters()).device
46
+
47
+ text_repr = []
48
+ if verbose:
49
+ iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
50
+ else:
51
+ iterator = range(0, len(texts), batch_size)
52
+
53
+ for i in iterator:
54
+ text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
55
+ return_tensors="pt",
56
+ truncation=True,
57
+ max_length=512,
58
+ padding=True)
59
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
60
+ output = self(text_inputs)
61
+
62
+ text_repr.append(output)
63
+
64
+ text_repr = torch.cat(text_repr, dim=0)
65
+ return normalize(text_repr, dim=-1)
66
+
67
+ def forward(self, inputs: dict):
68
+ """
69
+ Encode text into text representation
70
+ Args:
71
+ inputs: A dictionary containing the following keys:
72
+ - input_ids: [batch, seq_len]
73
+ - attention_mask: [batch, seq_len]
74
+ - token_type_ids: [batch, seq_len]
75
+
76
+ Returns:
77
+ text_repr: [batch, text_repr_dim]
78
+ """
79
+ reprs = self.model(**inputs).last_hidden_state[:, 0, :]
80
+ reprs = self.out(reprs)
81
+ return reprs
model/abstract_model.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import abc
3
+ import os
4
+ import copy
5
+
6
+ import pytorch_lightning as pl
7
+ from utils.lr_scheduler import *
8
+ from torch import distributed as dist
9
+
10
+
11
+ class AbstractModel(pl.LightningModule):
12
+ def __init__(self,
13
+ lr_scheduler_kwargs: dict = None,
14
+ optimizer_kwargs: dict = None,
15
+ save_path: str = None,
16
+ from_checkpoint: str = None,
17
+ load_prev_scheduler: bool = False,
18
+ save_weights_only: bool = True,):
19
+ """
20
+
21
+ Args:
22
+ lr_scheduler: Kwargs for lr_scheduler
23
+ optimizer_kwargs: Kwargs for optimizer_kwargs
24
+ save_path: Save trained model
25
+ from_checkpoint: Load model from checkpoint
26
+ load_prev_scheduler: Whether load previous scheduler from checkpoint
27
+ load_strict: Whether load model strictly
28
+ save_weights_only: Whether save only weights or also optimizer and lr_scheduler
29
+
30
+ """
31
+ super().__init__()
32
+ self.initialize_model()
33
+
34
+ self.metrics = {}
35
+ for stage in ["train", "valid", "test"]:
36
+ stage_metrics = self.initialize_metrics(stage)
37
+ # Rigister metrics as attributes
38
+ for metric_name, metric in stage_metrics.items():
39
+ setattr(self, metric_name, metric)
40
+
41
+ self.metrics[stage] = stage_metrics
42
+
43
+ if lr_scheduler_kwargs is None:
44
+ # Default lr_scheduler
45
+ self.lr_scheduler_kwargs = {
46
+ "class": "ConstantLRScheduler",
47
+ "init_lr": 0,
48
+ }
49
+ print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
50
+
51
+ else:
52
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs
53
+
54
+ if optimizer_kwargs is None:
55
+ # Default optimizer
56
+ self.optimizer_kwargs = {
57
+ "class": "AdamW",
58
+ "betas": (0.9, 0.98),
59
+ "weight_decay": 0.01,
60
+ }
61
+ print("No optimizer_kwargs provided. The default optimizer is AdamW.")
62
+ else:
63
+ self.optimizer_kwargs = optimizer_kwargs
64
+ self.init_optimizers()
65
+
66
+ self.save_path = save_path
67
+ self.save_weights_only = save_weights_only
68
+
69
+ # temp_step is used for accumulating gradients
70
+ self.temp_step = 0
71
+ self.step = 0
72
+ self.epoch = 0
73
+
74
+ self.load_prev_scheduler = load_prev_scheduler
75
+ self.from_checkpoint = from_checkpoint
76
+ if from_checkpoint:
77
+ self.load_checkpoint(from_checkpoint)
78
+
79
+ @abc.abstractmethod
80
+ def initialize_model(self) -> None:
81
+ """
82
+ All model initialization should be done here
83
+ Note that the whole model must be named as "self.model" for model saving and loading
84
+ """
85
+ raise NotImplementedError
86
+
87
+ @abc.abstractmethod
88
+ def forward(self, *args, **kwargs):
89
+ """
90
+ Forward propagation
91
+ """
92
+ raise NotImplementedError
93
+
94
+ @abc.abstractmethod
95
+ def initialize_metrics(self, stage: str) -> dict:
96
+ """
97
+ Initialize metrics for each stage
98
+ Args:
99
+ stage: "train", "valid" or "test"
100
+
101
+ Returns:
102
+ A dictionary of metrics for the stage. Keys are metric names and values are metric objects
103
+ """
104
+ raise NotImplementedError
105
+
106
+ @abc.abstractmethod
107
+ def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
108
+ """
109
+
110
+ Args:
111
+ stage: "train", "valid" or "test"
112
+ outputs: model outputs for calculating loss
113
+ labels: labels for calculating loss
114
+
115
+ Returns:
116
+ loss
117
+
118
+ """
119
+ raise NotImplementedError
120
+
121
+ @staticmethod
122
+ def load_weights(model, weights):
123
+ model_dict = model.state_dict()
124
+
125
+ unused_params = []
126
+ missed_params = list(model_dict.keys())
127
+
128
+ for k, v in weights.items():
129
+ if k in model_dict.keys():
130
+ model_dict[k] = v
131
+ missed_params.remove(k)
132
+
133
+ else:
134
+ unused_params.append(k)
135
+
136
+ if len(missed_params) > 0:
137
+ print(f"\033[31mSome weights of {type(model).__name__} were not "
138
+ f"initialized from the model checkpoint: {missed_params}\033[0m")
139
+
140
+ if len(unused_params) > 0:
141
+ print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
142
+
143
+ model.load_state_dict(model_dict)
144
+
145
+ def optimizer_step(
146
+ self,
147
+ epoch: int,
148
+ batch_idx: int,
149
+ optimizer,
150
+ optimizer_closure=None,
151
+ ) -> None:
152
+ super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
153
+
154
+ self.temp_step += 1
155
+ if self.temp_step == self.trainer.accumulate_grad_batches:
156
+ self.step += 1
157
+ self.temp_step = 0
158
+
159
+ # For pytorch-lightning 1.9.5
160
+ # def optimizer_step(
161
+ # self,
162
+ # epoch: int,
163
+ # batch_idx: int,
164
+ # optimizer,
165
+ # optimizer_idx: int = 0,
166
+ # optimizer_closure=None,
167
+ # on_tpu: bool = False,
168
+ # using_native_amp: bool = False,
169
+ # using_lbfgs: bool = False,
170
+ # ) -> None:
171
+ # super().optimizer_step(
172
+ # epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
173
+ # )
174
+ # self.temp_step += 1
175
+ # if self.temp_step == self.trainer.accumulate_grad_batches:
176
+ # self.step += 1
177
+ # self.temp_step = 0
178
+
179
+ def on_train_epoch_end(self):
180
+ self.epoch += 1
181
+
182
+ def training_step(self, batch, batch_idx):
183
+ inputs, labels = batch
184
+
185
+ # optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
186
+ # for _ in range(1000):
187
+ # outputs = self(**inputs)
188
+ # loss = self.loss_func('train', outputs, labels)
189
+ # loss.backward()
190
+ # optimizer.step()
191
+ # optimizer.zero_grad()
192
+ #
193
+ # raise
194
+
195
+ outputs = self(**inputs)
196
+ loss = self.loss_func('train', outputs, labels)
197
+
198
+ self.log("loss", loss, prog_bar=True)
199
+ return loss
200
+
201
+ def validation_step(self, batch, batch_idx):
202
+ inputs, labels = batch
203
+ outputs = self(**inputs)
204
+ loss = self.loss_func('valid', outputs, labels)
205
+ self.valid_outputs.append(loss)
206
+ return loss
207
+
208
+ def test_step(self, batch, batch_idx):
209
+ inputs, labels = batch
210
+ outputs = self(**inputs)
211
+
212
+ loss = self.loss_func('test', outputs, labels)
213
+ self.test_outputs.append(loss)
214
+ return loss
215
+
216
+ def on_train_start(self) -> None:
217
+ # Load previous scheduler
218
+ if getattr(self, "prev_schechuler", None) is not None:
219
+ try:
220
+ self.step = self.prev_schechuler["global_step"]
221
+ self.epoch = self.prev_schechuler["epoch"]
222
+ self.best_value = self.prev_schechuler["best_value"]
223
+ self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
224
+ print(f"Previous training global step: {self.step}")
225
+ print(f"Previous training epoch: {self.epoch}")
226
+ print(f"Previous best value: {self.best_value}")
227
+ print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
228
+
229
+ # Load optimizer state
230
+ if hasattr(self.trainer.strategy, "deepspeed_engine"):
231
+ # For DeepSpeed strategy
232
+ try:
233
+ self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
234
+ except Exception as e:
235
+ print(e)
236
+
237
+ else:
238
+ # For DDP strategy
239
+ self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
240
+
241
+ except Exception as e:
242
+ print(e)
243
+ raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
244
+
245
+ def on_validation_epoch_start(self) -> None:
246
+ setattr(self, "valid_outputs", [])
247
+
248
+ def on_test_epoch_start(self) -> None:
249
+ setattr(self, "test_outputs", [])
250
+
251
+ def load_checkpoint(self, from_checkpoint: str) -> None:
252
+ """
253
+ Args:
254
+ from_checkpoint: Path to checkpoint.
255
+ """
256
+
257
+ # If ``from_checkpoint`` is a directory, load the checkpoint in it
258
+ if os.path.isdir(from_checkpoint):
259
+ basename = os.path.basename(from_checkpoint)
260
+ from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
261
+
262
+ state_dict = torch.load(from_checkpoint, map_location=self.device)
263
+ self.load_weights(self.model, state_dict["model"])
264
+
265
+ if self.load_prev_scheduler:
266
+ state_dict.pop("model")
267
+ self.prev_schechuler = state_dict
268
+
269
+ def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
270
+ """
271
+ Save model to save_path
272
+ Args:
273
+ save_path: Path to save model
274
+ save_info: Other info to save
275
+ save_weights_only: Whether only save model weights
276
+ """
277
+ dir = os.path.dirname(save_path)
278
+ os.makedirs(dir, exist_ok=True)
279
+
280
+ state_dict = {} if save_info is None else save_info
281
+ state_dict["model"] = self.model.state_dict()
282
+
283
+ # Convert model weights to fp32
284
+ for k, v in state_dict["model"].items():
285
+ state_dict["model"][k] = v.float()
286
+
287
+ if not save_weights_only:
288
+ state_dict["global_step"] = self.step
289
+ state_dict["epoch"] = self.epoch
290
+ state_dict["best_value"] = getattr(self, f"best_value", None)
291
+ state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
292
+
293
+ # If not using DeepSpeed, save optimizer state
294
+ if not hasattr(self.trainer.strategy, "deepspeed_engine"):
295
+ state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
296
+
297
+ torch.save(state_dict, save_path)
298
+
299
+ def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
300
+ """
301
+ Check whether to save model. If save_path is not None and now_value is the best, save model.
302
+ Args:
303
+ now_value: Current metric value
304
+ mode: "min" or "max", meaning whether the lower the better or the higher the better
305
+ save_info: Other info to save
306
+ """
307
+
308
+ assert mode in ["min", "max"], "mode should be 'min' or 'max'"
309
+
310
+ if self.save_path is not None:
311
+ # In case there are variables to be included in the save path
312
+ save_path = eval(f"f'{self.save_path}'")
313
+
314
+ dir = os.path.dirname(save_path)
315
+ os.makedirs(dir, exist_ok=True)
316
+
317
+ # Check whether to save model
318
+ best_value = getattr(self, f"best_value", None)
319
+ if best_value is not None:
320
+ if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
321
+ return
322
+
323
+ setattr(self, "best_value", now_value)
324
+
325
+ # For DeepSpeed strategy
326
+ if hasattr(self.trainer.strategy, "deepspeed_engine"):
327
+ if not self.save_weights_only:
328
+ self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
329
+
330
+ # Save a complete checkpoint
331
+ if dist.get_rank() == 0:
332
+ basename = os.path.basename(save_path)
333
+ ckpt_path = os.path.join(save_path, f"{basename}.pt")
334
+ self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
335
+
336
+ # For normal situation
337
+ else:
338
+ if dist.get_rank() == 0:
339
+ self.save_checkpoint(save_path, save_info, self.save_weights_only)
340
+
341
+ def reset_metrics(self, stage) -> None:
342
+ """
343
+ Reset metrics for given stage
344
+ Args:
345
+ stage: "train", "valid" or "test"
346
+ """
347
+ for metric in self.metrics[stage].values():
348
+ metric.reset()
349
+
350
+ def get_log_dict(self, stage: str) -> dict:
351
+ """
352
+ Get log dict for the stage
353
+ Args:
354
+ stage: "train", "valid" or "test"
355
+
356
+ Returns:
357
+ A dictionary of metrics for the stage. Keys are metric names and values are metric values
358
+
359
+ """
360
+ return {name: metric.compute() for name, metric in self.metrics[stage].items()}
361
+
362
+ def log_info(self, info: dict) -> None:
363
+ """
364
+ Record metrics during training and testing
365
+ Args:
366
+ info: dict of metrics
367
+ """
368
+ if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
369
+ info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
370
+ info["epoch"] = self.epoch
371
+ self.logger.log_metrics(info, step=self.step)
372
+
373
+ def init_optimizers(self):
374
+ copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
375
+
376
+ # No decay for layer norm and bias
377
+ no_decay = ['LayerNorm.weight', 'bias']
378
+ weight_decay = copy_optimizer_kwargs.pop("weight_decay")
379
+
380
+ optimizer_grouped_parameters = [
381
+ {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
382
+ 'weight_decay': weight_decay},
383
+ {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
384
+ 'weight_decay': 0.0}
385
+ ]
386
+
387
+ optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
388
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters,
389
+ lr=self.lr_scheduler_kwargs['init_lr'],
390
+ **copy_optimizer_kwargs)
391
+
392
+ tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
393
+ lr_scheduler = tmp_kwargs.pop("class")
394
+ self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
395
+
396
+ def configure_optimizers(self):
397
+ return {"optimizer": self.optimizer,
398
+ "lr_scheduler": {"scheduler": self.lr_scheduler,
399
+ "interval": "step",
400
+ "frequency": 1}
401
+ }
model/model_interface.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import glob
4
+
5
+
6
+ # register all available models through *_model.py files
7
+ # def construct_model():
8
+ # model_dir = os.path.dirname(__file__)
9
+ #
10
+ # # lists all model files
11
+ # model_list = []
12
+ # for root, _, names in os.walk(model_dir):
13
+ # for name in names:
14
+ # if name.endswith('_model.py'):
15
+ # sub_dirs = root.replace(model_dir, '').split(os.sep)
16
+ # model_list.append((sub_dirs, name[:-3]))
17
+ #
18
+ # # load model_config.yaml, controlling which models to be loaded
19
+ # model_config = yaml.safe_load(open(f"{model_dir}/model_config.yaml", "r"))
20
+ #
21
+ # if model_config["verbose"]:
22
+ # print("*" * 30 + f" Loading model " + "*" * 30)
23
+ #
24
+ # # register models
25
+ # for sub_dirs, name in model_list:
26
+ # if name in model_config["models"]:
27
+ # if len(sub_dirs) > 1:
28
+ # cmd = f"from {'.'.join(sub_dirs)} import {name}"
29
+ # else:
30
+ # cmd = f"from . import {name}"
31
+ #
32
+ # exec(cmd)
33
+ #
34
+ # if model_config["verbose"]:
35
+ # info = f"Loaded model: {name}"
36
+ # print(f"\033[32m{info}\033[0m")
37
+ # else:
38
+ # if model_config["verbose"]:
39
+ # info = f"Skipped model: {name}"
40
+ # print(f"\033[31m{info}\033[0m")
41
+ #
42
+ # if model_config["verbose"]:
43
+ # print("*" * 75)
44
+ #
45
+ #
46
+ # # register function as a wrapper for all models
47
+ # def register_model(cls):
48
+ # model_dict[cls.__name__] = cls
49
+ # return cls
50
+ #
51
+ #
52
+ # model_dict = {}
53
+ # construct_model()
54
+ #
55
+ #
56
+ # class ModelInterface:
57
+ # @classmethod
58
+ # def get_available_models(cls):
59
+ # return model_dict.keys()
60
+ #
61
+ # @classmethod
62
+ # def init_model(cls, model: str, **kwargs):
63
+ # """
64
+ #
65
+ # Args:
66
+ # model : Class name of model you want to use. Must be in model_dict.keys()
67
+ # **kwargs: Kwargs for model initialization
68
+ #
69
+ # Returns: Corresponding model
70
+ #
71
+ # """
72
+ # assert model in model_dict.keys(), f"class {model} doesn't exist!"
73
+ # return model_dict[model](**kwargs)
74
+
75
+
76
+ ########################################################################
77
+ # Version 2 #
78
+ ########################################################################
79
+ # register function as a wrapper for all models
80
+ def register_model(cls):
81
+ global now_cls
82
+ now_cls = cls
83
+ return cls
84
+
85
+
86
+ now_cls = None
87
+
88
+
89
+ class ModelInterface:
90
+ @classmethod
91
+ def init_model(cls, model_py_path: str, **kwargs):
92
+ """
93
+
94
+ Args:
95
+ model_py_path: Py file Path of model you want to use.
96
+ **kwargs: Kwargs for model initialization
97
+
98
+ Returns: Corresponding model
99
+ """
100
+ sub_dirs = model_py_path.split(os.sep)
101
+ cmd = f"from {'.' + '.'.join(sub_dirs[:-1])} import {sub_dirs[-1]}"
102
+ exec(cmd)
103
+
104
+ return now_cls(**kwargs)
utils/constants.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+
4
+ aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
5
+ aa_list = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
6
+
7
+ foldseek_seq_vocab = "ACDEFGHIKLMNPQRSTVWY#"
8
+ foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"
9
+
10
+ struc_unit = "abcdefghijklmnopqrstuvwxyz"
11
+
12
+
13
+ def create_vocab(size: int) -> dict:
14
+ """
15
+
16
+ Args:
17
+ size: Size of the vocabulary
18
+
19
+ Returns:
20
+ vocab: Vocabulary
21
+ """
22
+
23
+ token_len = 1
24
+ while size > len(struc_unit) ** token_len:
25
+ token_len += 1
26
+
27
+ vocab = {}
28
+ for i, token in enumerate(itertools.product(struc_unit, repeat=token_len)):
29
+ vocab[i] = "".join(token)
30
+ if len(vocab) == size:
31
+ vocab[i+1] = "#"
32
+ return vocab
33
+
34
+ # ProTrek
35
+ residue_level = {"Active site", "Binding site", "Site", "DNA binding", "Natural variant", "Mutagenesis",
36
+ "Transmembrane", "Topological domain", "Intramembrane", "Signal peptide", "Propeptide",
37
+ "Transit peptide",
38
+ "Chain", "Peptide", "Modified residue", "Lipidation", "Glycosylation", "Disulfide bond",
39
+ "Cross-link",
40
+ "Domain", "Repeat", "Compositional bias", "Region", "Coiled coil", "Motif"}
41
+
42
+ sequence_level = {"Function", "Miscellaneous", "Caution", "Catalytic activity", "Cofactor", "Activity regulation",
43
+ "Biophysicochemical properties", "Pathway", "Involvement in disease", "Allergenic properties",
44
+ "Toxic dose", "Pharmaceutical use", "Disruption phenotype", "Subcellular location",
45
+ "Post-translational modification", "Subunit", "Domain (non-positional annotation)",
46
+ "Sequence similarities", "RNA Editing", "Tissue specificity", "Developmental stage", "Induction",
47
+ "Biotechnology", "Polymorphism", "GO annotation", "Proteomes", "Protein names", "Gene names",
48
+ "Organism", "Taxonomic lineage", "Virus host"}
49
+
50
+ raw_text_level = {"Function", "Subunit", "Tissue specificity", "Disruption phenotype", "Post-translational modification",
51
+ "Induction", "Miscellaneous", "Sequence similarities", "Developmental stage",
52
+ "Domain (non-positional annotation)", "Activity regulation", "Caution", "Polymorphism", "Toxic dose",
53
+ "Allergenic properties", "Pharmaceutical use", "Cofactor", "Biophysicochemical properties",
54
+ "Subcellular location", "RNA Editing"}
utils/downloader.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ from utils.mpr import MultipleProcessRunner
5
+ from tqdm import tqdm
6
+
7
+
8
+ class Downloader(MultipleProcessRunner):
9
+ """
10
+ Download files that has unified resource locator
11
+ """
12
+
13
+ def __init__(self, base_url, save_path, overwrite=False, skip_error_info=False, **kwargs):
14
+ """
15
+
16
+ Args:
17
+ base_url: Unified Resource Locator of pdb file
18
+ save_path: Unified Resource Locator of saving path
19
+ overwrite: whether overwrite existing files
20
+ """
21
+ super().__init__(**kwargs)
22
+
23
+ self.base_url = base_url
24
+ self.save_path = save_path
25
+ self.overwrite = overwrite
26
+ self.skip_error_info = skip_error_info
27
+
28
+ if not overwrite:
29
+ # remove existing files in data
30
+ self.data = [uniprot for uniprot in tqdm(self.data, desc="Filtering out existing files...")
31
+ if not os.path.exists(self.save_path.format(uniprot))]
32
+
33
+ def _aggregate(self, final_path: str, sub_paths):
34
+ pass
35
+
36
+ def _target_static(self, process_id, data, sub_path, *args):
37
+ for i, uniprot in enumerate(data):
38
+ url = self.base_url.format(uniprot)
39
+ save_path = self.save_path.format(uniprot)
40
+
41
+ # shell cmd to download files
42
+ wget = f"wget -q -o /dev/null {url} -O {save_path}"
43
+
44
+ rm = f"rm {save_path}"
45
+ err = f"echo 'Error: {url} cannot be downloaded!'"
46
+ if self.skip_error_info:
47
+ err += ">/dev/null"
48
+
49
+ os.system(f"{wget} || ({rm} && {err})")
50
+
51
+ self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} Downloading files...")
52
+
53
+ def run(self):
54
+ """
55
+ Run this function to download files
56
+ """
57
+ super().run()
58
+
59
+ def __len__(self):
60
+ return len(self.data)
61
+
62
+ @staticmethod
63
+ # Clear empty files in specific directory
64
+ def clear_empty_files(path):
65
+ cnt = 0
66
+ for file in tqdm(os.listdir(path), desc="Clearing empty files..."):
67
+ if os.path.getsize(os.path.join(path, file)) == 0:
68
+ os.remove(os.path.join(path, file))
69
+ cnt += 1
70
+ print(f"Removed {cnt} empty files")
71
+ return cnt
72
+
73
+
74
+ class AlphaDBDownloader(Downloader):
75
+ """
76
+ Download files from AlphaFold2 database
77
+ """
78
+ def __init__(self, uniprot_ids, type: str, save_dir: str, **kwargs):
79
+ """
80
+
81
+ Args:
82
+ uniprots: Uniprot ids
83
+ type: Which type of files to download. Must be one of ['pdb', 'mmcif', 'plddt', "pae"]
84
+ save_dir: Saving directory
85
+ **kwargs:
86
+ """
87
+
88
+ url_dict = {
89
+ "pdb": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.pdb",
90
+ "mmcif": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.cif",
91
+ "plddt": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-confidence_v4.json",
92
+ "pae": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-predicted_aligned_error_v4.json"
93
+ }
94
+
95
+ save_dict = {
96
+ "pdb": "{}.pdb",
97
+ "mmcif": "{}.cif",
98
+ "plddt": "{}.json",
99
+ "pae": "{}.json"
100
+ }
101
+ base_url = url_dict[type]
102
+ save_path = os.path.join(save_dir, save_dict[type])
103
+
104
+ super().__init__(data=uniprot_ids, base_url=base_url, save_path=save_path, **kwargs)
105
+
106
+
107
+ class PDBDownloader(Downloader):
108
+ """
109
+ Download files from PDB
110
+ """
111
+ def __init__(self, pdb_ids, type: str, save_dir: str, **kwargs):
112
+ """
113
+
114
+ Args:
115
+ pdb_ids: PDB ids
116
+ type: Which type of files to download. Must be one of ['pdb', 'mmcif']
117
+ save_dir: Saving directory
118
+ """
119
+
120
+ url_dict = {
121
+ "pdb": "https://files.rcsb.org/download/{}.pdb",
122
+ "mmcif": "https://files.rcsb.org/download/{}.cif"
123
+ }
124
+
125
+ save_dict = {
126
+ "pdb": "{}.pdb",
127
+ "mmcif": "{}.cif"
128
+ }
129
+
130
+ base_url = url_dict[type]
131
+ save_path = os.path.join(save_dir, save_dict[type])
132
+
133
+ super().__init__(data=pdb_ids, base_url=base_url, save_path=save_path, **kwargs)
134
+
135
+
136
+ class CATHDownloader(Downloader):
137
+ def __init__(self, cath_ids, save_dir, **kwargs):
138
+ """
139
+ Download files from CATH
140
+ Args:
141
+ cath_ids: CATH ids
142
+ save_dir: Saving directory
143
+ """
144
+
145
+ url = "http://www.cathdb.info/version/v4_3_0/api/rest/id/{}.pdb"
146
+ save_path = os.path.join(save_dir, "{}.pdb")
147
+
148
+ super().__init__(data=cath_ids, base_url=url, save_path=save_path, **kwargs)
149
+
150
+
151
+ def download_pdb(pdb_id: str, format: str, save_path: str):
152
+ """
153
+ Download pdb file from PDB
154
+ Args:
155
+ pdb_id: PDB id
156
+ format: File , must be one of ['pdb', 'cif']
157
+ save_path: Saving path
158
+ """
159
+
160
+ url = f"https://files.rcsb.org/download/{pdb_id}.{format}"
161
+ wget = f"wget -q -o /dev/null {url} -O {save_path}"
162
+ rm = f"rm {save_path}"
163
+ err = f"echo 'Error: {url} cannot be downloaded!'"
164
+ os.system(f"{wget} || ({rm} && {err})")
165
+
166
+
167
+ def download_af2(uniprot_id: str, format: str, save_path: str):
168
+ """
169
+ Download files from AlphaFold2 database
170
+ Args:
171
+ uniprot_id: Uniprot id
172
+ format: File format, must be one of ['pdb', 'cif', 'plddt', 'pae']
173
+ save_path: Saving path
174
+ """
175
+
176
+ url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.{format}"
177
+ wget = f"wget -q -o /dev/null {url} -O {save_path}"
178
+ rm = f"rm {save_path}"
179
+ err = f"echo 'Error: {url} cannot be downloaded!'"
180
+ os.system(f"{wget} || ({rm} && {err})")
utils/foldseek_util.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import numpy as np
5
+ import re
6
+ import sys
7
+ sys.path.append(".")
8
+
9
+
10
+ # Get structural seqs from pdb file
11
+ def get_struc_seq(foldseek,
12
+ path,
13
+ chains: list = None,
14
+ process_id: int = 0,
15
+ plddt_mask: bool = False,
16
+ plddt_threshold: float = 70.,
17
+ foldseek_verbose: bool = False) -> dict:
18
+ """
19
+
20
+ Args:
21
+ foldseek: Binary executable file of foldseek
22
+
23
+ path: Path to pdb file
24
+
25
+ chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
26
+
27
+ process_id: Process ID for temporary files. This is used for parallel processing.
28
+
29
+ plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
30
+
31
+ plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
32
+
33
+ foldseek_verbose: If True, foldseek will print verbose messages.
34
+
35
+ Returns:
36
+ seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
37
+ (seq, struc_seq, combined_seq).
38
+ """
39
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
40
+ assert os.path.exists(path), f"PDB file not found: {path}"
41
+
42
+ tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
43
+ if foldseek_verbose:
44
+ cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
45
+ else:
46
+ cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
47
+ os.system(cmd)
48
+
49
+ seq_dict = {}
50
+ name = os.path.basename(path)
51
+ with open(tmp_save_path, "r") as r:
52
+ for i, line in enumerate(r):
53
+ desc, seq, struc_seq = line.split("\t")[:3]
54
+
55
+ # Mask low plddt
56
+ if plddt_mask:
57
+ plddts = extract_plddt(path)
58
+ assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
59
+
60
+ # Mask regions with plddt < threshold
61
+ indices = np.where(plddts < plddt_threshold)[0]
62
+ np_seq = np.array(list(struc_seq))
63
+ np_seq[indices] = "#"
64
+ struc_seq = "".join(np_seq)
65
+
66
+ name_chain = desc.split(" ")[0]
67
+ chain = name_chain.replace(name, "").split("_")[-1]
68
+
69
+ if chains is None or chain in chains:
70
+ if chain not in seq_dict:
71
+ combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
72
+ seq_dict[chain] = (seq, struc_seq, combined_seq)
73
+
74
+ os.remove(tmp_save_path)
75
+ os.remove(tmp_save_path + ".dbtype")
76
+ return seq_dict
77
+
78
+
79
+ def extract_plddt(pdb_path: str) -> np.ndarray:
80
+ """
81
+ Extract plddt scores from pdb file.
82
+ Args:
83
+ pdb_path: Path to pdb file.
84
+
85
+ Returns:
86
+ plddts: plddt scores.
87
+ """
88
+ with open(pdb_path, "r") as r:
89
+ plddt_dict = {}
90
+ for line in r:
91
+ line = re.sub(' +', ' ', line).strip()
92
+ splits = line.split(" ")
93
+
94
+ if splits[0] == "ATOM":
95
+ # If position < 1000
96
+ if len(splits[4]) == 1:
97
+ pos = int(splits[5])
98
+
99
+ # If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
100
+ # So the length of splits[4] is not 1
101
+ else:
102
+ pos = int(splits[4][1:])
103
+
104
+ plddt = float(splits[-2])
105
+
106
+ if pos not in plddt_dict:
107
+ plddt_dict[pos] = [plddt]
108
+ else:
109
+ plddt_dict[pos].append(plddt)
110
+
111
+ plddts = np.array([np.mean(v) for v in plddt_dict.values()])
112
+ return plddts
113
+
114
+
115
+ if __name__ == '__main__':
116
+ foldseek = "/sujin/bin/foldseek"
117
+ # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
118
+ test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
119
+ plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
120
+ res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
121
+ print(res["A"][1].lower())
utils/lr_scheduler.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
4
+
5
+
6
+ class ConstantLRScheduler(_LRScheduler):
7
+ def __init__(self,
8
+ optimizer,
9
+ last_epoch: int = -1,
10
+ verbose: bool = False,
11
+ init_lr: float = 0.,
12
+ ):
13
+ """
14
+ This is an implementation of constant learning rate scheduler.
15
+ Args:
16
+ optimizer: Optimizer
17
+
18
+ last_epoch: The index of last epoch. Default: -1
19
+
20
+ verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
21
+
22
+ init_lr: Initial learning rate
23
+ """
24
+
25
+ self.init_lr = init_lr
26
+ super().__init__(optimizer, last_epoch, verbose)
27
+
28
+ def state_dict(self):
29
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
30
+ return state_dict
31
+
32
+ def load_state_dict(self, state_dict):
33
+ self.__dict__.update(state_dict)
34
+
35
+ def get_lr(self):
36
+ if not self._get_lr_called_within_step:
37
+ raise RuntimeError(
38
+ "To get the last learning rate computed by the scheduler, use "
39
+ "get_last_lr()"
40
+ )
41
+
42
+ return [self.init_lr for group in self.optimizer.param_groups]
43
+
44
+
45
+ class CosineAnnealingLRScheduler(_LRScheduler):
46
+ def __init__(self,
47
+ optimizer,
48
+ last_epoch: int = -1,
49
+ verbose: bool = False,
50
+ init_lr: float = 0.,
51
+ max_lr: float = 4e-4,
52
+ final_lr: float = 4e-5,
53
+ warmup_steps: int = 2000,
54
+ cosine_steps: int = 10000,
55
+ ):
56
+ """
57
+ This is an implementation of cosine annealing learning rate scheduler.
58
+ Args:
59
+ optimizer: Optimizer
60
+
61
+ last_epoch: The index of last epoch. Default: -1
62
+
63
+ verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
64
+
65
+ init_lr: Initial learning rate
66
+
67
+ max_lr: Maximum learning rate after warmup
68
+
69
+ final_lr: Final learning rate after decay
70
+
71
+ warmup_steps: Number of steps for warmup
72
+
73
+ cosine_steps: Number of steps for cosine annealing
74
+ """
75
+
76
+ self.init_lr = init_lr
77
+ self.max_lr = max_lr
78
+ self.final_lr = final_lr
79
+ self.warmup_steps = warmup_steps
80
+ self.cosine_steps = cosine_steps
81
+ super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)
82
+
83
+ def state_dict(self):
84
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
85
+ return state_dict
86
+
87
+ def load_state_dict(self, state_dict):
88
+ self.__dict__.update(state_dict)
89
+
90
+ def get_lr(self):
91
+ if not self._get_lr_called_within_step:
92
+ raise RuntimeError(
93
+ "To get the last learning rate computed by the scheduler, use "
94
+ "get_last_lr()"
95
+ )
96
+
97
+ step_no = self.last_epoch
98
+
99
+ if step_no <= self.warmup_steps:
100
+ lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
101
+
102
+ else:
103
+ lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \
104
+ * (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps))
105
+
106
+ return [lr for group in self.optimizer.param_groups]
107
+
108
+
109
+ class Esm2LRScheduler(_LRScheduler):
110
+ def __init__(self,
111
+ optimizer,
112
+ last_epoch: int = -1,
113
+ verbose: bool = False,
114
+ init_lr: float = 0.,
115
+ max_lr: float = 4e-4,
116
+ final_lr: float = 4e-5,
117
+ warmup_steps: int = 2000,
118
+ start_decay_after_n_steps: int = 500000,
119
+ end_decay_after_n_steps: int = 5000000,
120
+ on_use: bool = True,
121
+ ):
122
+ """
123
+ This is an implementation of ESM2's learning rate scheduler.
124
+ Args:
125
+ optimizer: Optimizer
126
+
127
+ last_epoch: The index of last epoch. Default: -1
128
+
129
+ verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
130
+
131
+ init_lr: Initial learning rate
132
+
133
+ max_lr: Maximum learning rate after warmup
134
+
135
+ final_lr: Final learning rate after decay
136
+
137
+ warmup_steps: Number of steps for warmup
138
+
139
+ start_decay_after_n_steps: Start decay after this number of steps
140
+
141
+ end_decay_after_n_steps: End decay after this number of steps
142
+
143
+ on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
144
+ and will only use the ``init_lr``. Default: ``True``
145
+ """
146
+
147
+ self.init_lr = init_lr
148
+ self.max_lr = max_lr
149
+ self.final_lr = final_lr
150
+ self.warmup_steps = warmup_steps
151
+ self.start_decay_after_n_steps = start_decay_after_n_steps
152
+ self.end_decay_after_n_steps = end_decay_after_n_steps
153
+ self.on_use = on_use
154
+ super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)
155
+
156
+ def state_dict(self):
157
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
158
+ return state_dict
159
+
160
+ def load_state_dict(self, state_dict):
161
+ self.__dict__.update(state_dict)
162
+
163
+ def get_lr(self):
164
+ if not self._get_lr_called_within_step:
165
+ raise RuntimeError(
166
+ "To get the last learning rate computed by the scheduler, use "
167
+ "get_last_lr()"
168
+ )
169
+
170
+ step_no = self.last_epoch
171
+ if not self.on_use:
172
+ return [base_lr for base_lr in self.base_lrs]
173
+
174
+ if step_no <= self.warmup_steps:
175
+ lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
176
+
177
+ elif step_no <= self.start_decay_after_n_steps:
178
+ lr = self.max_lr
179
+
180
+ elif step_no <= self.end_decay_after_n_steps:
181
+ portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps)
182
+ lr = self.max_lr - portion * (self.max_lr - self.final_lr)
183
+
184
+ else:
185
+ lr = self.final_lr
186
+
187
+ return [lr for group in self.optimizer.param_groups]
utils/mpr.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import os
3
+ import time
4
+ import sys
5
+
6
+
7
+ from tqdm import tqdm
8
+ from math import ceil
9
+
10
+
11
+ class MultipleProcessRunner:
12
+ """
13
+ Abstarct class for running tasks with multiple process
14
+ There are three abstract methods that should be implemented:
15
+ 1. __len__() : return the length of data
16
+ 2. _target() : target function for each process
17
+ 3. _aggregate() : aggregate results from each process
18
+ """
19
+
20
+ def __init__(self,
21
+ data,
22
+ save_path=None,
23
+ n_process=1,
24
+ verbose=True,
25
+ total_only=True,
26
+ log_step=1,
27
+ start_method='fork'):
28
+ """
29
+ Args:
30
+ data : data to be processed that can be sliced
31
+
32
+ path : final output path
33
+
34
+ n_process: number of process
35
+
36
+ verbose : if True, display progress bar
37
+
38
+ total_only: If True, only total progress bar is displayed
39
+
40
+ log_step : For total progress bar, Next log will be printed when
41
+ ``current iteration`` - ``last log iteration`` >= log_step
42
+
43
+ start_method: start method for multiprocessing
44
+ """
45
+ self.data = data
46
+ self.save_path = save_path
47
+ self.n_process = n_process
48
+ self.verbose = verbose
49
+ self.total_only = total_only
50
+ self.log_step = log_step
51
+ self.start_method = start_method
52
+
53
+ # get terminal width to format output
54
+ try:
55
+ self.terminal_y = os.get_terminal_size()[0]
56
+
57
+ except Exception as e:
58
+ print(e)
59
+ print("Can't get terminal size, set terminal_y = None")
60
+ self.terminal_y = None
61
+
62
+ def _s2hms(self, seconds: float):
63
+ """
64
+ convert second format of time into hour:minute:second format
65
+
66
+ """
67
+ m, s = divmod(seconds, 60)
68
+ h, m = divmod(m, 60)
69
+
70
+ return "%02d:%02d:%02d" % (h, m, s)
71
+
72
+ def _display_time(self, st_time, now, total):
73
+ ed_time = time.time()
74
+ running_time = ed_time - st_time
75
+ rest_time = running_time * (total - now) / now
76
+ iter_sec = f"{now / running_time:.2f}it/s" if now > running_time else f"{running_time / now:.2f}s/it"
77
+
78
+ return f' [{self._s2hms(running_time)} < {self._s2hms(rest_time)}, {iter_sec}]'
79
+
80
+ def _display_bar(self, now, total, length):
81
+ now = now if now <= total else total
82
+ num = now * length // total
83
+ progress_bar = '[' + '#' * num + '_' * (length - num) + ']'
84
+ return progress_bar
85
+
86
+ def _display_all(self, now, total, desc, st_time):
87
+ # make a progress bar
88
+ length = 50
89
+ progress_bar = self._display_bar(now, total, length)
90
+ time_display = self._display_time(st_time, now, total)
91
+
92
+ display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
93
+
94
+ # Clean a line
95
+ width = self.terminal_y if self.terminal_y is not None else 100
96
+ num_space = width - len(display)
97
+ if num_space > 0:
98
+ display += ' ' * num_space
99
+ else:
100
+ length += num_space
101
+ progress_bar = self._display_bar(now, total, length)
102
+ display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
103
+
104
+ # Set color
105
+ display = f"\033[31m{display}\033[0m"
106
+
107
+ return display
108
+
109
+ # Print progress bar at specific position in terminal
110
+ def terminal_progress_bar(self,
111
+ process_id: int,
112
+ now: int,
113
+ total: int,
114
+ desc: str = ''):
115
+ """
116
+
117
+ Args:
118
+ process_id: process id
119
+ now: now iteration number
120
+ total: total iteration number
121
+ desc: description
122
+
123
+ """
124
+ st_time = self.process_st_time[process_id]
125
+
126
+ # Aggregate total information
127
+ self.counts[process_id] = now
128
+ self._total_display(self.process_st_time["total"])
129
+
130
+ if not self.total_only:
131
+ process_display = self._display_all(now, total, desc, st_time)
132
+ if self.terminal_y is not None:
133
+ sys.stdout.write(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8")
134
+ sys.stdout.flush()
135
+ else:
136
+ print(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8", flush=True)
137
+
138
+ # Print global information
139
+ def _total_display(self, st_time):
140
+ if self.total_display_callable.value == 1:
141
+ self.total_display_callable.value = 0
142
+
143
+ cnt = sum([self.counts[i] for i in range(self.n_process)])
144
+ if cnt - self.last_cnt.value >= self.log_step:
145
+ total_display = self._display_all(cnt, self.__len__(), f"Total: ", st_time)
146
+ self.last_cnt.value = cnt
147
+
148
+ x = self.n_process + 1 if not self.total_only else 0
149
+ # if self.terminal_y is not None:
150
+ # sys.stdout.write(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8")
151
+ # sys.stdout.flush()
152
+ # else:
153
+ # print(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True)
154
+ print(f"\r\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True, end="")
155
+
156
+ self.total_display_callable.value = 1
157
+
158
+ def run(self):
159
+ """
160
+ The function is used to run a multi-process task
161
+ Returns: return the result of function '_aggregate()'
162
+ """
163
+
164
+ import multiprocess as mp
165
+ mp.set_start_method(self.start_method, force=True)
166
+
167
+ # total number of data that is already processed
168
+ self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
169
+
170
+ # record start time for each process
171
+ self.process_st_time = {"total": time.time()}
172
+
173
+ # set a lock to call total number display
174
+ self.total_display_callable = mp.Value('d', 1)
175
+
176
+ # Save last log iteration number
177
+ self.last_cnt = mp.Value('d', 0)
178
+
179
+ num_per_process = ceil(self.__len__() / self.n_process)
180
+
181
+ if self.save_path is not None:
182
+ file_name, suffix = os.path.splitext(self.save_path)
183
+
184
+ process_list = []
185
+ sub_paths = []
186
+ for i in range(self.n_process):
187
+ st = i * num_per_process
188
+ ed = st + num_per_process
189
+
190
+ # construct slice and sub path for sub process
191
+ data_slice = self.data[st: ed]
192
+
193
+ sub_path = None
194
+ # Create a directory to save sub-results
195
+ if self.save_path is not None:
196
+ save_dir = f"{file_name}{suffix}_temp"
197
+ os.makedirs(save_dir, exist_ok=True)
198
+ sub_path = f"{save_dir}/temp_{i}{suffix}"
199
+
200
+ # construct sub process
201
+ input_args = (i, data_slice, sub_path)
202
+ self.process_st_time[i] = time.time()
203
+ p = mp.Process(target=self._target, args=input_args)
204
+ p.start()
205
+
206
+ process_list.append(p)
207
+ sub_paths.append(sub_path)
208
+
209
+ for p in process_list:
210
+ p.join()
211
+
212
+ # aggregate results and remove temporary directory
213
+ results = self._aggregate(self.save_path, sub_paths)
214
+ if self.save_path is not None:
215
+ save_dir = f"{file_name}{suffix}_temp"
216
+ os.rmdir(save_dir)
217
+
218
+ return results
219
+
220
+ def parallel_run(self):
221
+ import multiprocess as mp
222
+ from joblib import Parallel, delayed
223
+
224
+ # total number of data that is already processed
225
+ self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
226
+
227
+ # record start time for each process
228
+ self.process_st_time = {"total": time.time()}
229
+
230
+ # set a lock to call total number display
231
+ self.total_display_callable = mp.Value('d', 1)
232
+
233
+ # Save last log iteration number
234
+ self.last_cnt = mp.Value('d', 0)
235
+
236
+ num_per_process = ceil(self.__len__() / self.n_process)
237
+
238
+ if self.save_path is not None:
239
+ file_name, suffix = os.path.splitext(self.save_path)
240
+
241
+ sub_paths = []
242
+ input_arg_list = []
243
+ for i in range(self.n_process):
244
+ st = i * num_per_process
245
+ ed = st + num_per_process
246
+
247
+ # construct slice and sub path for sub process
248
+ data_slice = self.data[st: ed]
249
+
250
+ sub_path = None
251
+ # Create a directory to save sub-results
252
+ if self.save_path is not None:
253
+ save_dir = f"{file_name}{suffix}_temp"
254
+ os.makedirs(save_dir, exist_ok=True)
255
+ sub_path = f"{save_dir}/temp_{i}{suffix}"
256
+
257
+ # construct sub process
258
+ input_args = (i, data_slice, sub_path)
259
+ self.process_st_time[i] = time.time()
260
+
261
+ sub_paths.append(sub_path)
262
+ input_arg_list.append(input_args)
263
+
264
+ # Start parallel processing
265
+ Parallel(n_jobs=self.n_process)(delayed(self._target)(input_args) for input_args in input_arg_list)
266
+
267
+ # aggregate results and remove temporary directory
268
+ results = self._aggregate(self.save_path, sub_paths)
269
+ if self.save_path is not None:
270
+ save_dir = f"{file_name}{suffix}_temp"
271
+ os.rmdir(save_dir)
272
+
273
+ return results
274
+
275
+
276
+ @abc.abstractmethod
277
+ def _aggregate(self, final_path: str, sub_paths):
278
+ """
279
+ This function is used to aggregate results from sub processes into a file
280
+
281
+ Args:
282
+ final_path: path to save final results
283
+ sub_paths : list of sub paths
284
+
285
+ Returns: None or desirable results specified by user
286
+
287
+ """
288
+ raise NotImplementedError
289
+
290
+ @abc.abstractmethod
291
+ def _target(self, process_id, data, sub_path):
292
+ """
293
+ The main body to operate data in one process
294
+
295
+ Args:
296
+ i : process id
297
+ data : data slice
298
+ sub_path: sub path to save results
299
+ """
300
+ raise NotImplementedError
301
+
302
+ @abc.abstractmethod
303
+ def __len__(self):
304
+ raise NotImplementedError
305
+
306
+
307
+ class MultipleProcessRunnerSimplifier(MultipleProcessRunner):
308
+ """
309
+ A simplified version of MultipleProcessRunner.
310
+ User only need to implement the function 'do', then it will be automatically executed
311
+ in every iteration after call the function 'run'.
312
+ If 'save_path' is specified, it will open a file in the 'sub_path' into which
313
+ user can write results, and results will be aggregated into 'save_path'.
314
+
315
+ The procedure would be like:
316
+ ...
317
+ with open(sub_path, 'w') as w:
318
+ for i, d in enumerate(data):
319
+ self.do(process_id, i, d, w) # You can write results into the file.
320
+ ...
321
+
322
+ The 'do' function should be like:
323
+ def do(process_id, idx, data, writer):
324
+ ...
325
+
326
+ If 'save_path' is None, the argument 'writer' will be set to None.
327
+
328
+ """
329
+
330
+ def __init__(self,
331
+ data,
332
+ do,
333
+ save_path=None,
334
+ n_process=1,
335
+ verbose=True,
336
+ total_only=True,
337
+ log_step=1,
338
+ return_results=False,
339
+ start_method='fork'):
340
+
341
+ super().__init__(data=data,
342
+ save_path=save_path,
343
+ n_process=n_process,
344
+ verbose=verbose,
345
+ total_only=total_only,
346
+ log_step=log_step,
347
+ start_method=start_method)
348
+ self.do = do
349
+ self.return_results = return_results
350
+
351
+ def run(self):
352
+ self.start_time = time.time()
353
+ return super().run()
354
+
355
+ def _aggregate(self, final_path: str, sub_paths):
356
+ results = []
357
+
358
+ w = open(final_path, 'w') if final_path is not None else None
359
+
360
+ if self.verbose:
361
+ iterator = tqdm(enumerate(sub_paths), "Aggregating results...")
362
+ else:
363
+ iterator = enumerate(sub_paths)
364
+
365
+ for i, sub_path in iterator:
366
+ if sub_path is None and self.return_results:
367
+ sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{i}.tmp"
368
+
369
+ if sub_path is not None:
370
+ with open(sub_path, 'r') as r:
371
+ for line in r:
372
+ if w is not None:
373
+ w.write(line)
374
+
375
+ if self.return_results:
376
+ results.append(line[:-1])
377
+
378
+ os.remove(sub_path)
379
+
380
+ return results
381
+
382
+ def _target(self, process_id, data, sub_path):
383
+ if sub_path is None and self.return_results:
384
+ sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{process_id}.tmp"
385
+
386
+ w = open(sub_path, 'w') if sub_path is not None else None
387
+ for i, d in enumerate(data):
388
+ self.do(process_id, i, d, w)
389
+ if self.verbose:
390
+ self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} running...")
391
+
392
+ if w is not None:
393
+ w.close()
394
+
395
+ def __len__(self):
396
+ return len(self.data)
397
+