LTEnjoy commited on
Commit
4e47952
1 Parent(s): 0e99665

Update utils/foldseek_util.py

Browse files
Files changed (1) hide show
  1. utils/foldseek_util.py +120 -120
utils/foldseek_util.py CHANGED
@@ -1,121 +1,121 @@
1
- import os
2
- import time
3
- import json
4
- import numpy as np
5
- import re
6
- import sys
7
- sys.path.append(".")
8
-
9
-
10
- # Get structural seqs from pdb file
11
- def get_struc_seq(foldseek,
12
- path,
13
- chains: list = None,
14
- process_id: int = 0,
15
- plddt_mask: bool = False,
16
- plddt_threshold: float = 70.,
17
- foldseek_verbose: bool = False) -> dict:
18
- """
19
-
20
- Args:
21
- foldseek: Binary executable file of foldseek
22
-
23
- path: Path to pdb file
24
-
25
- chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
26
-
27
- process_id: Process ID for temporary files. This is used for parallel processing.
28
-
29
- plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
30
-
31
- plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
32
-
33
- foldseek_verbose: If True, foldseek will print verbose messages.
34
-
35
- Returns:
36
- seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
37
- (seq, struc_seq, combined_seq).
38
- """
39
- assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
40
- assert os.path.exists(path), f"PDB file not found: {path}"
41
-
42
- tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
43
- if foldseek_verbose:
44
- cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
45
- else:
46
- cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
47
- os.system(cmd)
48
-
49
- seq_dict = {}
50
- name = os.path.basename(path)
51
- with open(tmp_save_path, "r") as r:
52
- for i, line in enumerate(r):
53
- desc, seq, struc_seq = line.split("\t")[:3]
54
-
55
- # Mask low plddt
56
- if plddt_mask:
57
- plddts = extract_plddt(path)
58
- assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
59
-
60
- # Mask regions with plddt < threshold
61
- indices = np.where(plddts < plddt_threshold)[0]
62
- np_seq = np.array(list(struc_seq))
63
- np_seq[indices] = "#"
64
- struc_seq = "".join(np_seq)
65
-
66
- name_chain = desc.split(" ")[0]
67
- chain = name_chain.replace(name, "").split("_")[-1]
68
-
69
- if chains is None or chain in chains:
70
- if chain not in seq_dict:
71
- combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
72
- seq_dict[chain] = (seq, struc_seq, combined_seq)
73
-
74
- os.remove(tmp_save_path)
75
- os.remove(tmp_save_path + ".dbtype")
76
- return seq_dict
77
-
78
-
79
- def extract_plddt(pdb_path: str) -> np.ndarray:
80
- """
81
- Extract plddt scores from pdb file.
82
- Args:
83
- pdb_path: Path to pdb file.
84
-
85
- Returns:
86
- plddts: plddt scores.
87
- """
88
- with open(pdb_path, "r") as r:
89
- plddt_dict = {}
90
- for line in r:
91
- line = re.sub(' +', ' ', line).strip()
92
- splits = line.split(" ")
93
-
94
- if splits[0] == "ATOM":
95
- # If position < 1000
96
- if len(splits[4]) == 1:
97
- pos = int(splits[5])
98
-
99
- # If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
100
- # So the length of splits[4] is not 1
101
- else:
102
- pos = int(splits[4][1:])
103
-
104
- plddt = float(splits[-2])
105
-
106
- if pos not in plddt_dict:
107
- plddt_dict[pos] = [plddt]
108
- else:
109
- plddt_dict[pos].append(plddt)
110
-
111
- plddts = np.array([np.mean(v) for v in plddt_dict.values()])
112
- return plddts
113
-
114
-
115
- if __name__ == '__main__':
116
- foldseek = "/sujin/bin/foldseek"
117
- # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
118
- test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
119
- plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
120
- res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
121
  print(res["A"][1].lower())
 
