oucgc1996 commited on
Commit
1d866ee
1 Parent(s): 2bce059

Upload VolumeMaker.py

Browse files
Files changed (1) hide show
  1. VolumeMaker.py +591 -0
VolumeMaker.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # Copyright 2021 Gabriele Orlando
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch,math
19
+ from pyuul.sources.globalVariables import *
20
+
21
+ import numpy as np
22
+ import random
23
+
24
+ def setup_seed(seed):
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+ np.random.seed(seed)
28
+ random.seed(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ setup_seed(100)
31
+
32
+ class Voxels(torch.nn.Module):
33
+
34
+ def __init__(self, device=torch.device("cpu"),sparse=True):
35
+ """
36
+ Constructor for the Voxels class, which builds the main PyUUL object.
37
+
38
+ Parameters
39
+ ----------
40
+
41
+ device : torch.device
42
+ The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
43
+ sparse : bool
44
+ Use sparse tensors calculation when possible
45
+
46
+ Returns
47
+ -------
48
+ """
49
+ super(Voxels, self).__init__()
50
+
51
+ self.sparse=sparse
52
+ self.boxsize = None
53
+ self.dev = device
54
+
55
+ def __transform_coordinates(self,coords,radius=None):
56
+ """
57
+ Private function that transform the coordinates to fit them in the 3d box. It also takes care of the resolution.
58
+
59
+ Parameters
60
+ ----------
61
+ coords : torch.Tensor
62
+ Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
63
+ radius : torch.Tensor or None
64
+ Radius of the atoms. Shape ( batch, numberOfAtoms )
65
+
66
+ Returns
67
+ -------
68
+ coords : torch.Tensor
69
+ transformed coordinates
70
+
71
+ """
72
+ coords = (coords*self.dilatation)- self.translation
73
+ if not radius is None:
74
+ radius = radius*self.dilatation
75
+ return coords,radius
76
+ else:
77
+ return coords
78
+ '''
79
+ def get_coords_voxel(self, voxel_indices, resolution):
80
+ """
81
+ returns the coordinates of the center of the voxel provided its indices.
82
+
83
+ Parameters
84
+ ----------
85
+ voxel_indices : torch.Tensor
86
+ Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
87
+ resolution : torch.Tensor or None
88
+ Radius of the atoms. Shape ( batch, numberOfAtoms )
89
+
90
+ Returns
91
+ -------
92
+ """
93
+ #voxel_indices is a n,3 long tensor
94
+ centersCoords = voxel_indices + 0.5*resolution
95
+ return (centersCoords + self.translation)/self.dilatation
96
+ '''
97
+ def __define_spatial_conformation(self,mincoords,cubes_around_atoms_dim,resolution):
98
+ """
99
+ Private function that defines the space of the volume. Takes resolution and margins into consideration.
100
+
101
+ Parameters
102
+ ----------
103
+ mincoords : torch.Tensor
104
+ minimum coordinates of each macromolecule of the batch. Shape ( batch, 3 )
105
+ cubes_around_atoms_dim : int
106
+ maximum distance in number of voxels to check for atom contribution to occupancy of a voxel
107
+ resolution : float
108
+ side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
109
+ Returns
110
+ -------
111
+ """
112
+ self.translation=(mincoords-(cubes_around_atoms_dim)).unsqueeze(1)
113
+ self.dilatation = 1.0/resolution
114
+
115
+ '''
116
+ def find_cubes_indices(self,coords):
117
+ coords_scaled = self.transform_coordinates(coords)
118
+ return torch.trunc(coords_scaled.data).long()
119
+ '''
120
+
121
+ def forward( self,coords, radius,channels,numberchannels=None,resolution=1, cubes_around_atoms_dim=5, steepness=10,function="sigmoid"):
122
+ """
123
+ Voxels representation of the macromolecules
124
+
125
+ Parameters
126
+ ----------
127
+ coords : torch.Tensor
128
+ Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
129
+ radius : torch.Tensor
130
+ Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
131
+ channels: torch.LongTensor
132
+ channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
133
+ numberchannels : int or None
134
+ maximum number of channels. if None, max(atNameHashing) + 1 is used
135
+
136
+ cubes_around_atoms_dim : int
137
+ maximum distance in number of voxels for which the contribution to occupancy is taken into consideration. Every atom that is farer than cubes_around_atoms_dim voxels from the center of a voxel does no give any contribution to the relative voxel occupancy
138
+ resolution : float
139
+ side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
140
+
141
+ steepness : float or int
142
+ steepness of the sigmoid occupancy function.
143
+
144
+ function : "sigmoid" or "gaussian"
145
+ occupancy function to use. Can be sigmoid (every atom has a sigmoid shaped occupancy function) or gaussian (based on Li et al. 2014)
146
+ Returns
147
+ -------
148
+ volume : torch.Tensor
149
+ voxel representation of the macromolecules in the batch. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
150
+
151
+ """
152
+ padding_mask = ~channels.eq(PADDING_INDEX)
153
+ if numberchannels is None:
154
+ numberchannels = int(channels[padding_mask].max().cpu().data+1)
155
+ self.featureVectorSize = numberchannels
156
+ self.function = function
157
+
158
+ arange_type = torch.int16
159
+
160
+ gx = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
161
+ gy = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
162
+ gz = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
163
+ self.lato = gx.shape[0]
164
+
165
+ x1 = gx.unsqueeze(1).expand(self.lato, self.lato).unsqueeze(-1)
166
+ x2 = gy.unsqueeze(0).expand(self.lato, self.lato).unsqueeze(-1)
167
+
168
+ xy = torch.cat([x1, x2], dim=-1).unsqueeze(2).expand(self.lato, self.lato, self.lato, 2)
169
+ x3 = gz.unsqueeze(0).unsqueeze(1).expand(self.lato, self.lato, self.lato).unsqueeze(-1)
170
+
171
+ del gx, gy, gz, x1, x2
172
+
173
+ self.standard_cube = torch.cat([xy, x3], dim=-1).unsqueeze(0).unsqueeze(0)
174
+
175
+
176
+
177
+ ### definition of the box ###
178
+ # you take the maximum and min coord on each dimension (every prot in the batch shares the same box. In the future we can pack, but I think this is not the bottleneck)
179
+ # I scale by resolution
180
+ # I add the cubes in which I define the gradient. One in the beginning and one at the end --> 2*
181
+
182
+
183
+
184
+ mincoords = torch.min(coords[:, :, :], dim=1)[0]
185
+ mincoords = torch.trunc(mincoords / resolution)
186
+
187
+
188
+ box_size_x = (math.ceil(torch.max(coords[padding_mask][:,0])/resolution)-mincoords[:,0].min())+(2*cubes_around_atoms_dim+1)
189
+ box_size_y = (math.ceil(torch.max(coords[padding_mask][:,1])/resolution)-mincoords[:,1].min())+(2*cubes_around_atoms_dim+1)
190
+ box_size_z = (math.ceil(torch.max(coords[padding_mask][:,2])/resolution)-mincoords[:,2].min())+(2*cubes_around_atoms_dim+1)
191
+ #############################
192
+
193
+ self.__define_spatial_conformation(mincoords,cubes_around_atoms_dim,resolution) #define the spatial transforms to coordinates
194
+ coords,radius = self.__transform_coordinates(coords,radius)
195
+
196
+ boxsize = (int(box_size_x),int(box_size_y),int(box_size_z))
197
+ self.boxsize=boxsize
198
+
199
+ #selecting best types for indexing
200
+ if max(boxsize)<256: # i can use byte tensor
201
+ self.dtype_indices=torch.uint8
202
+ else:
203
+ self.dtype_indices = torch.int16
204
+
205
+ if self.function=="sigmoid":
206
+ volume = self.__forward_actual_calculation(coords, boxsize, radius, channels,padding_mask,steepness,resolution)
207
+ elif self.function=="gaussian":
208
+ volume = self.__forward_actual_calculationGaussian(coords, boxsize, radius, channels, padding_mask,resolution)
209
+ return volume
210
+
211
+ def __forward_actual_calculationGaussian(self, coords_scaled, boxsize, radius, atNameHashing, padding_mask,resolution):
212
+ """
213
+ private function for the calculation of the gaussian voxel occupancy
214
+
215
+ Parameters
216
+ ----------
217
+ coords_scaled : torch.LongTensor
218
+ Discrete Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
219
+ boxsize : torch.LongTensor
220
+ The size of the box in which the macromolecules are represented
221
+ radius : torch.Tensor
222
+ Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
223
+ atNameHashing: torch.LongTensor
224
+ channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
225
+ resolution : float
226
+ side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
227
+ padding_mask : torch.BoolTensor
228
+ tensor to mask the padding. Shape (batch, numberOfAtoms)
229
+ Returns
230
+ -------
231
+ volume : torch.Tensor
232
+ voxel representation of the macromolecules in the batch with Gaussian occupancy function. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
233
+
234
+ """
235
+ batch = coords_scaled.shape[0]
236
+ dev = self.dev
237
+ L = coords_scaled.shape[1]
238
+
239
+ discrete_coordinates = torch.trunc(coords_scaled.data).to(self.dtype_indices)
240
+
241
+ #### making everything in the volume shape
242
+
243
+ # implicit_cube_formation
244
+ radius = radius.unsqueeze(2).unsqueeze(3).unsqueeze(4)
245
+ atNameHashing = atNameHashing.unsqueeze(2).unsqueeze(3).unsqueeze(4)
246
+ coords_scaled = coords_scaled.unsqueeze(2).unsqueeze(3).unsqueeze(4)
247
+ discrete_coordinates = discrete_coordinates.unsqueeze(2).unsqueeze(3).unsqueeze(4)
248
+ distmat_standard_cube = torch.norm(
249
+ coords_scaled - ((discrete_coordinates + self.standard_cube + 1) + 0.5 * resolution), dim=-1).to(
250
+ coords_scaled.dtype)
251
+
252
+ atNameHashing = atNameHashing.long()
253
+ #### old sigmoid stuff
254
+ '''
255
+ exponent = self.steepness*(distmat_standard_cube-radius)
256
+
257
+ exp_mask = exponent.ge(10)
258
+ exponent = torch.masked_fill(exponent,exp_mask, 10)
259
+
260
+ volume_cubes = 1.0/(1.0+torch.exp(exponent))
261
+ '''
262
+ ### from doi: 10.1142/S0219633614400021 eq 1
263
+ sigma = 0.93
264
+ exponent = -distmat_standard_cube[padding_mask] ** 2 / (sigma ** 2 * radius[padding_mask] ** 2)
265
+ exp_mask = exponent.ge(10)
266
+ exponent = torch.masked_fill(exponent, exp_mask, 10)
267
+ volume_cubes = torch.exp(exponent)
268
+
269
+ #### index_put everything ###
270
+ batch_list = torch.arange(batch,device=dev).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).expand(batch,L,self.lato,self.lato,self.lato)
271
+
272
+ cubes_coords = (discrete_coordinates[padding_mask] + self.standard_cube.squeeze(0) + 1)[~exp_mask]
273
+ atNameHashing = atNameHashing[padding_mask].expand(-1,self.lato,self.lato,self.lato)
274
+ if self.sparse:
275
+
276
+ index_tens = torch.cat(
277
+ [batch_list[padding_mask][~exp_mask].view(-1).unsqueeze(0),
278
+ atNameHashing[~exp_mask].unsqueeze(0),
279
+ cubes_coords[:,0].unsqueeze(0),
280
+ cubes_coords[:,1].unsqueeze(0),
281
+ cubes_coords[:,2].unsqueeze(0),
282
+ ])
283
+ #index_tens = torch.cat(index)
284
+
285
+ volume_cubes = volume_cubes[~exp_mask].view(-1)
286
+ volume_cubes = torch.log(1 - volume_cubes.contiguous())
287
+ #powOrExpIsNotImplementedInSparse
288
+ volume = torch.sparse_coo_tensor(indices=index_tens, values=volume_cubes.exp(), size=[batch, self.featureVectorSize, boxsize[0] , boxsize[1] , boxsize[2] ]).coalesce()
289
+ volume = torch.sparse_coo_tensor(volume.indices(),1 - volume.values(), volume.shape)
290
+
291
+ else:
292
+ volume = torch.zeros(batch,boxsize[0]+1,boxsize[1]+1,boxsize[2]+1,self.featureVectorSize,device=dev,dtype=torch.float)
293
+ #index = (batch_list[padding_mask].view(-1),cubes_coords[padding_mask][:,:,:,:,0].view(-1), cubes_coords[padding_mask][:,:,:,:,1].view(-1), cubes_coords[padding_mask][:,:,:,:,2].view(-1), atNameHashing[padding_mask].view(-1) )
294
+ index = (batch_list[padding_mask][~exp_mask].view(-1).long(),
295
+ cubes_coords[:,0].long(),
296
+ cubes_coords[:,1].long(),
297
+ cubes_coords[:,2].long(),
298
+ atNameHashing[~exp_mask])
299
+ volume_cubes=volume_cubes[~exp_mask].view(-1)
300
+
301
+ volume_cubes = torch.log(1 - volume_cubes.contiguous())
302
+ volume = 1- torch.exp(volume.index_put(index,volume_cubes,accumulate=True))
303
+ #volume = 1 - torch.exp(volume.index_put(index, torch.log(1 - volume_cubes.contiguous().view(-1)), accumulate=True))
304
+ volume=volume.permute(0,4,1,2,3)
305
+ #volume = -torch.nn.functional.threshold(-volume,-1,-1)
306
+
307
+ return volume
308
+
309
+
310
+
311
+ return volume
312
+
313
+ def __sparseClamp(self,volume, minv, maxv):
314
+ vals = volume.values()
315
+ ind = volume.indices()
316
+
317
+ vals = vals.clamp(minv, maxv)
318
+ volume = torch.sparse_coo_tensor(indices=ind, values=vals, size=volume.shape).coalesce()
319
+ return volume
320
+
321
+ def __forward_actual_calculation(self, coords_scaled, boxsize, radius,atNameHashing,padding_mask,steepness,resolution):
322
+ """
323
+ private function for the calculation of the gaussian voxel occupancy
324
+
325
+ Parameters
326
+ ----------
327
+ coords_scaled : torch.LongTensor
328
+ Discrete Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
329
+ boxsize : torch.LongTensor
330
+ The size of the box in which the macromolecules are represented
331
+ radius : torch.Tensor
332
+ Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
333
+ atNameHashing: torch.LongTensor
334
+ channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
335
+ resolution : float
336
+ side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
337
+ padding_mask : torch.BoolTensor
338
+ tensor to mask the padding. Shape (batch, numberOfAtoms)
339
+ steepness : float
340
+ steepness of the sigmoid function (coefficient of the exponent)
341
+
342
+ Returns
343
+ -------
344
+ volume : torch.Tensor
345
+ voxel representation of the macromolecules in the batch with Sigmoid occupancy function. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
346
+
347
+ """
348
+ batch = coords_scaled.shape[0]
349
+ dev=self.dev
350
+ L = coords_scaled.shape[1]
351
+
352
+ discrete_coordinates = torch.trunc(coords_scaled.data).to(self.dtype_indices)
353
+
354
+ #### making everything in the volume shape
355
+
356
+ #implicit_cube_formation
357
+ radius = radius.unsqueeze(2).unsqueeze(3).unsqueeze(4)
358
+ atNameHashing = atNameHashing.unsqueeze(2).unsqueeze(3).unsqueeze(4)
359
+ coords_scaled = coords_scaled.unsqueeze(2).unsqueeze(3).unsqueeze(4)
360
+ discrete_coordinates = discrete_coordinates.unsqueeze(2).unsqueeze(3).unsqueeze(4)
361
+ distmat_standard_cube = torch.norm(coords_scaled-((discrete_coordinates + self.standard_cube + 1) + 0.5 * resolution), dim=-1).to(coords_scaled.dtype)
362
+
363
+ atNameHashing = atNameHashing.long()
364
+
365
+ exponent = steepness*(distmat_standard_cube[padding_mask]-radius[padding_mask])
366
+ del distmat_standard_cube
367
+ exp_mask = exponent.ge(10)
368
+ exponent = torch.masked_fill(exponent,exp_mask, 10)
369
+
370
+ volume_cubes = 1.0/(1.0+torch.exp(exponent))
371
+
372
+ #### index_put everything ###
373
+ batch_list = torch.arange(batch,device=dev).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).expand(batch,L,self.lato,self.lato,self.lato)
374
+
375
+ #cubes_coords = coords_scaled + self.standard_cube + 1
376
+ cubes_coords = (discrete_coordinates[padding_mask] + self.standard_cube.squeeze(0) + 1)[~exp_mask]
377
+ atNameHashing = atNameHashing[padding_mask].expand(-1,self.lato,self.lato,self.lato)
378
+ if self.sparse:
379
+
380
+ index_tens = torch.cat(
381
+ [batch_list[padding_mask][~exp_mask].view(-1).unsqueeze(0),
382
+ atNameHashing[~exp_mask].unsqueeze(0),
383
+ cubes_coords[:,0].unsqueeze(0),
384
+ cubes_coords[:,1].unsqueeze(0),
385
+ cubes_coords[:,2].unsqueeze(0),
386
+ ])
387
+ #index_tens = torch.cat(index)
388
+ volume = torch.sparse_coo_tensor(indices=index_tens, values=volume_cubes[~exp_mask].view(-1), size=[batch, self.featureVectorSize, boxsize[0] , boxsize[1] , boxsize[2] ]).coalesce()
389
+ volume = self.__sparseClamp(volume,0,1)
390
+
391
+ else:
392
+ volume = torch.zeros(batch,boxsize[0]+1,boxsize[1]+1,boxsize[2]+1,self.featureVectorSize,device=dev,dtype=torch.float)
393
+ #index = (batch_list[padding_mask].view(-1),cubes_coords[padding_mask][:,:,:,:,0].view(-1), cubes_coords[padding_mask][:,:,:,:,1].view(-1), cubes_coords[padding_mask][:,:,:,:,2].view(-1), atNameHashing[padding_mask].view(-1) )
394
+ index = (batch_list[padding_mask][~exp_mask].view(-1).long(),
395
+ cubes_coords[:,0].long(),
396
+ cubes_coords[:,1].long(),
397
+ cubes_coords[:,2].long(),
398
+ atNameHashing[~exp_mask])
399
+ volume_cubes=volume_cubes[~exp_mask].view(-1)
400
+
401
+ volume = volume.index_put(index,volume_cubes.view(-1),accumulate=True)
402
+
403
+ volume = -torch.nn.functional.threshold(-volume,-1,-1)
404
+ volume = volume.permute(0,4,1,2,3)
405
+
406
+ return volume
407
+ '''
408
+ mesh will be added as soon as pytorch3d becomes a little more stable
409
+ def mesh(self,coords, radius,threshSurface = 0.01):
410
+
411
+ atNameHashing= torch.zeros(radius.shape).to(self.dev)
412
+ mask = radius.eq(PADDING_INDEX)
413
+ atNameHashing = atNameHashing.masked_fill_(mask,PADDING_INDEX)
414
+ vol = self(coords,radius,atNameHashing).to_dense()
415
+ mesh = cubifyNOALIGN(vol.sum(-1),thresh=threshSurface)# creates pytorch 3d mesh from cubes. It uses a MODIFIED version of pytorch3d with no align
416
+ return mesh
417
+ '''
418
+
419
+ class PointCloudSurface(torch.nn.Module):
420
+ def __init__(self,device="cpu"):
421
+ """
422
+ Constructor for the CloudPointSurface class, which builds the main PyUUL object for cloud surface.
423
+
424
+ Parameters
425
+ ----------
426
+ device : torch.device
427
+ The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
428
+
429
+
430
+ Returns
431
+ -------
432
+ """
433
+ super(PointCloudSurface, self).__init__()
434
+
435
+ self.device=device
436
+
437
+ def __buildStandardSphere(self,npoints=50): # Fibonacci lattice
438
+
439
+ goldenRatio = (1 + 5 ** 0.5) / 2
440
+ i = torch.arange(0, npoints,device=self.device)
441
+ theta = 2 * math.pi * i / goldenRatio
442
+ phi = torch.acos(1 - 2 * (i + 0.5) / npoints)
443
+
444
+ x, y, z = torch.cos(theta) * torch.sin(phi), torch.sin(theta) * torch.sin(phi), torch.cos(phi)
445
+
446
+ coords=torch.cat([x.unsqueeze(-1),y.unsqueeze(-1),z.unsqueeze(-1)],dim=-1)
447
+ #plot_volume(False,20*coords.unsqueeze(0))
448
+
449
+ return coords
450
+
451
+ def forward(self, coords, radius, maxpoints=5000,external_radius_factor=1.4):
452
+ """
453
+ Function to calculate the surface cloud point representation of macromolecules
454
+
455
+ Parameters
456
+ ----------
457
+ coords : torch.Tensor
458
+ Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
459
+ radius : torch.Tensor
460
+ Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
461
+ maxpoints : int
462
+ number of points per macromolecule in the batch
463
+ external_radius_factor=1.4
464
+ multiplicative factor of the radius in order ot define the place to sample the points around each atom. The higher this value is, the smoother the surface will be
465
+ Returns
466
+ -------
467
+ surfacePointCloud : torch.Tensor
468
+ surface point cloud representation of the macromolecules in the batch. Shape ( batch, channels, numberOfAtoms, 3)
469
+
470
+ """
471
+ padding_mask = ~radius.eq(PADDING_INDEX)
472
+
473
+ batch = coords.shape[0]
474
+ npoints = torch.div(maxpoints,(padding_mask.sum(-1).min() + 1), rounding_mode="floor") * 2 # we ensure that the smallest protein has at least maxpoints points
475
+
476
+ sphere = self.__buildStandardSphere(npoints)
477
+ finalPoints=[]
478
+
479
+ for b in range(batch):
480
+
481
+ distmat = torch.cdist(coords[b][padding_mask[b]].unsqueeze(0), coords[b][padding_mask[b]].unsqueeze(0))
482
+ L=distmat.shape[1]
483
+ AtomSelfContributionMask = torch.eye(L, dtype=torch.bool, device=self.device).unsqueeze(0)
484
+ triangular_mask = ~torch.tril(torch.ones((L, L), dtype=torch.bool, device=self.device), diagonal=-1).unsqueeze(0)
485
+
486
+ #todoMask = (distmat[b].le(5) & (~AtomSelfContributionMask) & triangular_mask).squeeze(0)
487
+ external_radius = radius * external_radius_factor
488
+ todoMask = (distmat[0].le(5) & (~AtomSelfContributionMask)).squeeze(0)
489
+ points = coords[b][padding_mask[b]].unsqueeze(0).unsqueeze(-2) - sphere.unsqueeze(0).unsqueeze(1) * external_radius[b][padding_mask[b]].unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
490
+
491
+ p = points.expand( L, L, npoints, 3)[todoMask]
492
+ c = coords[b][padding_mask[b]].unsqueeze(1).unsqueeze(-2).expand( L, L, points.shape[2], 3)[todoMask]
493
+ r = radius[b][padding_mask[b]].unsqueeze(1).unsqueeze(-2).expand( L, L, points.shape[2])[todoMask]
494
+ occupancy = self.__occupancy(p, c, r)
495
+
496
+ point_index = torch.arange(0,L*npoints,device=self.device).view(L,npoints).unsqueeze(0).expand(L,L,npoints)[todoMask]
497
+ point_occupancy =torch.zeros((L*npoints),dtype=torch.float,device=self.device)
498
+ point_occupancy = point_occupancy.index_put_([point_index.view(-1)], occupancy.view(-1), accumulate=True)
499
+ point_occupancy = (1- torch.exp(point_occupancy))
500
+
501
+ points_on_surfaceMask = point_occupancy.le(0.5)
502
+
503
+ points=points.permute(0,3,1,2).view(3,-1).transpose(0,1)[points_on_surfaceMask]
504
+ random_indices = torch.randint(0, points.shape[0], [maxpoints], device=self.device)
505
+ sampled_points = points[random_indices,:]
506
+
507
+ finalPoints +=[sampled_points]
508
+
509
+ return torch.cat(finalPoints,dim=0)
510
+
511
+ def __occupancy(self, points, coords, radius):
512
+
513
+ dist = torch.norm(points-coords,dim=-1)
514
+
515
+
516
+ sigma=0.93
517
+ exponent = -dist**2/(sigma**2 * radius**2)
518
+ exp_mask = exponent.ge(10)
519
+ exponent = torch.masked_fill(exponent, exp_mask, 10)
520
+
521
+ occupancy_on_points = torch.exp(exponent)
522
+ return torch.log(1-occupancy_on_points)
523
+ return occupancy_on_points
524
+ del exponent
525
+
526
+ AtomSelfContributionMask = torch.eye(L,dtype=torch.bool,device=self.device).unsqueeze(0).expand(batch,L,L)
527
+ occupancy_on_points[AtomSelfContributionMask]=0.0
528
+
529
+ occupancy = (1-torch.exp(torch.log(1-occupancy_on_points).sum(2)))#.sum(dim=-1)/npoints
530
+ #if log_correction:
531
+ # occupancy = -torch.log(occupancy + 1) # log scaling
532
+ return occupancy
533
+
534
+ class PointCloudVolume(torch.nn.Module):
535
+ def __init__(self, device="cpu"):
536
+ """
537
+ Constructor for the CloudPointSurface class, which builds the main PyUUL object for volumetric point cloud.
538
+
539
+ Parameters
540
+ ----------
541
+ device : torch.device
542
+ The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
543
+
544
+
545
+ Returns
546
+ -------
547
+ """
548
+ super(PointCloudVolume, self).__init__()
549
+
550
+ self.device = device
551
+
552
+ def forward(self, coords, radius, maxpoints=500):
553
+
554
+ """
555
+ Function to calculate the volumetric cloud point representation of macromolecules
556
+
557
+ Parameters
558
+ ----------
559
+ coords : torch.Tensor
560
+ Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
561
+ radius : torch.Tensor
562
+ Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
563
+ maxpoints : int
564
+ number of points per macromolecule in the batch
565
+
566
+ Returns
567
+ -------
568
+ PointCloudVolume : torch.Tensor
569
+ volume point cloud representation of the macromolecules in the batch. Shape ( batch, channels, numberOfAtoms, 3)
570
+
571
+ """
572
+
573
+ padding_mask = ~radius.eq(PADDING_INDEX)
574
+
575
+ #npoints = torch.div(maxpoints, padding_mask.sum(-1).min()) + 1 # we ensure that the smallest protein has at least 5000 points
576
+
577
+ batch = coords.shape[0]
578
+ L = coords.shape[1]
579
+
580
+ batched = []
581
+ for i in range(batch):
582
+ mean = coords[i][padding_mask[i]]
583
+
584
+ sampled = radius[i][padding_mask[i]].sqrt().unsqueeze(-1) * torch.randn((mean.size()), device=self.device) + mean
585
+ p = sampled.view(-1,3)
586
+ random_indices = torch.randint(0, p.shape[0], [maxpoints], device=self.device)
587
+ batched+=[p[random_indices].unsqueeze(0)]
588
+
589
+ batched = torch.cat(batched,dim=0)
590
+ return batched
591
+