Upload VolumeMaker.py
Browse files- VolumeMaker.py +591 -0
VolumeMaker.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
#
|
4 |
+
# Copyright 2021 Gabriele Orlando
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import torch,math
|
19 |
+
from pyuul.sources.globalVariables import *
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import random
|
23 |
+
|
24 |
+
def setup_seed(seed):
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
torch.cuda.manual_seed_all(seed)
|
27 |
+
np.random.seed(seed)
|
28 |
+
random.seed(seed)
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
setup_seed(100)
|
31 |
+
|
32 |
+
class Voxels(torch.nn.Module):
|
33 |
+
|
34 |
+
def __init__(self, device=torch.device("cpu"),sparse=True):
|
35 |
+
"""
|
36 |
+
Constructor for the Voxels class, which builds the main PyUUL object.
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
|
41 |
+
device : torch.device
|
42 |
+
The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
|
43 |
+
sparse : bool
|
44 |
+
Use sparse tensors calculation when possible
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
"""
|
49 |
+
super(Voxels, self).__init__()
|
50 |
+
|
51 |
+
self.sparse=sparse
|
52 |
+
self.boxsize = None
|
53 |
+
self.dev = device
|
54 |
+
|
55 |
+
def __transform_coordinates(self,coords,radius=None):
|
56 |
+
"""
|
57 |
+
Private function that transform the coordinates to fit them in the 3d box. It also takes care of the resolution.
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
coords : torch.Tensor
|
62 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
63 |
+
radius : torch.Tensor or None
|
64 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms )
|
65 |
+
|
66 |
+
Returns
|
67 |
+
-------
|
68 |
+
coords : torch.Tensor
|
69 |
+
transformed coordinates
|
70 |
+
|
71 |
+
"""
|
72 |
+
coords = (coords*self.dilatation)- self.translation
|
73 |
+
if not radius is None:
|
74 |
+
radius = radius*self.dilatation
|
75 |
+
return coords,radius
|
76 |
+
else:
|
77 |
+
return coords
|
78 |
+
'''
|
79 |
+
def get_coords_voxel(self, voxel_indices, resolution):
|
80 |
+
"""
|
81 |
+
returns the coordinates of the center of the voxel provided its indices.
|
82 |
+
|
83 |
+
Parameters
|
84 |
+
----------
|
85 |
+
voxel_indices : torch.Tensor
|
86 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
87 |
+
resolution : torch.Tensor or None
|
88 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms )
|
89 |
+
|
90 |
+
Returns
|
91 |
+
-------
|
92 |
+
"""
|
93 |
+
#voxel_indices is a n,3 long tensor
|
94 |
+
centersCoords = voxel_indices + 0.5*resolution
|
95 |
+
return (centersCoords + self.translation)/self.dilatation
|
96 |
+
'''
|
97 |
+
def __define_spatial_conformation(self,mincoords,cubes_around_atoms_dim,resolution):
|
98 |
+
"""
|
99 |
+
Private function that defines the space of the volume. Takes resolution and margins into consideration.
|
100 |
+
|
101 |
+
Parameters
|
102 |
+
----------
|
103 |
+
mincoords : torch.Tensor
|
104 |
+
minimum coordinates of each macromolecule of the batch. Shape ( batch, 3 )
|
105 |
+
cubes_around_atoms_dim : int
|
106 |
+
maximum distance in number of voxels to check for atom contribution to occupancy of a voxel
|
107 |
+
resolution : float
|
108 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
109 |
+
Returns
|
110 |
+
-------
|
111 |
+
"""
|
112 |
+
self.translation=(mincoords-(cubes_around_atoms_dim)).unsqueeze(1)
|
113 |
+
self.dilatation = 1.0/resolution
|
114 |
+
|
115 |
+
'''
|
116 |
+
def find_cubes_indices(self,coords):
|
117 |
+
coords_scaled = self.transform_coordinates(coords)
|
118 |
+
return torch.trunc(coords_scaled.data).long()
|
119 |
+
'''
|
120 |
+
|
121 |
+
def forward( self,coords, radius,channels,numberchannels=None,resolution=1, cubes_around_atoms_dim=5, steepness=10,function="sigmoid"):
|
122 |
+
"""
|
123 |
+
Voxels representation of the macromolecules
|
124 |
+
|
125 |
+
Parameters
|
126 |
+
----------
|
127 |
+
coords : torch.Tensor
|
128 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
|
129 |
+
radius : torch.Tensor
|
130 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
131 |
+
channels: torch.LongTensor
|
132 |
+
channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
|
133 |
+
numberchannels : int or None
|
134 |
+
maximum number of channels. if None, max(atNameHashing) + 1 is used
|
135 |
+
|
136 |
+
cubes_around_atoms_dim : int
|
137 |
+
maximum distance in number of voxels for which the contribution to occupancy is taken into consideration. Every atom that is farer than cubes_around_atoms_dim voxels from the center of a voxel does no give any contribution to the relative voxel occupancy
|
138 |
+
resolution : float
|
139 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
140 |
+
|
141 |
+
steepness : float or int
|
142 |
+
steepness of the sigmoid occupancy function.
|
143 |
+
|
144 |
+
function : "sigmoid" or "gaussian"
|
145 |
+
occupancy function to use. Can be sigmoid (every atom has a sigmoid shaped occupancy function) or gaussian (based on Li et al. 2014)
|
146 |
+
Returns
|
147 |
+
-------
|
148 |
+
volume : torch.Tensor
|
149 |
+
voxel representation of the macromolecules in the batch. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
|
150 |
+
|
151 |
+
"""
|
152 |
+
padding_mask = ~channels.eq(PADDING_INDEX)
|
153 |
+
if numberchannels is None:
|
154 |
+
numberchannels = int(channels[padding_mask].max().cpu().data+1)
|
155 |
+
self.featureVectorSize = numberchannels
|
156 |
+
self.function = function
|
157 |
+
|
158 |
+
arange_type = torch.int16
|
159 |
+
|
160 |
+
gx = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
|
161 |
+
gy = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
|
162 |
+
gz = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
|
163 |
+
self.lato = gx.shape[0]
|
164 |
+
|
165 |
+
x1 = gx.unsqueeze(1).expand(self.lato, self.lato).unsqueeze(-1)
|
166 |
+
x2 = gy.unsqueeze(0).expand(self.lato, self.lato).unsqueeze(-1)
|
167 |
+
|
168 |
+
xy = torch.cat([x1, x2], dim=-1).unsqueeze(2).expand(self.lato, self.lato, self.lato, 2)
|
169 |
+
x3 = gz.unsqueeze(0).unsqueeze(1).expand(self.lato, self.lato, self.lato).unsqueeze(-1)
|
170 |
+
|
171 |
+
del gx, gy, gz, x1, x2
|
172 |
+
|
173 |
+
self.standard_cube = torch.cat([xy, x3], dim=-1).unsqueeze(0).unsqueeze(0)
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
### definition of the box ###
|
178 |
+
# you take the maximum and min coord on each dimension (every prot in the batch shares the same box. In the future we can pack, but I think this is not the bottleneck)
|
179 |
+
# I scale by resolution
|
180 |
+
# I add the cubes in which I define the gradient. One in the beginning and one at the end --> 2*
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
mincoords = torch.min(coords[:, :, :], dim=1)[0]
|
185 |
+
mincoords = torch.trunc(mincoords / resolution)
|
186 |
+
|
187 |
+
|
188 |
+
box_size_x = (math.ceil(torch.max(coords[padding_mask][:,0])/resolution)-mincoords[:,0].min())+(2*cubes_around_atoms_dim+1)
|
189 |
+
box_size_y = (math.ceil(torch.max(coords[padding_mask][:,1])/resolution)-mincoords[:,1].min())+(2*cubes_around_atoms_dim+1)
|
190 |
+
box_size_z = (math.ceil(torch.max(coords[padding_mask][:,2])/resolution)-mincoords[:,2].min())+(2*cubes_around_atoms_dim+1)
|
191 |
+
#############################
|
192 |
+
|
193 |
+
self.__define_spatial_conformation(mincoords,cubes_around_atoms_dim,resolution) #define the spatial transforms to coordinates
|
194 |
+
coords,radius = self.__transform_coordinates(coords,radius)
|
195 |
+
|
196 |
+
boxsize = (int(box_size_x),int(box_size_y),int(box_size_z))
|
197 |
+
self.boxsize=boxsize
|
198 |
+
|
199 |
+
#selecting best types for indexing
|
200 |
+
if max(boxsize)<256: # i can use byte tensor
|
201 |
+
self.dtype_indices=torch.uint8
|
202 |
+
else:
|
203 |
+
self.dtype_indices = torch.int16
|
204 |
+
|
205 |
+
if self.function=="sigmoid":
|
206 |
+
volume = self.__forward_actual_calculation(coords, boxsize, radius, channels,padding_mask,steepness,resolution)
|
207 |
+
elif self.function=="gaussian":
|
208 |
+
volume = self.__forward_actual_calculationGaussian(coords, boxsize, radius, channels, padding_mask,resolution)
|
209 |
+
return volume
|
210 |
+
|
211 |
+
def __forward_actual_calculationGaussian(self, coords_scaled, boxsize, radius, atNameHashing, padding_mask,resolution):
|
212 |
+
"""
|
213 |
+
private function for the calculation of the gaussian voxel occupancy
|
214 |
+
|
215 |
+
Parameters
|
216 |
+
----------
|
217 |
+
coords_scaled : torch.LongTensor
|
218 |
+
Discrete Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
219 |
+
boxsize : torch.LongTensor
|
220 |
+
The size of the box in which the macromolecules are represented
|
221 |
+
radius : torch.Tensor
|
222 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
223 |
+
atNameHashing: torch.LongTensor
|
224 |
+
channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
|
225 |
+
resolution : float
|
226 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
227 |
+
padding_mask : torch.BoolTensor
|
228 |
+
tensor to mask the padding. Shape (batch, numberOfAtoms)
|
229 |
+
Returns
|
230 |
+
-------
|
231 |
+
volume : torch.Tensor
|
232 |
+
voxel representation of the macromolecules in the batch with Gaussian occupancy function. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
|
233 |
+
|
234 |
+
"""
|
235 |
+
batch = coords_scaled.shape[0]
|
236 |
+
dev = self.dev
|
237 |
+
L = coords_scaled.shape[1]
|
238 |
+
|
239 |
+
discrete_coordinates = torch.trunc(coords_scaled.data).to(self.dtype_indices)
|
240 |
+
|
241 |
+
#### making everything in the volume shape
|
242 |
+
|
243 |
+
# implicit_cube_formation
|
244 |
+
radius = radius.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
245 |
+
atNameHashing = atNameHashing.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
246 |
+
coords_scaled = coords_scaled.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
247 |
+
discrete_coordinates = discrete_coordinates.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
248 |
+
distmat_standard_cube = torch.norm(
|
249 |
+
coords_scaled - ((discrete_coordinates + self.standard_cube + 1) + 0.5 * resolution), dim=-1).to(
|
250 |
+
coords_scaled.dtype)
|
251 |
+
|
252 |
+
atNameHashing = atNameHashing.long()
|
253 |
+
#### old sigmoid stuff
|
254 |
+
'''
|
255 |
+
exponent = self.steepness*(distmat_standard_cube-radius)
|
256 |
+
|
257 |
+
exp_mask = exponent.ge(10)
|
258 |
+
exponent = torch.masked_fill(exponent,exp_mask, 10)
|
259 |
+
|
260 |
+
volume_cubes = 1.0/(1.0+torch.exp(exponent))
|
261 |
+
'''
|
262 |
+
### from doi: 10.1142/S0219633614400021 eq 1
|
263 |
+
sigma = 0.93
|
264 |
+
exponent = -distmat_standard_cube[padding_mask] ** 2 / (sigma ** 2 * radius[padding_mask] ** 2)
|
265 |
+
exp_mask = exponent.ge(10)
|
266 |
+
exponent = torch.masked_fill(exponent, exp_mask, 10)
|
267 |
+
volume_cubes = torch.exp(exponent)
|
268 |
+
|
269 |
+
#### index_put everything ###
|
270 |
+
batch_list = torch.arange(batch,device=dev).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).expand(batch,L,self.lato,self.lato,self.lato)
|
271 |
+
|
272 |
+
cubes_coords = (discrete_coordinates[padding_mask] + self.standard_cube.squeeze(0) + 1)[~exp_mask]
|
273 |
+
atNameHashing = atNameHashing[padding_mask].expand(-1,self.lato,self.lato,self.lato)
|
274 |
+
if self.sparse:
|
275 |
+
|
276 |
+
index_tens = torch.cat(
|
277 |
+
[batch_list[padding_mask][~exp_mask].view(-1).unsqueeze(0),
|
278 |
+
atNameHashing[~exp_mask].unsqueeze(0),
|
279 |
+
cubes_coords[:,0].unsqueeze(0),
|
280 |
+
cubes_coords[:,1].unsqueeze(0),
|
281 |
+
cubes_coords[:,2].unsqueeze(0),
|
282 |
+
])
|
283 |
+
#index_tens = torch.cat(index)
|
284 |
+
|
285 |
+
volume_cubes = volume_cubes[~exp_mask].view(-1)
|
286 |
+
volume_cubes = torch.log(1 - volume_cubes.contiguous())
|
287 |
+
#powOrExpIsNotImplementedInSparse
|
288 |
+
volume = torch.sparse_coo_tensor(indices=index_tens, values=volume_cubes.exp(), size=[batch, self.featureVectorSize, boxsize[0] , boxsize[1] , boxsize[2] ]).coalesce()
|
289 |
+
volume = torch.sparse_coo_tensor(volume.indices(),1 - volume.values(), volume.shape)
|
290 |
+
|
291 |
+
else:
|
292 |
+
volume = torch.zeros(batch,boxsize[0]+1,boxsize[1]+1,boxsize[2]+1,self.featureVectorSize,device=dev,dtype=torch.float)
|
293 |
+
#index = (batch_list[padding_mask].view(-1),cubes_coords[padding_mask][:,:,:,:,0].view(-1), cubes_coords[padding_mask][:,:,:,:,1].view(-1), cubes_coords[padding_mask][:,:,:,:,2].view(-1), atNameHashing[padding_mask].view(-1) )
|
294 |
+
index = (batch_list[padding_mask][~exp_mask].view(-1).long(),
|
295 |
+
cubes_coords[:,0].long(),
|
296 |
+
cubes_coords[:,1].long(),
|
297 |
+
cubes_coords[:,2].long(),
|
298 |
+
atNameHashing[~exp_mask])
|
299 |
+
volume_cubes=volume_cubes[~exp_mask].view(-1)
|
300 |
+
|
301 |
+
volume_cubes = torch.log(1 - volume_cubes.contiguous())
|
302 |
+
volume = 1- torch.exp(volume.index_put(index,volume_cubes,accumulate=True))
|
303 |
+
#volume = 1 - torch.exp(volume.index_put(index, torch.log(1 - volume_cubes.contiguous().view(-1)), accumulate=True))
|
304 |
+
volume=volume.permute(0,4,1,2,3)
|
305 |
+
#volume = -torch.nn.functional.threshold(-volume,-1,-1)
|
306 |
+
|
307 |
+
return volume
|
308 |
+
|
309 |
+
|
310 |
+
|
311 |
+
return volume
|
312 |
+
|
313 |
+
def __sparseClamp(self,volume, minv, maxv):
|
314 |
+
vals = volume.values()
|
315 |
+
ind = volume.indices()
|
316 |
+
|
317 |
+
vals = vals.clamp(minv, maxv)
|
318 |
+
volume = torch.sparse_coo_tensor(indices=ind, values=vals, size=volume.shape).coalesce()
|
319 |
+
return volume
|
320 |
+
|
321 |
+
def __forward_actual_calculation(self, coords_scaled, boxsize, radius,atNameHashing,padding_mask,steepness,resolution):
|
322 |
+
"""
|
323 |
+
private function for the calculation of the gaussian voxel occupancy
|
324 |
+
|
325 |
+
Parameters
|
326 |
+
----------
|
327 |
+
coords_scaled : torch.LongTensor
|
328 |
+
Discrete Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
329 |
+
boxsize : torch.LongTensor
|
330 |
+
The size of the box in which the macromolecules are represented
|
331 |
+
radius : torch.Tensor
|
332 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
333 |
+
atNameHashing: torch.LongTensor
|
334 |
+
channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
|
335 |
+
resolution : float
|
336 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
337 |
+
padding_mask : torch.BoolTensor
|
338 |
+
tensor to mask the padding. Shape (batch, numberOfAtoms)
|
339 |
+
steepness : float
|
340 |
+
steepness of the sigmoid function (coefficient of the exponent)
|
341 |
+
|
342 |
+
Returns
|
343 |
+
-------
|
344 |
+
volume : torch.Tensor
|
345 |
+
voxel representation of the macromolecules in the batch with Sigmoid occupancy function. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
|
346 |
+
|
347 |
+
"""
|
348 |
+
batch = coords_scaled.shape[0]
|
349 |
+
dev=self.dev
|
350 |
+
L = coords_scaled.shape[1]
|
351 |
+
|
352 |
+
discrete_coordinates = torch.trunc(coords_scaled.data).to(self.dtype_indices)
|
353 |
+
|
354 |
+
#### making everything in the volume shape
|
355 |
+
|
356 |
+
#implicit_cube_formation
|
357 |
+
radius = radius.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
358 |
+
atNameHashing = atNameHashing.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
359 |
+
coords_scaled = coords_scaled.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
360 |
+
discrete_coordinates = discrete_coordinates.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
361 |
+
distmat_standard_cube = torch.norm(coords_scaled-((discrete_coordinates + self.standard_cube + 1) + 0.5 * resolution), dim=-1).to(coords_scaled.dtype)
|
362 |
+
|
363 |
+
atNameHashing = atNameHashing.long()
|
364 |
+
|
365 |
+
exponent = steepness*(distmat_standard_cube[padding_mask]-radius[padding_mask])
|
366 |
+
del distmat_standard_cube
|
367 |
+
exp_mask = exponent.ge(10)
|
368 |
+
exponent = torch.masked_fill(exponent,exp_mask, 10)
|
369 |
+
|
370 |
+
volume_cubes = 1.0/(1.0+torch.exp(exponent))
|
371 |
+
|
372 |
+
#### index_put everything ###
|
373 |
+
batch_list = torch.arange(batch,device=dev).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).expand(batch,L,self.lato,self.lato,self.lato)
|
374 |
+
|
375 |
+
#cubes_coords = coords_scaled + self.standard_cube + 1
|
376 |
+
cubes_coords = (discrete_coordinates[padding_mask] + self.standard_cube.squeeze(0) + 1)[~exp_mask]
|
377 |
+
atNameHashing = atNameHashing[padding_mask].expand(-1,self.lato,self.lato,self.lato)
|
378 |
+
if self.sparse:
|
379 |
+
|
380 |
+
index_tens = torch.cat(
|
381 |
+
[batch_list[padding_mask][~exp_mask].view(-1).unsqueeze(0),
|
382 |
+
atNameHashing[~exp_mask].unsqueeze(0),
|
383 |
+
cubes_coords[:,0].unsqueeze(0),
|
384 |
+
cubes_coords[:,1].unsqueeze(0),
|
385 |
+
cubes_coords[:,2].unsqueeze(0),
|
386 |
+
])
|
387 |
+
#index_tens = torch.cat(index)
|
388 |
+
volume = torch.sparse_coo_tensor(indices=index_tens, values=volume_cubes[~exp_mask].view(-1), size=[batch, self.featureVectorSize, boxsize[0] , boxsize[1] , boxsize[2] ]).coalesce()
|
389 |
+
volume = self.__sparseClamp(volume,0,1)
|
390 |
+
|
391 |
+
else:
|
392 |
+
volume = torch.zeros(batch,boxsize[0]+1,boxsize[1]+1,boxsize[2]+1,self.featureVectorSize,device=dev,dtype=torch.float)
|
393 |
+
#index = (batch_list[padding_mask].view(-1),cubes_coords[padding_mask][:,:,:,:,0].view(-1), cubes_coords[padding_mask][:,:,:,:,1].view(-1), cubes_coords[padding_mask][:,:,:,:,2].view(-1), atNameHashing[padding_mask].view(-1) )
|
394 |
+
index = (batch_list[padding_mask][~exp_mask].view(-1).long(),
|
395 |
+
cubes_coords[:,0].long(),
|
396 |
+
cubes_coords[:,1].long(),
|
397 |
+
cubes_coords[:,2].long(),
|
398 |
+
atNameHashing[~exp_mask])
|
399 |
+
volume_cubes=volume_cubes[~exp_mask].view(-1)
|
400 |
+
|
401 |
+
volume = volume.index_put(index,volume_cubes.view(-1),accumulate=True)
|
402 |
+
|
403 |
+
volume = -torch.nn.functional.threshold(-volume,-1,-1)
|
404 |
+
volume = volume.permute(0,4,1,2,3)
|
405 |
+
|
406 |
+
return volume
|
407 |
+
'''
|
408 |
+
mesh will be added as soon as pytorch3d becomes a little more stable
|
409 |
+
def mesh(self,coords, radius,threshSurface = 0.01):
|
410 |
+
|
411 |
+
atNameHashing= torch.zeros(radius.shape).to(self.dev)
|
412 |
+
mask = radius.eq(PADDING_INDEX)
|
413 |
+
atNameHashing = atNameHashing.masked_fill_(mask,PADDING_INDEX)
|
414 |
+
vol = self(coords,radius,atNameHashing).to_dense()
|
415 |
+
mesh = cubifyNOALIGN(vol.sum(-1),thresh=threshSurface)# creates pytorch 3d mesh from cubes. It uses a MODIFIED version of pytorch3d with no align
|
416 |
+
return mesh
|
417 |
+
'''
|
418 |
+
|
419 |
+
class PointCloudSurface(torch.nn.Module):
|
420 |
+
def __init__(self,device="cpu"):
|
421 |
+
"""
|
422 |
+
Constructor for the CloudPointSurface class, which builds the main PyUUL object for cloud surface.
|
423 |
+
|
424 |
+
Parameters
|
425 |
+
----------
|
426 |
+
device : torch.device
|
427 |
+
The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
|
428 |
+
|
429 |
+
|
430 |
+
Returns
|
431 |
+
-------
|
432 |
+
"""
|
433 |
+
super(PointCloudSurface, self).__init__()
|
434 |
+
|
435 |
+
self.device=device
|
436 |
+
|
437 |
+
def __buildStandardSphere(self,npoints=50): # Fibonacci lattice
|
438 |
+
|
439 |
+
goldenRatio = (1 + 5 ** 0.5) / 2
|
440 |
+
i = torch.arange(0, npoints,device=self.device)
|
441 |
+
theta = 2 * math.pi * i / goldenRatio
|
442 |
+
phi = torch.acos(1 - 2 * (i + 0.5) / npoints)
|
443 |
+
|
444 |
+
x, y, z = torch.cos(theta) * torch.sin(phi), torch.sin(theta) * torch.sin(phi), torch.cos(phi)
|
445 |
+
|
446 |
+
coords=torch.cat([x.unsqueeze(-1),y.unsqueeze(-1),z.unsqueeze(-1)],dim=-1)
|
447 |
+
#plot_volume(False,20*coords.unsqueeze(0))
|
448 |
+
|
449 |
+
return coords
|
450 |
+
|
451 |
+
def forward(self, coords, radius, maxpoints=5000,external_radius_factor=1.4):
|
452 |
+
"""
|
453 |
+
Function to calculate the surface cloud point representation of macromolecules
|
454 |
+
|
455 |
+
Parameters
|
456 |
+
----------
|
457 |
+
coords : torch.Tensor
|
458 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
|
459 |
+
radius : torch.Tensor
|
460 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
461 |
+
maxpoints : int
|
462 |
+
number of points per macromolecule in the batch
|
463 |
+
external_radius_factor=1.4
|
464 |
+
multiplicative factor of the radius in order ot define the place to sample the points around each atom. The higher this value is, the smoother the surface will be
|
465 |
+
Returns
|
466 |
+
-------
|
467 |
+
surfacePointCloud : torch.Tensor
|
468 |
+
surface point cloud representation of the macromolecules in the batch. Shape ( batch, channels, numberOfAtoms, 3)
|
469 |
+
|
470 |
+
"""
|
471 |
+
padding_mask = ~radius.eq(PADDING_INDEX)
|
472 |
+
|
473 |
+
batch = coords.shape[0]
|
474 |
+
npoints = torch.div(maxpoints,(padding_mask.sum(-1).min() + 1), rounding_mode="floor") * 2 # we ensure that the smallest protein has at least maxpoints points
|
475 |
+
|
476 |
+
sphere = self.__buildStandardSphere(npoints)
|
477 |
+
finalPoints=[]
|
478 |
+
|
479 |
+
for b in range(batch):
|
480 |
+
|
481 |
+
distmat = torch.cdist(coords[b][padding_mask[b]].unsqueeze(0), coords[b][padding_mask[b]].unsqueeze(0))
|
482 |
+
L=distmat.shape[1]
|
483 |
+
AtomSelfContributionMask = torch.eye(L, dtype=torch.bool, device=self.device).unsqueeze(0)
|
484 |
+
triangular_mask = ~torch.tril(torch.ones((L, L), dtype=torch.bool, device=self.device), diagonal=-1).unsqueeze(0)
|
485 |
+
|
486 |
+
#todoMask = (distmat[b].le(5) & (~AtomSelfContributionMask) & triangular_mask).squeeze(0)
|
487 |
+
external_radius = radius * external_radius_factor
|
488 |
+
todoMask = (distmat[0].le(5) & (~AtomSelfContributionMask)).squeeze(0)
|
489 |
+
points = coords[b][padding_mask[b]].unsqueeze(0).unsqueeze(-2) - sphere.unsqueeze(0).unsqueeze(1) * external_radius[b][padding_mask[b]].unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
490 |
+
|
491 |
+
p = points.expand( L, L, npoints, 3)[todoMask]
|
492 |
+
c = coords[b][padding_mask[b]].unsqueeze(1).unsqueeze(-2).expand( L, L, points.shape[2], 3)[todoMask]
|
493 |
+
r = radius[b][padding_mask[b]].unsqueeze(1).unsqueeze(-2).expand( L, L, points.shape[2])[todoMask]
|
494 |
+
occupancy = self.__occupancy(p, c, r)
|
495 |
+
|
496 |
+
point_index = torch.arange(0,L*npoints,device=self.device).view(L,npoints).unsqueeze(0).expand(L,L,npoints)[todoMask]
|
497 |
+
point_occupancy =torch.zeros((L*npoints),dtype=torch.float,device=self.device)
|
498 |
+
point_occupancy = point_occupancy.index_put_([point_index.view(-1)], occupancy.view(-1), accumulate=True)
|
499 |
+
point_occupancy = (1- torch.exp(point_occupancy))
|
500 |
+
|
501 |
+
points_on_surfaceMask = point_occupancy.le(0.5)
|
502 |
+
|
503 |
+
points=points.permute(0,3,1,2).view(3,-1).transpose(0,1)[points_on_surfaceMask]
|
504 |
+
random_indices = torch.randint(0, points.shape[0], [maxpoints], device=self.device)
|
505 |
+
sampled_points = points[random_indices,:]
|
506 |
+
|
507 |
+
finalPoints +=[sampled_points]
|
508 |
+
|
509 |
+
return torch.cat(finalPoints,dim=0)
|
510 |
+
|
511 |
+
def __occupancy(self, points, coords, radius):
|
512 |
+
|
513 |
+
dist = torch.norm(points-coords,dim=-1)
|
514 |
+
|
515 |
+
|
516 |
+
sigma=0.93
|
517 |
+
exponent = -dist**2/(sigma**2 * radius**2)
|
518 |
+
exp_mask = exponent.ge(10)
|
519 |
+
exponent = torch.masked_fill(exponent, exp_mask, 10)
|
520 |
+
|
521 |
+
occupancy_on_points = torch.exp(exponent)
|
522 |
+
return torch.log(1-occupancy_on_points)
|
523 |
+
return occupancy_on_points
|
524 |
+
del exponent
|
525 |
+
|
526 |
+
AtomSelfContributionMask = torch.eye(L,dtype=torch.bool,device=self.device).unsqueeze(0).expand(batch,L,L)
|
527 |
+
occupancy_on_points[AtomSelfContributionMask]=0.0
|
528 |
+
|
529 |
+
occupancy = (1-torch.exp(torch.log(1-occupancy_on_points).sum(2)))#.sum(dim=-1)/npoints
|
530 |
+
#if log_correction:
|
531 |
+
# occupancy = -torch.log(occupancy + 1) # log scaling
|
532 |
+
return occupancy
|
533 |
+
|
534 |
+
class PointCloudVolume(torch.nn.Module):
|
535 |
+
def __init__(self, device="cpu"):
|
536 |
+
"""
|
537 |
+
Constructor for the CloudPointSurface class, which builds the main PyUUL object for volumetric point cloud.
|
538 |
+
|
539 |
+
Parameters
|
540 |
+
----------
|
541 |
+
device : torch.device
|
542 |
+
The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
|
543 |
+
|
544 |
+
|
545 |
+
Returns
|
546 |
+
-------
|
547 |
+
"""
|
548 |
+
super(PointCloudVolume, self).__init__()
|
549 |
+
|
550 |
+
self.device = device
|
551 |
+
|
552 |
+
def forward(self, coords, radius, maxpoints=500):
|
553 |
+
|
554 |
+
"""
|
555 |
+
Function to calculate the volumetric cloud point representation of macromolecules
|
556 |
+
|
557 |
+
Parameters
|
558 |
+
----------
|
559 |
+
coords : torch.Tensor
|
560 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
|
561 |
+
radius : torch.Tensor
|
562 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
563 |
+
maxpoints : int
|
564 |
+
number of points per macromolecule in the batch
|
565 |
+
|
566 |
+
Returns
|
567 |
+
-------
|
568 |
+
PointCloudVolume : torch.Tensor
|
569 |
+
volume point cloud representation of the macromolecules in the batch. Shape ( batch, channels, numberOfAtoms, 3)
|
570 |
+
|
571 |
+
"""
|
572 |
+
|
573 |
+
padding_mask = ~radius.eq(PADDING_INDEX)
|
574 |
+
|
575 |
+
#npoints = torch.div(maxpoints, padding_mask.sum(-1).min()) + 1 # we ensure that the smallest protein has at least 5000 points
|
576 |
+
|
577 |
+
batch = coords.shape[0]
|
578 |
+
L = coords.shape[1]
|
579 |
+
|
580 |
+
batched = []
|
581 |
+
for i in range(batch):
|
582 |
+
mean = coords[i][padding_mask[i]]
|
583 |
+
|
584 |
+
sampled = radius[i][padding_mask[i]].sqrt().unsqueeze(-1) * torch.randn((mean.size()), device=self.device) + mean
|
585 |
+
p = sampled.view(-1,3)
|
586 |
+
random_indices = torch.randint(0, p.shape[0], [maxpoints], device=self.device)
|
587 |
+
batched+=[p[random_indices].unsqueeze(0)]
|
588 |
+
|
589 |
+
batched = torch.cat(batched,dim=0)
|
590 |
+
return batched
|
591 |
+
|