oucgc1996 commited on
Commit
9259a4a
1 Parent(s): 1d866ee

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +403 -0
utils.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Gabriele Orlando
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os,torch
16
+ from pyuul.sources.globalVariables import *
17
+ from pyuul.sources import hashings
18
+
19
+ import numpy as np
20
+ import random
21
+
22
+ def setup_seed(seed):
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ np.random.seed(seed)
26
+ random.seed(seed)
27
+ torch.backends.cudnn.deterministic = True
28
+ setup_seed(100)
29
+
30
+ def parseSDF(SDFFile):
31
+ """
32
+ function to parse pdb files. It can be used to parse a single file or all the pdb files in a folder. In case a folder is given, the coordinates are gonna be padded
33
+
34
+ Parameters
35
+ ----------
36
+ SDFFile : str
37
+ path of the PDB file or of the folder containing multiple PDB files
38
+
39
+ Returns
40
+ -------
41
+ coords : torch.Tensor
42
+ coordinates of the atoms in the pdb file(s). Shape ( batch, numberOfAtoms, 3)
43
+
44
+ atomNames : list
45
+ a list of the atom identifier. It encodes atom type, residue type, residue position and chain
46
+
47
+ """
48
+ if not os.path.isdir(SDFFile):
49
+ fil = SDFFile
50
+ totcoords=[]
51
+ totaname=[]
52
+ coords = []
53
+ atomNames = []
54
+ for line in open(fil).readlines():
55
+ a=line.strip().split()
56
+ if len(a)==16: ## atom
57
+ element = a[3]
58
+ x = float(a[0])
59
+ y = float(a[1])
60
+ z = float(a[2])
61
+ coords += [[x,y,z]]
62
+ #aname = line[17:20].strip()+"_"+str(resnum)+"_"+line[12:16].strip()+"_"+line[21]
63
+ aname = "MOL"+"_"+"0"+"_"+element+"_"+"A"
64
+
65
+ atomNames += [aname]
66
+ elif "$$$$" in line:
67
+ totcoords+=[torch.tensor(coords)]
68
+ totaname += [atomNames]
69
+ coords=[]
70
+ atomNames=[]
71
+ return torch.torch.nn.utils.rnn.pad_sequence(totcoords, batch_first=True, padding_value=PADDING_INDEX),totaname
72
+ else:
73
+ totcoords = []
74
+ totaname = []
75
+ for fil in sorted(os.listdir(SDFFile)):
76
+ coords = []
77
+ atomNames = []
78
+ for line in open(SDFFile+fil).readlines():
79
+ a = line.strip().split()
80
+ if len(a) == 16: ## atom
81
+ element = a[3]
82
+ x = float(a[0])
83
+ y = float(a[1])
84
+ z = float(a[2])
85
+ coords += [[x, y, z]]
86
+ aname = "MOL"+"_"+"0"+"_"+element+"_"+"A"
87
+
88
+ atomNames += [aname]
89
+ elif "$$$$" in line:
90
+ totcoords += [torch.tensor(coords)]
91
+ totaname += [atomNames]
92
+ coords = []
93
+ atomNames = []
94
+ return torch.torch.nn.utils.rnn.pad_sequence(totcoords, batch_first=True, padding_value=PADDING_INDEX),totaname
95
+
96
+
97
+ def parsePDB(PDBFile,keep_only_chains=None,keep_hetatm=True,bb_only=False):
98
+
99
+ """
100
+ function to parse pdb files. It can be used to parse a single file or all the pdb files in a folder. In case a folder is given, the coordinates are gonna be padded
101
+
102
+ Parameters
103
+ ----------
104
+ PDBFile : str
105
+ path of the PDB file or of the folder containing multiple PDB files
106
+ bb_only : bool
107
+ if True ignores all the atoms but backbone N, C and CA
108
+ keep_only_chains : str or None
109
+ ignores all the chain but the one given. If None it keeps all chains
110
+ keep_hetatm : bool
111
+ if False it ignores heteroatoms
112
+ Returns
113
+ -------
114
+ coords : torch.Tensor
115
+ coordinates of the atoms in the pdb file(s). Shape ( batch, numberOfAtoms, 3)
116
+
117
+ atomNames : list
118
+ a list of the atom identifier. It encodes atom type, residue type, residue position and chain
119
+
120
+ """
121
+
122
+ bbatoms = ["N", "CA", "C"]
123
+ if not os.path.isdir(PDBFile):
124
+ fil = PDBFile
125
+ coords = []
126
+ atomNames = []
127
+ cont = -1
128
+ oldres=-999
129
+ for line in open(fil).readlines():
130
+
131
+
132
+ if line[:4] == "ATOM":
133
+ if keep_only_chains is not None and (not line[21] in keep_only_chains):
134
+ continue
135
+ if bb_only and not line[12:16].strip() in bbatoms:
136
+ continue
137
+ if oldres != int(line[22:26]):
138
+ cont+=1
139
+ oldres=int(line[22:26])
140
+ resnum = int(line[22:26])
141
+ atomNames += [line[17:20].strip()+"_"+str(resnum)+"_"+line[12:16].strip()+"_"+line[21]]
142
+
143
+ x = float(line[30:38])
144
+ y = float(line[38:46])
145
+ z = float(line[47:54])
146
+ coords+=[[x,y,z]]
147
+
148
+ elif line[:6] == "HETATM" and keep_hetatm:
149
+
150
+ resname_het = line[17:20].strip()
151
+ resnum = int(line[22:26])
152
+ x = float(line[30:38])
153
+ y = float(line[38:46])
154
+ z = float(line[47:54])
155
+ coords += [[x, y, z]]
156
+ atnameHet = line[12:16].strip()
157
+ atomNames += [resname_het+"_"+str(resnum)+"_"+atnameHet+"_"+line[21]]
158
+ return torch.tensor(coords).unsqueeze(0), [atomNames]
159
+ else:
160
+ coords = []
161
+ atomNames = []
162
+ pdbname = []
163
+ pdb_num = 0
164
+ for fil in sorted(os.listdir(PDBFile)):
165
+ # print(pdb_num)
166
+ pdb_num +=1
167
+ pdbname.append(fil)
168
+ atomNamesTMP = []
169
+ coordsTMP = []
170
+ cont = -1
171
+ oldres=-999
172
+ for line in open(PDBFile+"/"+fil).readlines():
173
+
174
+ if line[:4] == "ATOM":
175
+ if keep_only_chains is not None and (not line[21] in keep_only_chains):
176
+ continue
177
+ if bb_only and not line[12:16].strip() in bbatoms:
178
+ continue
179
+ if oldres != int(line[22:26]):
180
+ cont += 1
181
+ oldres = int(line[22:26])
182
+
183
+ resnum = int(line[22:26])
184
+ atomNamesTMP += [line[17:20].strip()+"_"+str(resnum)+"_"+line[12:16].strip()+"_"+line[21]]
185
+
186
+ x = float(line[30:38])
187
+ y = float(line[38:46])
188
+ z = float(line[47:54])
189
+ coordsTMP+=[[x,y,z]]
190
+
191
+ elif line[:6] == "HETATM" and keep_hetatm:
192
+ if line[17:20].strip()!="GTP":
193
+ continue
194
+ x = float(line[30:38])
195
+ y = float(line[38:46])
196
+ z = float(line[47:54])
197
+ resnum = int(line[22:26])
198
+ coordsTMP += [[x, y, z]]
199
+ atnameHet = line[12:16].strip()
200
+ atomNamesTMP += ["HET_"+str(resnum)+"_"+atnameHet+"_"+line[21]]
201
+ coords+=[torch.tensor(coordsTMP)]
202
+ atomNames += [atomNamesTMP]
203
+
204
+ return torch.torch.nn.utils.rnn.pad_sequence(coords, batch_first=True, padding_value=PADDING_INDEX), atomNames, pdbname, pdb_num
205
+
206
+
207
+ def atomlistToChannels(atomNames, hashing="Element_Hashing", device="cpu"):
208
+ """
209
+ function to get channels from atom names (obtained parsing the pdb files with the parsePDB function)
210
+
211
+ Parameters
212
+ ----------
213
+ atomNames : list
214
+ atom names obtained parsing the pdb files with the parsePDB function
215
+
216
+ hashing : "TPL_Hashing" or "Element_Hashing" or dict
217
+ define which atoms are grouped together. You can use two default hashings or build your own hashing:
218
+
219
+ TPL_Hashing: uses the hashing of torch protein library (https://github.com/lupoglaz/TorchProteinLibrary)
220
+ Element_Hashing: groups atoms in accordnce with the element only: C -> 0, N -> 1, O ->2, P ->3, S- >4, H ->5, everything else ->6
221
+
222
+ Alternatively, if you are not happy with the default hashings, you can build a dictionary of dictionaries that defines the channel of every atom type in the pdb.
223
+ the first dictionary has the residue tag (three letters amino acid code) as key (3 letters compound name for hetero atoms, as written in the PDB file)
224
+ every residue key is associated to a dictionary, which the atom tags (as written in the PDB files) as keys and the channel (int) as value
225
+
226
+ for example, you can define the channels just based on the atom element as following:
227
+ {
228
+ 'CYS': {'N': 1, 'O': 2, 'C': 0, 'SG': 3, 'CB': 0, 'CA': 0}, # channels for cysteine atoms
229
+ 'GLY': {'N': 1, 'O': 2, 'C': 0, 'CA': 0}, # channels for glycine atom
230
+ ...
231
+ 'GOL': {'O1':2,'O2':2,'O3':2,'C1':0,'C2':0,'C3':0}, # channels for glycerol atom
232
+ ...
233
+ }
234
+
235
+ The default encoding is the one that assigns a different channel to each element
236
+
237
+ other encodings can be found in sources/hashings.py
238
+
239
+ device : torch.device
240
+ The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
241
+ Returns
242
+ -------
243
+ coords : torch.Tensor
244
+ coordinates of the atoms in the pdb file(s). Shape ( batch, numberOfAtoms, 3)
245
+
246
+ channels : torch.tensor
247
+ the channel of every atom. Shape (batch,numberOfAtoms)
248
+
249
+ """
250
+ if hashing == "TPL_Hashing":
251
+ hashing = hashings.TPLatom_hash
252
+
253
+ elif hashing == "Element_Hashing":
254
+ hashing = hashings.elements_hash
255
+ else:
256
+ assert type(hashing) is dict
257
+
258
+ if type(hashing[list(hashing.keys())[0]]) == dict:
259
+ useResName = True
260
+ else:
261
+ useResName = False
262
+ assert type(hashing[list(hashing.keys())[0]]) == int
263
+ channels = []
264
+ for singleAtomList in atomNames:
265
+ haTMP = []
266
+ for i in singleAtomList:
267
+ resname = i.split("_")[0]
268
+ atName = i.split("_")[2]
269
+ # if resname=="HET":
270
+ # atName="HET"
271
+ if useResName:
272
+ if resname in hashing and atName in hashing[resname]:
273
+ haTMP += [hashing[resname][atName]]
274
+ else:
275
+ haTMP += [PADDING_INDEX]
276
+ print("missing ", resname, atName)
277
+ else:
278
+ if atName in hashing:
279
+ haTMP += [hashing[atName]]
280
+ elif atName[0] in hashing:
281
+ haTMP += [hashing[atName[0]]]
282
+ elif hashing == "Element_Hashing":
283
+ haTMP += [6]
284
+ else:
285
+ haTMP += [PADDING_INDEX]
286
+ print("missing ", resname, atName)
287
+
288
+ channels += [torch.tensor(haTMP, dtype=torch.float, device=device)]
289
+ channels = torch.torch.nn.utils.rnn.pad_sequence(channels, batch_first=True, padding_value=PADDING_INDEX)
290
+ return channels
291
+
292
+
293
+ def atomlistToRadius(atomList, hashing="FoldX_radius", device="cpu"):
294
+ """
295
+ function to get radius from atom names (obtained parsing the pdb files with the parsePDB function)
296
+
297
+
298
+
299
+ Parameters
300
+ ----------
301
+ atomNames : list
302
+ atom names obtained parsing the pdb files with the parsePDB function
303
+ hashing : FoldX_radius or dict
304
+ "FoldX_radius" provides the radius used by the FoldX force field
305
+
306
+ Alternatively, if you are not happy with the foldX radius, you can build a dictionary of dictionaries that defines the radius of every atom type in the pdb.
307
+ The first dictionary has the residue tag (three letters amino acid code) as key (3 letters compound name for hetero atoms, as written in the PDB file)
308
+ every residue key is associated to a dictionary, which the atom tags (as written in the PDB files) as keys and the radius (float) as value
309
+
310
+ for example, you can define the radius as following:
311
+ {
312
+ 'CYS': {'N': 1.45, 'O': 1.37, 'C': 1.7, 'SG': 1.7, 'CB': 1.7, 'CA': 1.7}, # radius for cysteine atoms
313
+ 'GLY': {'N': 1.45, 'O': 1.37, 'C': 1.7, 'CA': 1.7}, # radius for glycine atoms
314
+ ...
315
+ 'GOL': {'O1':1.37,'O2':1.37,'O3':1.37,'C1':1.7,'C2':1.7,'C3':1.7}, # radius for glycerol atoms
316
+ ...
317
+ }
318
+
319
+ The default radius are the ones defined in FoldX
320
+
321
+ Radius default dictionary can be found in sources/hashings.py
322
+
323
+ device : torch.device
324
+ The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
325
+ Returns
326
+ -------
327
+ coords : torch.Tensor
328
+ coordinates of the atoms in the pdb file(s). Shape ( batch, numberOfAtoms, 3)
329
+
330
+ radius : torch.tensor
331
+ The radius of every atom. Shape (batch,numberOfAtoms)
332
+
333
+ """
334
+ if hashing == "FoldX_radius":
335
+ hashing = hashings.radius
336
+ hahsingSomgleAtom = hashings.radiusSingleAtom
337
+ else:
338
+ assert type(hashing) is dict
339
+
340
+ radius = []
341
+ for singleAtomList in atomList:
342
+ haTMP = []
343
+ for i in singleAtomList:
344
+ resname = i.split("_")[0]
345
+ atName = i.split("_")[2]
346
+ if resname in hashing and atName in hashing[resname]:
347
+ haTMP += [hashing[resname][atName]]
348
+ elif atName[0] in hahsingSomgleAtom:
349
+ haTMP += [hahsingSomgleAtom[atName[0]]]
350
+ else:
351
+ haTMP += [1.0]
352
+ print("missing ", resname, atName)
353
+ radius += [torch.tensor(haTMP, dtype=torch.float, device=device)]
354
+ radius = torch.torch.nn.utils.rnn.pad_sequence(radius, batch_first=True, padding_value=PADDING_INDEX)
355
+ return radius
356
+
357
+
358
+ '''
359
+ def write_pdb(batchedCoords, atomNames , name=None, output_folder="outpdb/"): #I need to add the chain id
360
+
361
+ if name is None:
362
+ name = range(len(batchedCoords))
363
+
364
+ for struct in range(len(name)):
365
+ f = open(output_folder + str(name[struct]) + ".pdb", "w")
366
+
367
+ coords=batchedCoords[struct].data.numpy()
368
+ atname=atomNames[struct]
369
+ for i in range(len(coords)):
370
+
371
+ rnName = atname[i].split("_")[0]#hashings.resi_hash_inverse[resi_list[i]]
372
+ atName = atname[i].split("_")[2]#hashings.atom_hash_inverse[resi_list[i]][atom_list[i]]
373
+ pos = atname[i].split("_")[1]
374
+ chain = "A"
375
+
376
+ num = " " * (5 - len(str(i))) + str(i)
377
+ a_name = atName + " " * (4 - len(atName))
378
+ numres = " " * (4 - len(str(pos))) + str(pos)
379
+
380
+ x = round(float(coords[i][0]), 3)
381
+ sx = str(x)
382
+ while len(sx.split(".")[1]) < 3:
383
+ sx += "0"
384
+ x = " " * (8 - len(sx)) + sx
385
+
386
+ y = round(float(coords[i][1]), 3)
387
+ sy = str(y)
388
+ while len(sy.split(".")[1]) < 3:
389
+ sy += "0"
390
+ y = " " * (8 - len(sy)) + sy
391
+
392
+ z = round(float(coords[i][2]), 3)
393
+ sz = str(z)
394
+ while len(sz.split(".")[1]) < 3:
395
+ sz += "0"
396
+ z = " " * (8 - len(sz)) + sz
397
+ chain = " " * (2 - len(chain)) + chain
398
+
399
+ if rnName !="HET":
400
+ f.write("ATOM " + num + " " + a_name + "" + rnName + chain + numres + " " + x + y + z + " 1.00 64.10 " + atName[0] + "\n")
401
+ else:
402
+ f.write("HETATM" + num + " " + a_name + "" + rnName + chain + numres + " " + x + y + z + " 1.00 64.10 " + atName[0] + "\n")
403
+ '''