1
+ import os
2
+ import time
3
+ import json
4
+ import numpy as np
5
+ import re
6
+ import sys
7
+ sys.path.append(".")
8
+
9
+
10
+ # Get structural seqs from pdb file
11
+ def get_struc_seq(foldseek,
12
+ path,
13
+ chains: list = None,
14
+ process_id: int = 0,
15
+ plddt_mask: bool = False,
16
+ plddt_threshold: float = 70.,
17
+ foldseek_verbose: bool = False) -> dict:
18
+ """
19
+
20
+ Args:
21
+ foldseek: Binary executable file of foldseek
22
+
23
+ path: Path to pdb file
24
+
25
+ chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
26
+
27
+ process_id: Process ID for temporary files. This is used for parallel processing.
28
+
29
+ plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
30
+
31
+ plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
32
+
33
+ foldseek_verbose: If True, foldseek will print verbose messages.
34
+
35
+ Returns:
36
+ seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
37
+ (seq, struc_seq, combined_seq).
38
+ """
39
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
40
+ assert os.path.exists(path), f"PDB file not found: {path}"
41
+
42
+ tmp_save_path = f"/tmp/get_struc_seq_{process_id}_{time.time()}.tsv"
43
+ if foldseek_verbose:
44
+ cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
45
+ else:
46
+ cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
47
+ os.system(cmd)
48
+
49
+ seq_dict = {}
50
+ name = os.path.basename(path)
51
+ with open(tmp_save_path, "r") as r:
52
+ for i, line in enumerate(r):
53
+ desc, seq, struc_seq = line.split("\t")[:3]
54
+
55
+ # Mask low plddt
56
+ if plddt_mask:
57
+ plddts = extract_plddt(path)
58
+ assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
59
+
60
+ # Mask regions with plddt < threshold
61
+ indices = np.where(plddts < plddt_threshold)[0]
62
+ np_seq = np.array(list(struc_seq))
63
+ np_seq[indices] = "#"
64
+ struc_seq = "".join(np_seq)
65
+
66
+ name_chain = desc.split(" ")[0]
67
+ chain = name_chain.replace(name, "").split("_")[-1]
68
+
69
+ if chains is None or chain in chains:
70
+ if chain not in seq_dict:
71
+ combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
72
+ seq_dict[chain] = (seq, struc_seq, combined_seq)
73
+
74
+ os.remove(tmp_save_path)
75
+ os.remove(tmp_save_path + ".dbtype")
76
+ return seq_dict
77
+
78
+
79
+ def extract_plddt(pdb_path: str) -> np.ndarray:
80
+ """
81
+ Extract plddt scores from pdb file.
82
+ Args:
83
+ pdb_path: Path to pdb file.
84
+
85
+ Returns:
86
+ plddts: plddt scores.
87
+ """
88
+ with open(pdb_path, "r") as r:
89
+ plddt_dict = {}
90
+ for line in r:
91
+ line = re.sub(' +', ' ', line).strip()
92
+ splits = line.split(" ")
93
+
94
+ if splits[0] == "ATOM":
95
+ # If position < 1000
96
+ if len(splits[4]) == 1:
97
+ pos = int(splits[5])
98
+
99
+ # If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
100
+ # So the length of splits[4] is not 1
101
+ else:
102
+ pos = int(splits[4][1:])
103
+
104
+ plddt = float(splits[-2])
105
+
106
+ if pos not in plddt_dict:
107
+ plddt_dict[pos] = [plddt]
108
+ else:
109
+ plddt_dict[pos].append(plddt)
110
+
111
+ plddts = np.array([np.mean(v) for v in plddt_dict.values()])
112
+ return plddts
113
+
114
+
115
+ if __name__ == '__main__':
116
+ foldseek = "/sujin/bin/foldseek"
117
+ # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
118
+ test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
119
+ plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
120
+ res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
121
  print(res["A"][1].lower())