Update utils/foldseek_util.py
Browse files- utils/foldseek_util.py +120 -120
utils/foldseek_util.py
CHANGED
@@ -1,121 +1,121 @@
|
|
1 |
-
import os
|
2 |
-
import time
|
3 |
-
import json
|
4 |
-
import numpy as np
|
5 |
-
import re
|
6 |
-
import sys
|
7 |
-
sys.path.append(".")
|
8 |
-
|
9 |
-
|
10 |
-
# Get structural seqs from pdb file
|
11 |
-
def get_struc_seq(foldseek,
|
12 |
-
path,
|
13 |
-
chains: list = None,
|
14 |
-
process_id: int = 0,
|
15 |
-
plddt_mask: bool = False,
|
16 |
-
plddt_threshold: float = 70.,
|
17 |
-
foldseek_verbose: bool = False) -> dict:
|
18 |
-
"""
|
19 |
-
|
20 |
-
Args:
|
21 |
-
foldseek: Binary executable file of foldseek
|
22 |
-
|
23 |
-
path: Path to pdb file
|
24 |
-
|
25 |
-
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
|
26 |
-
|
27 |
-
process_id: Process ID for temporary files. This is used for parallel processing.
|
28 |
-
|
29 |
-
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
|
30 |
-
|
31 |
-
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
|
32 |
-
|
33 |
-
foldseek_verbose: If True, foldseek will print verbose messages.
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
|
37 |
-
(seq, struc_seq, combined_seq).
|
38 |
-
"""
|
39 |
-
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
40 |
-
assert os.path.exists(path), f"PDB file not found: {path}"
|
41 |
-
|
42 |
-
tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
|
43 |
-
if foldseek_verbose:
|
44 |
-
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
45 |
-
else:
|
46 |
-
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
47 |
-
os.system(cmd)
|
48 |
-
|
49 |
-
seq_dict = {}
|
50 |
-
name = os.path.basename(path)
|
51 |
-
with open(tmp_save_path, "r") as r:
|
52 |
-
for i, line in enumerate(r):
|
53 |
-
desc, seq, struc_seq = line.split("\t")[:3]
|
54 |
-
|
55 |
-
# Mask low plddt
|
56 |
-
if plddt_mask:
|
57 |
-
plddts = extract_plddt(path)
|
58 |
-
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
|
59 |
-
|
60 |
-
# Mask regions with plddt < threshold
|
61 |
-
indices = np.where(plddts < plddt_threshold)[0]
|
62 |
-
np_seq = np.array(list(struc_seq))
|
63 |
-
np_seq[indices] = "#"
|
64 |
-
struc_seq = "".join(np_seq)
|
65 |
-
|
66 |
-
name_chain = desc.split(" ")[0]
|
67 |
-
chain = name_chain.replace(name, "").split("_")[-1]
|
68 |
-
|
69 |
-
if chains is None or chain in chains:
|
70 |
-
if chain not in seq_dict:
|
71 |
-
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
|
72 |
-
seq_dict[chain] = (seq, struc_seq, combined_seq)
|
73 |
-
|
74 |
-
os.remove(tmp_save_path)
|
75 |
-
os.remove(tmp_save_path + ".dbtype")
|
76 |
-
return seq_dict
|
77 |
-
|
78 |
-
|
79 |
-
def extract_plddt(pdb_path: str) -> np.ndarray:
|
80 |
-
"""
|
81 |
-
Extract plddt scores from pdb file.
|
82 |
-
Args:
|
83 |
-
pdb_path: Path to pdb file.
|
84 |
-
|
85 |
-
Returns:
|
86 |
-
plddts: plddt scores.
|
87 |
-
"""
|
88 |
-
with open(pdb_path, "r") as r:
|
89 |
-
plddt_dict = {}
|
90 |
-
for line in r:
|
91 |
-
line = re.sub(' +', ' ', line).strip()
|
92 |
-
splits = line.split(" ")
|
93 |
-
|
94 |
-
if splits[0] == "ATOM":
|
95 |
-
# If position < 1000
|
96 |
-
if len(splits[4]) == 1:
|
97 |
-
pos = int(splits[5])
|
98 |
-
|
99 |
-
# If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
|
100 |
-
# So the length of splits[4] is not 1
|
101 |
-
else:
|
102 |
-
pos = int(splits[4][1:])
|
103 |
-
|
104 |
-
plddt = float(splits[-2])
|
105 |
-
|
106 |
-
if pos not in plddt_dict:
|
107 |
-
plddt_dict[pos] = [plddt]
|
108 |
-
else:
|
109 |
-
plddt_dict[pos].append(plddt)
|
110 |
-
|
111 |
-
plddts = np.array([np.mean(v) for v in plddt_dict.values()])
|
112 |
-
return plddts
|
113 |
-
|
114 |
-
|
115 |
-
if __name__ == '__main__':
|
116 |
-
foldseek = "/sujin/bin/foldseek"
|
117 |
-
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
|
118 |
-
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
|
119 |
-
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
|
120 |
-
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
|
121 |
print(res["A"][1].lower())
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
sys.path.append(".")
|
8 |
+
|
9 |
+
|
10 |
+
# Get structural seqs from pdb file
|
11 |
+
def get_struc_seq(foldseek,
|
12 |
+
path,
|
13 |
+
chains: list = None,
|
14 |
+
process_id: int = 0,
|
15 |
+
plddt_mask: bool = False,
|
16 |
+
plddt_threshold: float = 70.,
|
17 |
+
foldseek_verbose: bool = False) -> dict:
|
18 |
+
"""
|
19 |
+
|
20 |
+
Args:
|
21 |
+
foldseek: Binary executable file of foldseek
|
22 |
+
|
23 |
+
path: Path to pdb file
|
24 |
+
|
25 |
+
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
|
26 |
+
|
27 |
+
process_id: Process ID for temporary files. This is used for parallel processing.
|
28 |
+
|
29 |
+
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
|
30 |
+
|
31 |
+
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
|
32 |
+
|
33 |
+
foldseek_verbose: If True, foldseek will print verbose messages.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
|
37 |
+
(seq, struc_seq, combined_seq).
|
38 |
+
"""
|
39 |
+
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
40 |
+
assert os.path.exists(path), f"PDB file not found: {path}"
|
41 |
+
|
42 |
+
tmp_save_path = f"/tmp/get_struc_seq_{process_id}_{time.time()}.tsv"
|
43 |
+
if foldseek_verbose:
|
44 |
+
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
45 |
+
else:
|
46 |
+
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
47 |
+
os.system(cmd)
|
48 |
+
|
49 |
+
seq_dict = {}
|
50 |
+
name = os.path.basename(path)
|
51 |
+
with open(tmp_save_path, "r") as r:
|
52 |
+
for i, line in enumerate(r):
|
53 |
+
desc, seq, struc_seq = line.split("\t")[:3]
|
54 |
+
|
55 |
+
# Mask low plddt
|
56 |
+
if plddt_mask:
|
57 |
+
plddts = extract_plddt(path)
|
58 |
+
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
|
59 |
+
|
60 |
+
# Mask regions with plddt < threshold
|
61 |
+
indices = np.where(plddts < plddt_threshold)[0]
|
62 |
+
np_seq = np.array(list(struc_seq))
|
63 |
+
np_seq[indices] = "#"
|
64 |
+
struc_seq = "".join(np_seq)
|
65 |
+
|
66 |
+
name_chain = desc.split(" ")[0]
|
67 |
+
chain = name_chain.replace(name, "").split("_")[-1]
|
68 |
+
|
69 |
+
if chains is None or chain in chains:
|
70 |
+
if chain not in seq_dict:
|
71 |
+
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
|
72 |
+
seq_dict[chain] = (seq, struc_seq, combined_seq)
|
73 |
+
|
74 |
+
os.remove(tmp_save_path)
|
75 |
+
os.remove(tmp_save_path + ".dbtype")
|
76 |
+
return seq_dict
|
77 |
+
|
78 |
+
|
79 |
+
def extract_plddt(pdb_path: str) -> np.ndarray:
|
80 |
+
"""
|
81 |
+
Extract plddt scores from pdb file.
|
82 |
+
Args:
|
83 |
+
pdb_path: Path to pdb file.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
plddts: plddt scores.
|
87 |
+
"""
|
88 |
+
with open(pdb_path, "r") as r:
|
89 |
+
plddt_dict = {}
|
90 |
+
for line in r:
|
91 |
+
line = re.sub(' +', ' ', line).strip()
|
92 |
+
splits = line.split(" ")
|
93 |
+
|
94 |
+
if splits[0] == "ATOM":
|
95 |
+
# If position < 1000
|
96 |
+
if len(splits[4]) == 1:
|
97 |
+
pos = int(splits[5])
|
98 |
+
|
99 |
+
# If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
|
100 |
+
# So the length of splits[4] is not 1
|
101 |
+
else:
|
102 |
+
pos = int(splits[4][1:])
|
103 |
+
|
104 |
+
plddt = float(splits[-2])
|
105 |
+
|
106 |
+
if pos not in plddt_dict:
|
107 |
+
plddt_dict[pos] = [plddt]
|
108 |
+
else:
|
109 |
+
plddt_dict[pos].append(plddt)
|
110 |
+
|
111 |
+
plddts = np.array([np.mean(v) for v in plddt_dict.values()])
|
112 |
+
return plddts
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
foldseek = "/sujin/bin/foldseek"
|
117 |
+
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
|
118 |
+
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
|
119 |
+
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
|
120 |
+
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
|
121 |
print(res["A"][1].lower())
|