Spaces:
Running
Running
add codes
Browse files- models.py +392 -0
- new_dataloader.py +349 -0
- requirements.txt +8 -0
- trainer.py +892 -0
- training_data.py +50 -0
- utils.py +462 -0
models.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from layers import TransformerEncoder, TransformerDecoder
|
5 |
+
|
6 |
+
class Generator(nn.Module):
|
7 |
+
"""Generator network."""
|
8 |
+
def __init__(self, z_dim, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio, submodel):
|
9 |
+
super(Generator, self).__init__()
|
10 |
+
|
11 |
+
self.submodel = submodel
|
12 |
+
self.vertexes = vertexes
|
13 |
+
self.edges = edges
|
14 |
+
self.nodes = nodes
|
15 |
+
self.depth = depth
|
16 |
+
self.dim = dim
|
17 |
+
self.heads = heads
|
18 |
+
self.mlp_ratio = mlp_ratio
|
19 |
+
|
20 |
+
self.dropout = dropout
|
21 |
+
self.z_dim = z_dim
|
22 |
+
|
23 |
+
if act == "relu":
|
24 |
+
act = nn.ReLU()
|
25 |
+
elif act == "leaky":
|
26 |
+
act = nn.LeakyReLU()
|
27 |
+
elif act == "sigmoid":
|
28 |
+
act = nn.Sigmoid()
|
29 |
+
elif act == "tanh":
|
30 |
+
act = nn.Tanh()
|
31 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
32 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
33 |
+
self.pos_enc_dim = 5
|
34 |
+
#self.pos_enc = nn.Linear(self.pos_enc_dim, self.dim)
|
35 |
+
|
36 |
+
self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
|
37 |
+
self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
|
38 |
+
|
39 |
+
self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
|
40 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
|
41 |
+
|
42 |
+
self.readout_e = nn.Linear(self.dim, edges)
|
43 |
+
self.readout_n = nn.Linear(self.dim, nodes)
|
44 |
+
self.softmax = nn.Softmax(dim = -1)
|
45 |
+
|
46 |
+
def _generate_square_subsequent_mask(self, sz):
|
47 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
48 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
49 |
+
return mask
|
50 |
+
|
51 |
+
def laplacian_positional_enc(self, adj):
|
52 |
+
|
53 |
+
A = adj
|
54 |
+
D = torch.diag(torch.count_nonzero(A, dim=-1))
|
55 |
+
L = torch.eye(A.shape[0], device=A.device) - D * A * D
|
56 |
+
|
57 |
+
EigVal, EigVec = torch.linalg.eig(L)
|
58 |
+
|
59 |
+
idx = torch.argsort(torch.real(EigVal))
|
60 |
+
EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
|
61 |
+
pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
|
62 |
+
|
63 |
+
return pos_enc
|
64 |
+
|
65 |
+
def forward(self, z_e, z_n):
|
66 |
+
b, n, c = z_n.shape
|
67 |
+
_, _, _ , d = z_e.shape
|
68 |
+
#random_mask_e = torch.randint(low=0,high=2,size=(b,n,n,d)).to(z_e.device).float()
|
69 |
+
#random_mask_n = torch.randint(low=0,high=2,size=(b,n,c)).to(z_n.device).float()
|
70 |
+
#z_e = F.relu(z_e - random_mask_e)
|
71 |
+
#z_n = F.relu(z_n - random_mask_n)
|
72 |
+
|
73 |
+
#mask = self._generate_square_subsequent_mask(self.vertexes).to(z_e.device)
|
74 |
+
|
75 |
+
node = self.node_layers(z_n)
|
76 |
+
|
77 |
+
edge = self.edge_layers(z_e)
|
78 |
+
|
79 |
+
edge = (edge + edge.permute(0,2,1,3))/2
|
80 |
+
|
81 |
+
#lap = [self.laplacian_positional_enc(torch.max(x,-1)[1]) for x in edge]
|
82 |
+
|
83 |
+
#lap = torch.stack(lap).to(node.device)
|
84 |
+
|
85 |
+
#pos_enc = self.pos_enc(lap)
|
86 |
+
|
87 |
+
#node = node + pos_enc
|
88 |
+
|
89 |
+
node, edge = self.TransformerEncoder(node,edge)
|
90 |
+
|
91 |
+
node_sample = self.softmax(self.readout_n(node))
|
92 |
+
|
93 |
+
edge_sample = self.softmax(self.readout_e(edge))
|
94 |
+
|
95 |
+
return node, edge, node_sample, edge_sample
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
class Generator2(nn.Module):
|
100 |
+
def __init__(self, dim, dec_dim, depth, heads, mlp_ratio, drop_rate, drugs_m_dim, drugs_b_dim, submodel):
|
101 |
+
super().__init__()
|
102 |
+
self.submodel = submodel
|
103 |
+
self.depth = depth
|
104 |
+
self.dim = dim
|
105 |
+
self.mlp_ratio = mlp_ratio
|
106 |
+
self.heads = heads
|
107 |
+
self.dropout_rate = drop_rate
|
108 |
+
self.drugs_m_dim = drugs_m_dim
|
109 |
+
self.drugs_b_dim = drugs_b_dim
|
110 |
+
|
111 |
+
self.pos_enc_dim = 5
|
112 |
+
|
113 |
+
|
114 |
+
if self.submodel == "Prot":
|
115 |
+
self.prot_n = torch.nn.Linear(3822, 45) ## exact dimension of protein features
|
116 |
+
self.prot_e = torch.nn.Linear(298116, 2025) ## exact dimension of protein features
|
117 |
+
|
118 |
+
self.protn_dim = torch.nn.Linear(1, dec_dim)
|
119 |
+
self.prote_dim = torch.nn.Linear(1, dec_dim)
|
120 |
+
|
121 |
+
|
122 |
+
self.mol_nodes = nn.Linear(dim, dec_dim)
|
123 |
+
self.mol_edges = nn.Linear(dim, dec_dim)
|
124 |
+
|
125 |
+
self.drug_nodes = nn.Linear(self.drugs_m_dim, dec_dim)
|
126 |
+
self.drug_edges = nn.Linear(self.drugs_b_dim, dec_dim)
|
127 |
+
|
128 |
+
self.TransformerDecoder = TransformerDecoder(dec_dim, depth, heads, mlp_ratio, drop_rate=self.dropout_rate)
|
129 |
+
|
130 |
+
self.nodes_output_layer = nn.Linear(dec_dim, self.drugs_m_dim)
|
131 |
+
self.edges_output_layer = nn.Linear(dec_dim, self.drugs_b_dim)
|
132 |
+
self.softmax = nn.Softmax(dim=-1)
|
133 |
+
|
134 |
+
def laplacian_positional_enc(self, adj):
|
135 |
+
|
136 |
+
A = adj
|
137 |
+
D = torch.diag(torch.count_nonzero(A, dim=-1))
|
138 |
+
L = torch.eye(A.shape[0], device=A.device) - D * A * D
|
139 |
+
|
140 |
+
EigVal, EigVec = torch.linalg.eig(L)
|
141 |
+
|
142 |
+
idx = torch.argsort(torch.real(EigVal))
|
143 |
+
EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
|
144 |
+
pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
|
145 |
+
|
146 |
+
return pos_enc
|
147 |
+
|
148 |
+
def _generate_square_subsequent_mask(self, sz):
|
149 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
150 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
151 |
+
return mask
|
152 |
+
|
153 |
+
def forward(self, edges_logits, nodes_logits ,akt1_adj,akt1_annot):
|
154 |
+
|
155 |
+
edges_logits = self.mol_edges(edges_logits)
|
156 |
+
nodes_logits = self.mol_nodes(nodes_logits)
|
157 |
+
|
158 |
+
if self.submodel != "Prot":
|
159 |
+
akt1_annot = self.drug_nodes(akt1_annot)
|
160 |
+
akt1_adj = self.drug_edges(akt1_adj)
|
161 |
+
|
162 |
+
else:
|
163 |
+
akt1_adj = self.prote_dim(self.prot_e(akt1_adj).view(1,45,45,1))
|
164 |
+
akt1_annot = self.protn_dim(self.prot_n(akt1_annot).view(1,45,1))
|
165 |
+
|
166 |
+
|
167 |
+
#lap = [self.laplacian_positional_enc(torch.max(x,-1)[1]) for x in drug_e]
|
168 |
+
#lap = torch.stack(lap).to(drug_e.device)
|
169 |
+
#pos_enc = self.pos_enc(lap)
|
170 |
+
#drug_n = drug_n + pos_enc
|
171 |
+
|
172 |
+
nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(nodes_logits,akt1_annot,edges_logits,akt1_adj)
|
173 |
+
|
174 |
+
edges_logits = self.edges_output_layer(edges_logits)
|
175 |
+
nodes_logits = self.nodes_output_layer(nodes_logits)
|
176 |
+
|
177 |
+
edges_logits = self.softmax(edges_logits)
|
178 |
+
nodes_logits = self.softmax(nodes_logits)
|
179 |
+
|
180 |
+
return edges_logits, nodes_logits
|
181 |
+
|
182 |
+
|
183 |
+
class simple_disc(nn.Module):
|
184 |
+
def __init__(self, act, m_dim, vertexes, b_dim):
|
185 |
+
super().__init__()
|
186 |
+
if act == "relu":
|
187 |
+
act = nn.ReLU()
|
188 |
+
elif act == "leaky":
|
189 |
+
act = nn.LeakyReLU()
|
190 |
+
elif act == "sigmoid":
|
191 |
+
act = nn.Sigmoid()
|
192 |
+
elif act == "tanh":
|
193 |
+
act = nn.Tanh()
|
194 |
+
features = vertexes * m_dim + vertexes * vertexes * b_dim
|
195 |
+
|
196 |
+
self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
|
197 |
+
nn.Linear(64,32), act, nn.Linear(32,16), act,
|
198 |
+
nn.Linear(16,1))
|
199 |
+
|
200 |
+
def forward(self, x):
|
201 |
+
|
202 |
+
prediction = self.predictor(x)
|
203 |
+
|
204 |
+
#prediction = F.softmax(prediction,dim=-1)
|
205 |
+
|
206 |
+
return prediction
|
207 |
+
|
208 |
+
"""class Discriminator(nn.Module):
|
209 |
+
|
210 |
+
def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
|
211 |
+
super(Discriminator, self).__init__()
|
212 |
+
self.degree = deg
|
213 |
+
self.aggregators = agg
|
214 |
+
self.scalers = sca
|
215 |
+
self.pna_in_channels = pna_in_ch
|
216 |
+
self.pna_out_channels = pna_out_ch
|
217 |
+
self.edge_dimension = edge_dim
|
218 |
+
self.towers = towers
|
219 |
+
self.pre_layers_num = pre_lay
|
220 |
+
self.post_layers_num = post_lay
|
221 |
+
self.pna_layer_num = pna_layer_num
|
222 |
+
self.graph_add = graph_add
|
223 |
+
self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
|
224 |
+
pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
|
225 |
+
towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
|
226 |
+
pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
|
227 |
+
|
228 |
+
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
229 |
+
|
230 |
+
h = self.PNA_layer(x, edge_index, edge_attr, batch)
|
231 |
+
|
232 |
+
h = activation(h) if activation is not None else h
|
233 |
+
|
234 |
+
return h"""
|
235 |
+
|
236 |
+
"""class Discriminator2(nn.Module):
|
237 |
+
|
238 |
+
def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
|
239 |
+
super(Discriminator2, self).__init__()
|
240 |
+
self.degree = deg
|
241 |
+
self.aggregators = agg
|
242 |
+
self.scalers = sca
|
243 |
+
self.pna_in_channels = pna_in_ch
|
244 |
+
self.pna_out_channels = pna_out_ch
|
245 |
+
self.edge_dimension = edge_dim
|
246 |
+
self.towers = towers
|
247 |
+
self.pre_layers_num = pre_lay
|
248 |
+
self.post_layers_num = post_lay
|
249 |
+
self.pna_layer_num = pna_layer_num
|
250 |
+
self.graph_add = graph_add
|
251 |
+
self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
|
252 |
+
pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
|
253 |
+
towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
|
254 |
+
pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
|
255 |
+
|
256 |
+
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
257 |
+
|
258 |
+
h = self.PNA_layer(x, edge_index, edge_attr, batch)
|
259 |
+
|
260 |
+
h = activation(h) if activation is not None else h
|
261 |
+
|
262 |
+
return h"""
|
263 |
+
|
264 |
+
|
265 |
+
"""class Discriminator_old(nn.Module):
|
266 |
+
|
267 |
+
def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
|
268 |
+
super(Discriminator_old, self).__init__()
|
269 |
+
|
270 |
+
graph_conv_dim, aux_dim, linear_dim = conv_dim
|
271 |
+
|
272 |
+
# discriminator
|
273 |
+
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout,gcn_depth)
|
274 |
+
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
|
275 |
+
|
276 |
+
# multi dense layer
|
277 |
+
layers = []
|
278 |
+
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
|
279 |
+
layers.append(nn.Linear(c0,c1))
|
280 |
+
layers.append(nn.Dropout(dropout))
|
281 |
+
self.linear_layer = nn.Sequential(*layers)
|
282 |
+
|
283 |
+
self.output_layer = nn.Linear(linear_dim[-1], 1)
|
284 |
+
|
285 |
+
def forward(self, adj, hidden, node, activation=None):
|
286 |
+
|
287 |
+
adj = adj[:,:,:,1:].permute(0,3,1,2)
|
288 |
+
|
289 |
+
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
|
290 |
+
|
291 |
+
h = self.gcn_layer(annotations, adj)
|
292 |
+
annotations = torch.cat((h, hidden, node) if hidden is not None\
|
293 |
+
else (h, node), -1)
|
294 |
+
|
295 |
+
h = self.agg_layer(annotations, torch.tanh)
|
296 |
+
h = self.linear_layer(h)
|
297 |
+
|
298 |
+
# Need to implement batch discriminator #
|
299 |
+
#########################################
|
300 |
+
|
301 |
+
output = self.output_layer(h)
|
302 |
+
output = activation(output) if activation is not None else output
|
303 |
+
|
304 |
+
return output, h"""
|
305 |
+
|
306 |
+
"""class Discriminator_old2(nn.Module):
|
307 |
+
|
308 |
+
def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
|
309 |
+
super(Discriminator_old2, self).__init__()
|
310 |
+
|
311 |
+
graph_conv_dim, aux_dim, linear_dim = conv_dim
|
312 |
+
|
313 |
+
# discriminator
|
314 |
+
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout, gcn_depth)
|
315 |
+
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
|
316 |
+
|
317 |
+
# multi dense layer
|
318 |
+
layers = []
|
319 |
+
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
|
320 |
+
layers.append(nn.Linear(c0,c1))
|
321 |
+
layers.append(nn.Dropout(dropout))
|
322 |
+
self.linear_layer = nn.Sequential(*layers)
|
323 |
+
|
324 |
+
self.output_layer = nn.Linear(linear_dim[-1], 1)
|
325 |
+
|
326 |
+
def forward(self, adj, hidden, node, activation=None):
|
327 |
+
|
328 |
+
adj = adj[:,:,:,1:].permute(0,3,1,2)
|
329 |
+
|
330 |
+
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
|
331 |
+
|
332 |
+
h = self.gcn_layer(annotations, adj)
|
333 |
+
annotations = torch.cat((h, hidden, node) if hidden is not None\
|
334 |
+
else (h, node), -1)
|
335 |
+
|
336 |
+
h = self.agg_layer(annotations, torch.tanh)
|
337 |
+
h = self.linear_layer(h)
|
338 |
+
|
339 |
+
# Need to implement batch discriminator #
|
340 |
+
#########################################
|
341 |
+
|
342 |
+
output = self.output_layer(h)
|
343 |
+
output = activation(output) if activation is not None else output
|
344 |
+
|
345 |
+
return output, h"""
|
346 |
+
|
347 |
+
"""class Discriminator3(nn.Module):
|
348 |
+
|
349 |
+
def __init__(self,in_ch):
|
350 |
+
super(Discriminator3, self).__init__()
|
351 |
+
self.dim = in_ch
|
352 |
+
|
353 |
+
|
354 |
+
self.TraConv_layer = TransformerConv(in_channels = self.dim,out_channels = self.dim//4,edge_dim = self.dim)
|
355 |
+
self.mlp = torch.nn.Sequential(torch.nn.Tanh(), torch.nn.Linear(self.dim//4,1))
|
356 |
+
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
357 |
+
|
358 |
+
h = self.TraConv_layer(x, edge_index, edge_attr)
|
359 |
+
h = global_add_pool(h,batch)
|
360 |
+
h = self.mlp(h)
|
361 |
+
h = activation(h) if activation is not None else h
|
362 |
+
|
363 |
+
return h"""
|
364 |
+
|
365 |
+
|
366 |
+
"""class PNA_Net(nn.Module):
|
367 |
+
def __init__(self,deg):
|
368 |
+
super().__init__()
|
369 |
+
|
370 |
+
|
371 |
+
|
372 |
+
self.convs = nn.ModuleList()
|
373 |
+
|
374 |
+
self.lin = nn.Linear(5, 128)
|
375 |
+
for _ in range(1):
|
376 |
+
conv = DenseGCNConv(128, 128, improved=False, bias=True)
|
377 |
+
self.convs.append(conv)
|
378 |
+
|
379 |
+
self.agg_layer = GraphAggregation(128, 128, 0, dropout=0.1)
|
380 |
+
self.mlp = nn.Sequential(nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 32), nn.Tanh(),
|
381 |
+
nn.Linear(32, 1))
|
382 |
+
|
383 |
+
def forward(self, x, adj,mask=None):
|
384 |
+
x = self.lin(x)
|
385 |
+
|
386 |
+
for conv in self.convs:
|
387 |
+
x = F.relu(conv(x, adj,mask=None))
|
388 |
+
|
389 |
+
x = self.agg_layer(x,torch.tanh)
|
390 |
+
|
391 |
+
return self.mlp(x) """
|
392 |
+
|
new_dataloader.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import os.path as osp
|
3 |
+
import re
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from rdkit import Chem
|
9 |
+
from rdkit import RDLogger
|
10 |
+
from torch_geometric.data import (Data, InMemoryDataset)
|
11 |
+
|
12 |
+
RDLogger.DisableLog('rdApp.*')
|
13 |
+
class DruggenDataset(InMemoryDataset):
|
14 |
+
|
15 |
+
def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None):
|
16 |
+
self.dataset_name = dataset_file.split(".")[0]
|
17 |
+
self.dataset_file = dataset_file
|
18 |
+
self.raw_files = raw_files
|
19 |
+
self.max_atom = max_atom
|
20 |
+
self.features = features
|
21 |
+
|
22 |
+
super().__init__(root, transform, pre_transform, pre_filter)
|
23 |
+
self.data, self.slices = torch.load(osp.join(root, dataset_file))
|
24 |
+
|
25 |
+
|
26 |
+
@property
|
27 |
+
def raw_file_names(self):
|
28 |
+
return self.raw_files
|
29 |
+
|
30 |
+
@property
|
31 |
+
def processed_file_names(self):
|
32 |
+
'''
|
33 |
+
Return the processed file names. If these names are not present, they will be automatically processed using process function of this class.
|
34 |
+
'''
|
35 |
+
return self.dataset_file
|
36 |
+
|
37 |
+
def _generate_encoders_decoders(self, data):
|
38 |
+
"""
|
39 |
+
Generates the encoders and decoders for the atoms and bonds.
|
40 |
+
"""
|
41 |
+
self.data = data
|
42 |
+
print('Creating atoms encoder and decoder..')
|
43 |
+
|
44 |
+
atom_labels = set()
|
45 |
+
# bond_labels = set()
|
46 |
+
self.max_atom_size_in_data = 0
|
47 |
+
|
48 |
+
for smile in data:
|
49 |
+
mol = Chem.MolFromSmiles(smile)
|
50 |
+
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
|
51 |
+
# bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
|
52 |
+
self.max_atom_size_in_data = max(self.max_atom_size_in_data, mol.GetNumAtoms())
|
53 |
+
atom_labels.update([0]) # add PAD symbol (for unknown atoms)
|
54 |
+
atom_labels = sorted(atom_labels) # turn set into list and sort it
|
55 |
+
|
56 |
+
# atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
|
57 |
+
self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
|
58 |
+
self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
|
59 |
+
self.atom_num_types = len(atom_labels)
|
60 |
+
print(f'Created atoms encoder and decoder with {self.atom_num_types - 1} atom types and 1 PAD symbol!')
|
61 |
+
print("atom_labels", atom_labels)
|
62 |
+
print('Creating bonds encoder and decoder..')
|
63 |
+
# bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
|
64 |
+
# for mol in self.data
|
65 |
+
# for bond in mol.GetBonds())))
|
66 |
+
bond_labels = [
|
67 |
+
Chem.rdchem.BondType.ZERO,
|
68 |
+
Chem.rdchem.BondType.SINGLE,
|
69 |
+
Chem.rdchem.BondType.DOUBLE,
|
70 |
+
Chem.rdchem.BondType.TRIPLE,
|
71 |
+
Chem.rdchem.BondType.AROMATIC,
|
72 |
+
]
|
73 |
+
|
74 |
+
print("bond labels", bond_labels)
|
75 |
+
self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
|
76 |
+
self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
|
77 |
+
self.bond_num_types = len(bond_labels)
|
78 |
+
print(f'Created bonds encoder and decoder with {self.bond_num_types - 1} bond types and 1 PAD symbol!')
|
79 |
+
#dataset_names = str(self.dataset_name)
|
80 |
+
with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
|
81 |
+
pickle.dump(self.atom_encoder_m,atom_encoders)
|
82 |
+
|
83 |
+
|
84 |
+
with open("DrugGEN/data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders:
|
85 |
+
pickle.dump(self.atom_decoder_m,atom_decoders)
|
86 |
+
|
87 |
+
|
88 |
+
with open("DrugGEN/data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders:
|
89 |
+
pickle.dump(self.bond_encoder_m,bond_encoders)
|
90 |
+
|
91 |
+
|
92 |
+
with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
|
93 |
+
pickle.dump(self.bond_decoder_m,bond_decoders)
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
def generate_adjacency_matrix(self, mol, connected=True, max_length=None):
|
98 |
+
"""
|
99 |
+
Generates the adjacency matrix for a molecule.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
mol (Molecule): The molecule object.
|
103 |
+
connected (bool): Whether to check for connectivity in the molecule. Defaults to True.
|
104 |
+
max_length (int): The maximum length of the adjacency matrix. Defaults to the number of atoms in the molecule.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
numpy.ndarray or None: The adjacency matrix if connected and all atoms have a degree greater than 0,
|
108 |
+
otherwise None.
|
109 |
+
"""
|
110 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
111 |
+
|
112 |
+
A = np.zeros(shape=(max_length, max_length))
|
113 |
+
|
114 |
+
begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
|
115 |
+
bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
|
116 |
+
|
117 |
+
A[begin, end] = bond_type
|
118 |
+
A[end, begin] = bond_type
|
119 |
+
|
120 |
+
degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
|
121 |
+
|
122 |
+
return A if connected and (degree > 0).all() else None
|
123 |
+
|
124 |
+
def generate_node_features(self, mol, max_length=None):
|
125 |
+
"""
|
126 |
+
Generates the node features for a molecule.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
mol (Molecule): The molecule object.
|
130 |
+
max_length (int): The maximum length of the node features. Defaults to the number of atoms in the molecule.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
numpy.ndarray: The node features matrix.
|
134 |
+
"""
|
135 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
136 |
+
|
137 |
+
return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
|
138 |
+
max_length - mol.GetNumAtoms()))
|
139 |
+
|
140 |
+
def generate_additional_features(self, mol, max_length=None):
|
141 |
+
"""
|
142 |
+
Generates additional features for a molecule.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
mol (Molecule): The molecule object.
|
146 |
+
max_length (int): The maximum length of the additional features. Defaults to the number of atoms in the molecule.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
numpy.ndarray: The additional features matrix.
|
150 |
+
"""
|
151 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
152 |
+
|
153 |
+
features = np.array([[*[a.GetDegree() == i for i in range(5)],
|
154 |
+
*[a.GetExplicitValence() == i for i in range(9)],
|
155 |
+
*[int(a.GetHybridization()) == i for i in range(1, 7)],
|
156 |
+
*[a.GetImplicitValence() == i for i in range(9)],
|
157 |
+
a.GetIsAromatic(),
|
158 |
+
a.GetNoImplicit(),
|
159 |
+
*[a.GetNumExplicitHs() == i for i in range(5)],
|
160 |
+
*[a.GetNumImplicitHs() == i for i in range(5)],
|
161 |
+
*[a.GetNumRadicalElectrons() == i for i in range(5)],
|
162 |
+
a.IsInRing(),
|
163 |
+
*[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32)
|
164 |
+
|
165 |
+
return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
|
166 |
+
|
167 |
+
def decoder_load(self, dictionary_name):
|
168 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
169 |
+
return pickle.load(f)
|
170 |
+
|
171 |
+
def drugs_decoder_load(self, dictionary_name):
|
172 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
|
173 |
+
return pickle.load(f)
|
174 |
+
|
175 |
+
def matrices2mol(self, node_labels, edge_labels, strict=True):
|
176 |
+
mol = Chem.RWMol()
|
177 |
+
RDLogger.DisableLog('rdApp.*')
|
178 |
+
atom_decoders = self.decoder_load("atom")
|
179 |
+
bond_decoders = self.decoder_load("bond")
|
180 |
+
|
181 |
+
for node_label in node_labels:
|
182 |
+
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
|
183 |
+
|
184 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
185 |
+
if start > end:
|
186 |
+
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
|
187 |
+
mol = self.correct_mol(mol)
|
188 |
+
if strict:
|
189 |
+
try:
|
190 |
+
|
191 |
+
Chem.SanitizeMol(mol)
|
192 |
+
except:
|
193 |
+
mol = None
|
194 |
+
|
195 |
+
return mol
|
196 |
+
|
197 |
+
def drug_decoder_load(self, dictionary_name):
|
198 |
+
|
199 |
+
''' Loading the atom and bond decoders '''
|
200 |
+
|
201 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + "akt_train" +'.pkl', 'rb') as f:
|
202 |
+
|
203 |
+
return pickle.load(f)
|
204 |
+
def matrices2mol_drugs(self, node_labels, edge_labels, strict=True):
|
205 |
+
mol = Chem.RWMol()
|
206 |
+
RDLogger.DisableLog('rdApp.*')
|
207 |
+
atom_decoders = self.drug_decoder_load("atom")
|
208 |
+
bond_decoders = self.drug_decoder_load("bond")
|
209 |
+
|
210 |
+
for node_label in node_labels:
|
211 |
+
|
212 |
+
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
|
213 |
+
|
214 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
215 |
+
if start > end:
|
216 |
+
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
|
217 |
+
mol = self.correct_mol(mol)
|
218 |
+
if strict:
|
219 |
+
try:
|
220 |
+
Chem.SanitizeMol(mol)
|
221 |
+
except:
|
222 |
+
mol = None
|
223 |
+
|
224 |
+
return mol
|
225 |
+
def check_valency(self,mol):
|
226 |
+
"""
|
227 |
+
Checks that no atoms in the mol have exceeded their possible
|
228 |
+
valency
|
229 |
+
:return: True if no valency issues, False otherwise
|
230 |
+
"""
|
231 |
+
try:
|
232 |
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
|
233 |
+
return True, None
|
234 |
+
except ValueError as e:
|
235 |
+
e = str(e)
|
236 |
+
p = e.find('#')
|
237 |
+
e_sub = e[p:]
|
238 |
+
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
|
239 |
+
return False, atomid_valence
|
240 |
+
|
241 |
+
|
242 |
+
def correct_mol(self,x):
|
243 |
+
# xsm = Chem.MolToSmiles(x, isomericSmiles=True)
|
244 |
+
mol = x
|
245 |
+
while True:
|
246 |
+
flag, atomid_valence = self.check_valency(mol)
|
247 |
+
if flag:
|
248 |
+
break
|
249 |
+
else:
|
250 |
+
assert len (atomid_valence) == 2
|
251 |
+
idx = atomid_valence[0]
|
252 |
+
v = atomid_valence[1]
|
253 |
+
queue = []
|
254 |
+
for b in mol.GetAtomWithIdx(idx).GetBonds():
|
255 |
+
queue.append(
|
256 |
+
(b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
|
257 |
+
)
|
258 |
+
queue.sort(key=lambda tup: tup[1], reverse=True)
|
259 |
+
if len(queue) > 0:
|
260 |
+
start = queue[0][2]
|
261 |
+
end = queue[0][3]
|
262 |
+
t = queue[0][1] - 1
|
263 |
+
mol.RemoveBond(start, end)
|
264 |
+
|
265 |
+
#if t >= 1:
|
266 |
+
|
267 |
+
#mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
|
268 |
+
# if '.' in Chem.MolToSmiles(mol, isomericSmiles=True):
|
269 |
+
# mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
|
270 |
+
# print(tt)
|
271 |
+
# print(Chem.MolToSmiles(mol, isomericSmiles=True))
|
272 |
+
|
273 |
+
return mol
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
def label2onehot(self, labels, dim):
|
278 |
+
|
279 |
+
"""Convert label indices to one-hot vectors."""
|
280 |
+
|
281 |
+
out = torch.zeros(list(labels.size())+[dim])
|
282 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
283 |
+
|
284 |
+
return out.float()
|
285 |
+
|
286 |
+
def process(self, size= None):
|
287 |
+
'''
|
288 |
+
Process the dataset. This function will be only run if processed_file_names does not exist in the data folder already.
|
289 |
+
'''
|
290 |
+
# mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()]
|
291 |
+
# mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols))
|
292 |
+
# mols = mols[:size] # i
|
293 |
+
# indices = range(len(mols))
|
294 |
+
|
295 |
+
smiles = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
296 |
+
self._generate_encoders_decoders(smiles)
|
297 |
+
|
298 |
+
# pbar.set_description(f'Processing chembl dataset')
|
299 |
+
# max_length = max(mol.GetNumAtoms() for mol in mols)
|
300 |
+
data_list = []
|
301 |
+
max_length = min(self.max_atom_size_in_data, self.max_atom)
|
302 |
+
self.m_dim = len(self.atom_decoder_m)
|
303 |
+
# for idx in indices:
|
304 |
+
for smiles in tqdm(smiles, desc='Processing chembl dataset', total=len(smiles)):
|
305 |
+
# mol = mols[idx]
|
306 |
+
|
307 |
+
mol = Chem.MolFromSmiles(smile)
|
308 |
+
|
309 |
+
# filter by max atom size
|
310 |
+
if mol.GetNumAtoms() > max_length:
|
311 |
+
continue
|
312 |
+
|
313 |
+
A = self.generate_adjacency_matrix(mol, connected=True, max_length=max_length)
|
314 |
+
if A is not None:
|
315 |
+
|
316 |
+
|
317 |
+
x = torch.from_numpy(self.generate_node_features(mol, max_length=max_length)).to(torch.long).view(1, -1)
|
318 |
+
|
319 |
+
x = self.label2onehot(x,self.m_dim).squeeze()
|
320 |
+
if self.features:
|
321 |
+
f = torch.from_numpy(self.generate_additional_features(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
|
322 |
+
x = torch.concat((x,f), dim=-1)
|
323 |
+
|
324 |
+
adjacency = torch.from_numpy(A)
|
325 |
+
|
326 |
+
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|
327 |
+
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
|
328 |
+
|
329 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
330 |
+
|
331 |
+
if self.pre_filter is not None and not self.pre_filter(data):
|
332 |
+
continue
|
333 |
+
|
334 |
+
if self.pre_transform is not None:
|
335 |
+
data = self.pre_transform(data)
|
336 |
+
|
337 |
+
data_list.append(data)
|
338 |
+
# pbar.update(1)
|
339 |
+
|
340 |
+
# pbar.close()
|
341 |
+
|
342 |
+
torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == '__main__':
|
348 |
+
data = DruggenDataset("DrugGEN/data")
|
349 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
rdkit-pypi
|
3 |
+
tqdm
|
4 |
+
numpy
|
5 |
+
seaborn
|
6 |
+
matplotlib
|
7 |
+
pandas
|
8 |
+
torch_geometric
|
trainer.py
ADDED
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch.nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from utils import *
|
7 |
+
from models import Generator, Generator2, simple_disc
|
8 |
+
import torch_geometric.utils as geoutils
|
9 |
+
#import #wandb
|
10 |
+
import re
|
11 |
+
from torch_geometric.loader import DataLoader
|
12 |
+
from new_dataloader import DruggenDataset
|
13 |
+
import torch.utils.data
|
14 |
+
from rdkit import RDLogger
|
15 |
+
import pickle
|
16 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
17 |
+
torch.set_num_threads(5)
|
18 |
+
RDLogger.DisableLog('rdApp.*')
|
19 |
+
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
|
20 |
+
from training_data import load_data
|
21 |
+
import random
|
22 |
+
|
23 |
+
|
24 |
+
class Trainer(object):
|
25 |
+
|
26 |
+
"""Trainer for training and testing DrugGEN."""
|
27 |
+
|
28 |
+
def __init__(self, config):
|
29 |
+
|
30 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
31 |
+
"""Initialize configurations."""
|
32 |
+
self.submodel = config.submodel
|
33 |
+
self.inference_model = config.inference_model
|
34 |
+
# Data loader.
|
35 |
+
self.raw_file = config.raw_file # SMILES containing text file for first dataset.
|
36 |
+
# Write the full path to file.
|
37 |
+
|
38 |
+
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
|
39 |
+
# Write the full path to file.
|
40 |
+
|
41 |
+
|
42 |
+
self.dataset_file = config.dataset_file # Dataset file name for the first GAN.
|
43 |
+
# Contains large number of molecules.
|
44 |
+
|
45 |
+
self.drugs_dataset_file = config.drug_dataset_file # Drug dataset file name for the second GAN.
|
46 |
+
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
47 |
+
|
48 |
+
self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
|
49 |
+
# Write the full path to file.
|
50 |
+
|
51 |
+
self.inf_drug_raw_file = config.inf_drug_raw_file # SMILES containing text file for second dataset.
|
52 |
+
# Write the full path to file.
|
53 |
+
|
54 |
+
|
55 |
+
self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN.
|
56 |
+
# Contains large number of molecules.
|
57 |
+
|
58 |
+
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
|
59 |
+
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
60 |
+
|
61 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
62 |
+
|
63 |
+
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
64 |
+
|
65 |
+
self.dataset_name = self.dataset_file.split(".")[0]
|
66 |
+
self.drugs_name = self.drugs_dataset_file.split(".")[0]
|
67 |
+
|
68 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
69 |
+
# Max atom number for molecules must be specified.
|
70 |
+
|
71 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
72 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
73 |
+
|
74 |
+
|
75 |
+
self.batch_size = config.batch_size # Batch size for training.
|
76 |
+
|
77 |
+
self.dataset = DruggenDataset(self.mol_data_dir,
|
78 |
+
self.dataset_file,
|
79 |
+
self.raw_file,
|
80 |
+
self.max_atom,
|
81 |
+
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
|
82 |
+
# Can create any molecular graph dataset given smiles string.
|
83 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
84 |
+
# Uses sparse matrix representation for graphs,
|
85 |
+
# For computational and speed efficiency.
|
86 |
+
|
87 |
+
self.loader = DataLoader(self.dataset,
|
88 |
+
shuffle=True,
|
89 |
+
batch_size=self.batch_size,
|
90 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
91 |
+
|
92 |
+
self.drugs = DruggenDataset(self.drug_data_dir,
|
93 |
+
self.drugs_dataset_file,
|
94 |
+
self.drug_raw_file,
|
95 |
+
self.max_atom,
|
96 |
+
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
|
97 |
+
# Can create any molecular graph dataset given smiles string.
|
98 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
99 |
+
# Uses sparse matrix representation for graphs,
|
100 |
+
# For computational and speed efficiency.
|
101 |
+
|
102 |
+
self.drugs_loader = DataLoader(self.drugs,
|
103 |
+
shuffle=True,
|
104 |
+
batch_size=self.batch_size,
|
105 |
+
drop_last=True) # PyG dataloader for the second GAN.
|
106 |
+
|
107 |
+
# Atom and bond type dimensions for the construction of the model.
|
108 |
+
|
109 |
+
self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN.
|
110 |
+
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
|
111 |
+
|
112 |
+
self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN.
|
113 |
+
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
|
114 |
+
|
115 |
+
self.m_dim = len(self.atom_decoders) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
|
116 |
+
|
117 |
+
self.b_dim = len(self.bond_decoders) # Bond type dimension.
|
118 |
+
|
119 |
+
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
120 |
+
|
121 |
+
self.drugs_atom_decoders = self.drug_decoder_load("atom") # Atom type decoders for second GAN.
|
122 |
+
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
|
123 |
+
|
124 |
+
self.drugs_bond_decoders = self.drug_decoder_load("bond") # Bond type decoders for second GAN.
|
125 |
+
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
|
126 |
+
|
127 |
+
self.drugs_m_dim = len(self.drugs_atom_decoders) if not self.features else int(self.drugs_loader.dataset[0].x.shape[1]) # Atom type dimension.
|
128 |
+
|
129 |
+
self.drugs_b_dim = len(self.drugs_bond_decoders) # Bond type dimension.
|
130 |
+
|
131 |
+
self.drug_vertexes = int(self.drugs_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
132 |
+
|
133 |
+
# Transformer and Convolution configurations.
|
134 |
+
|
135 |
+
self.act = config.act
|
136 |
+
|
137 |
+
self.z_dim = config.z_dim
|
138 |
+
|
139 |
+
self.lambda_gp = config.lambda_gp
|
140 |
+
|
141 |
+
self.dim = config.dim
|
142 |
+
|
143 |
+
self.depth = config.depth
|
144 |
+
|
145 |
+
self.heads = config.heads
|
146 |
+
|
147 |
+
self.mlp_ratio = config.mlp_ratio
|
148 |
+
|
149 |
+
self.dec_depth = config.dec_depth
|
150 |
+
|
151 |
+
self.dec_heads = config.dec_heads
|
152 |
+
|
153 |
+
self.dec_dim = config.dec_dim
|
154 |
+
|
155 |
+
self.dis_select = config.dis_select
|
156 |
+
|
157 |
+
"""self.la = config.la
|
158 |
+
self.la2 = config.la2
|
159 |
+
self.gcn_depth = config.gcn_depth
|
160 |
+
self.g_conv_dim = config.g_conv_dim
|
161 |
+
self.d_conv_dim = config.d_conv_dim"""
|
162 |
+
"""# PNA config
|
163 |
+
|
164 |
+
self.agg = config.aggregators
|
165 |
+
self.sca = config.scalers
|
166 |
+
self.pna_in_ch = config.pna_in_ch
|
167 |
+
self.pna_out_ch = config.pna_out_ch
|
168 |
+
self.edge_dim = config.edge_dim
|
169 |
+
self.towers = config.towers
|
170 |
+
self.pre_lay = config.pre_lay
|
171 |
+
self.post_lay = config.post_lay
|
172 |
+
self.pna_layer_num = config.pna_layer_num
|
173 |
+
self.graph_add = config.graph_add"""
|
174 |
+
|
175 |
+
# Training configurations.
|
176 |
+
|
177 |
+
self.epoch = config.epoch
|
178 |
+
|
179 |
+
self.g_lr = config.g_lr
|
180 |
+
|
181 |
+
self.d_lr = config.d_lr
|
182 |
+
|
183 |
+
self.g2_lr = config.g2_lr
|
184 |
+
|
185 |
+
self.d2_lr = config.d2_lr
|
186 |
+
|
187 |
+
self.dropout = config.dropout
|
188 |
+
|
189 |
+
self.dec_dropout = config.dec_dropout
|
190 |
+
|
191 |
+
self.n_critic = config.n_critic
|
192 |
+
|
193 |
+
self.beta1 = config.beta1
|
194 |
+
|
195 |
+
self.beta2 = config.beta2
|
196 |
+
|
197 |
+
self.resume_iters = config.resume_iters
|
198 |
+
|
199 |
+
self.warm_up_steps = config.warm_up_steps
|
200 |
+
|
201 |
+
# Test configurations.
|
202 |
+
|
203 |
+
self.num_test_epoch = config.num_test_epoch
|
204 |
+
|
205 |
+
self.test_iters = config.test_iters
|
206 |
+
|
207 |
+
self.inference_sample_num = config.inference_sample_num
|
208 |
+
|
209 |
+
# Directories.
|
210 |
+
|
211 |
+
self.log_dir = config.log_dir
|
212 |
+
self.sample_dir = config.sample_dir
|
213 |
+
self.model_save_dir = config.model_save_dir
|
214 |
+
self.result_dir = config.result_dir
|
215 |
+
|
216 |
+
# Step size.
|
217 |
+
|
218 |
+
self.log_step = config.log_sample_step
|
219 |
+
self.clipping_value = config.clipping_value
|
220 |
+
# Miscellaneous.
|
221 |
+
|
222 |
+
self.mode = config.mode
|
223 |
+
|
224 |
+
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
|
225 |
+
self.noise_strength_1 = torch.nn.Parameter(torch.zeros([]))
|
226 |
+
self.noise_strength_2 = torch.nn.Parameter(torch.zeros([]))
|
227 |
+
self.noise_strength_3 = torch.nn.Parameter(torch.zeros([]))
|
228 |
+
|
229 |
+
self.init_type = config.init_type
|
230 |
+
self.build_model()
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
def build_model(self):
|
235 |
+
"""Create generators and discriminators."""
|
236 |
+
|
237 |
+
''' Generator is based on Transformer Encoder:
|
238 |
+
|
239 |
+
@ g_conv_dim: Dimensions for first MLP layers before Transformer Encoder
|
240 |
+
@ vertexes: maximum length of generated molecules (atom length)
|
241 |
+
@ b_dim: number of bond types
|
242 |
+
@ m_dim: number of atom types (or number of features used)
|
243 |
+
@ dropout: dropout possibility
|
244 |
+
@ dim: Hidden dimension of Transformer Encoder
|
245 |
+
@ depth: Transformer layer number
|
246 |
+
@ heads: Number of multihead-attention heads
|
247 |
+
@ mlp_ratio: Read-out layer dimension of Transformer
|
248 |
+
@ drop_rate: depricated
|
249 |
+
@ tra_conv: Whether module creates output for TransformerConv discriminator
|
250 |
+
'''
|
251 |
+
|
252 |
+
self.G = Generator(self.z_dim,
|
253 |
+
self.act,
|
254 |
+
self.vertexes,
|
255 |
+
self.b_dim,
|
256 |
+
self.m_dim,
|
257 |
+
self.dropout,
|
258 |
+
dim=self.dim,
|
259 |
+
depth=self.depth,
|
260 |
+
heads=self.heads,
|
261 |
+
mlp_ratio=self.mlp_ratio,
|
262 |
+
submodel = self.submodel)
|
263 |
+
|
264 |
+
self.G2 = Generator2(self.dim,
|
265 |
+
self.dec_dim,
|
266 |
+
self.dec_depth,
|
267 |
+
self.dec_heads,
|
268 |
+
self.mlp_ratio,
|
269 |
+
self.dec_dropout,
|
270 |
+
self.drugs_m_dim,
|
271 |
+
self.drugs_b_dim,
|
272 |
+
self.submodel)
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
''' Discriminator implementation with PNA:
|
277 |
+
|
278 |
+
@ deg: Degree distribution based on used data. (Created with _genDegree() function)
|
279 |
+
@ agg: aggregators used in PNA
|
280 |
+
@ sca: scalers used in PNA
|
281 |
+
@ pna_in_ch: First PNA hidden dimension
|
282 |
+
@ pna_out_ch: Last PNA hidden dimension
|
283 |
+
@ edge_dim: Edge hidden dimension
|
284 |
+
@ towers: Number of towers (Splitting the hidden dimension to multiple parallel processes)
|
285 |
+
@ pre_lay: Pre-transformation layer
|
286 |
+
@ post_lay: Post-transformation layer
|
287 |
+
@ pna_layer_num: number of PNA layers
|
288 |
+
@ graph_add: global pooling layer selection
|
289 |
+
'''
|
290 |
+
|
291 |
+
|
292 |
+
''' Discriminator implementation with Graph Convolution:
|
293 |
+
|
294 |
+
@ d_conv_dim: convolution dimensions for GCN
|
295 |
+
@ m_dim: number of atom types (or number of features used)
|
296 |
+
@ b_dim: number of bond types
|
297 |
+
@ dropout: dropout possibility
|
298 |
+
'''
|
299 |
+
|
300 |
+
''' Discriminator implementation with MLP:
|
301 |
+
|
302 |
+
@ act: Activation function for MLP
|
303 |
+
@ m_dim: number of atom types (or number of features used)
|
304 |
+
@ b_dim: number of bond types
|
305 |
+
@ dropout: dropout possibility
|
306 |
+
@ vertexes: maximum length of generated molecules (molecule length)
|
307 |
+
'''
|
308 |
+
|
309 |
+
#self.D = Discriminator_old(self.d_conv_dim, self.m_dim , self.b_dim, self.dropout, self.gcn_depth)
|
310 |
+
self.D2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
|
311 |
+
self.D = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
|
312 |
+
self.V = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
|
313 |
+
self.V2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
|
314 |
+
|
315 |
+
''' Optimizers for G1, G2, D1, and D2:
|
316 |
+
|
317 |
+
Adam Optimizer is used and different beta1 and beta2s are used for GAN1 and GAN2
|
318 |
+
'''
|
319 |
+
|
320 |
+
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
|
321 |
+
self.g2_optimizer = torch.optim.AdamW(self.G2.parameters(), self.g2_lr, [self.beta1, self.beta2])
|
322 |
+
|
323 |
+
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
|
324 |
+
self.d2_optimizer = torch.optim.AdamW(self.D2.parameters(), self.d2_lr, [self.beta1, self.beta2])
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
self.v_optimizer = torch.optim.AdamW(self.V.parameters(), self.d_lr, [self.beta1, self.beta2])
|
329 |
+
self.v2_optimizer = torch.optim.AdamW(self.V2.parameters(), self.d2_lr, [self.beta1, self.beta2])
|
330 |
+
''' Learning rate scheduler:
|
331 |
+
|
332 |
+
Changes learning rate based on loss.
|
333 |
+
'''
|
334 |
+
|
335 |
+
#self.scheduler_g = ReduceLROnPlateau(self.g_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
336 |
+
|
337 |
+
|
338 |
+
#self.scheduler_d = ReduceLROnPlateau(self.d_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
339 |
+
|
340 |
+
#self.scheduler_v = ReduceLROnPlateau(self.v_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
341 |
+
#self.scheduler_g2 = ReduceLROnPlateau(self.g2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
342 |
+
#self.scheduler_d2 = ReduceLROnPlateau(self.d2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
343 |
+
#self.scheduler_v2 = ReduceLROnPlateau(self.v2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
344 |
+
self.print_network(self.G, 'G')
|
345 |
+
self.print_network(self.D, 'D')
|
346 |
+
|
347 |
+
self.print_network(self.G2, 'G2')
|
348 |
+
self.print_network(self.D2, 'D2')
|
349 |
+
|
350 |
+
self.G.to(self.device)
|
351 |
+
self.D.to(self.device)
|
352 |
+
|
353 |
+
self.V.to(self.device)
|
354 |
+
self.V2.to(self.device)
|
355 |
+
self.G2.to(self.device)
|
356 |
+
self.D2.to(self.device)
|
357 |
+
|
358 |
+
#self.V2.to(self.device)
|
359 |
+
#self.modules_of_the_model = (self.G, self.D, self.G2, self.D2)
|
360 |
+
"""for p in self.G.parameters():
|
361 |
+
if p.dim() > 1:
|
362 |
+
if self.init_type == 'uniform':
|
363 |
+
torch.nn.init.xavier_uniform_(p)
|
364 |
+
elif self.init_type == 'normal':
|
365 |
+
torch.nn.init.xavier_normal_(p)
|
366 |
+
elif self.init_type == 'random_normal':
|
367 |
+
torch.nn.init.normal_(p, 0.0, 0.02)
|
368 |
+
for p in self.G2.parameters():
|
369 |
+
if p.dim() > 1:
|
370 |
+
if self.init_type == 'uniform':
|
371 |
+
torch.nn.init.xavier_uniform_(p)
|
372 |
+
elif self.init_type == 'normal':
|
373 |
+
torch.nn.init.xavier_normal_(p)
|
374 |
+
elif self.init_type == 'random_normal':
|
375 |
+
torch.nn.init.normal_(p, 0.0, 0.02)
|
376 |
+
if self.dis_select == "conv":
|
377 |
+
for p in self.D.parameters():
|
378 |
+
if p.dim() > 1:
|
379 |
+
if self.init_type == 'uniform':
|
380 |
+
torch.nn.init.xavier_uniform_(p)
|
381 |
+
elif self.init_type == 'normal':
|
382 |
+
torch.nn.init.xavier_normal_(p)
|
383 |
+
elif self.init_type == 'random_normal':
|
384 |
+
torch.nn.init.normal_(p, 0.0, 0.02)
|
385 |
+
|
386 |
+
if self.dis_select == "conv":
|
387 |
+
for p in self.D2.parameters():
|
388 |
+
if p.dim() > 1:
|
389 |
+
if self.init_type == 'uniform':
|
390 |
+
torch.nn.init.xavier_uniform_(p)
|
391 |
+
elif self.init_type == 'normal':
|
392 |
+
torch.nn.init.xavier_normal_(p)
|
393 |
+
elif self.init_type == 'random_normal':
|
394 |
+
torch.nn.init.normal_(p, 0.0, 0.02)"""
|
395 |
+
|
396 |
+
|
397 |
+
def decoder_load(self, dictionary_name):
|
398 |
+
|
399 |
+
''' Loading the atom and bond decoders'''
|
400 |
+
|
401 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
402 |
+
|
403 |
+
return pickle.load(f)
|
404 |
+
|
405 |
+
def drug_decoder_load(self, dictionary_name):
|
406 |
+
|
407 |
+
''' Loading the atom and bond decoders'''
|
408 |
+
|
409 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
410 |
+
|
411 |
+
return pickle.load(f)
|
412 |
+
|
413 |
+
def print_network(self, model, name):
|
414 |
+
|
415 |
+
"""Print out the network information."""
|
416 |
+
|
417 |
+
num_params = 0
|
418 |
+
for p in model.parameters():
|
419 |
+
num_params += p.numel()
|
420 |
+
print(model)
|
421 |
+
print(name)
|
422 |
+
print("The number of parameters: {}".format(num_params))
|
423 |
+
|
424 |
+
|
425 |
+
def restore_model(self, epoch, iteration, model_directory):
|
426 |
+
|
427 |
+
"""Restore the trained generator and discriminator."""
|
428 |
+
|
429 |
+
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
430 |
+
|
431 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
432 |
+
#D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
|
433 |
+
|
434 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
435 |
+
#self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
436 |
+
|
437 |
+
|
438 |
+
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
|
439 |
+
#D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
|
440 |
+
|
441 |
+
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
442 |
+
#self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
|
443 |
+
|
444 |
+
|
445 |
+
def save_model(self, model_directory, idx,i):
|
446 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
|
447 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
|
448 |
+
torch.save(self.G.state_dict(), G_path)
|
449 |
+
torch.save(self.D.state_dict(), D_path)
|
450 |
+
|
451 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
452 |
+
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(idx+1,i+1))
|
453 |
+
D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(idx+1,i+1))
|
454 |
+
|
455 |
+
torch.save(self.G2.state_dict(), G2_path)
|
456 |
+
torch.save(self.D2.state_dict(), D2_path)
|
457 |
+
|
458 |
+
def reset_grad(self):
|
459 |
+
|
460 |
+
"""Reset the gradient buffers."""
|
461 |
+
|
462 |
+
self.g_optimizer.zero_grad()
|
463 |
+
self.v_optimizer.zero_grad()
|
464 |
+
self.g2_optimizer.zero_grad()
|
465 |
+
self.v2_optimizer.zero_grad()
|
466 |
+
|
467 |
+
self.d_optimizer.zero_grad()
|
468 |
+
self.d2_optimizer.zero_grad()
|
469 |
+
|
470 |
+
def gradient_penalty(self, y, x):
|
471 |
+
|
472 |
+
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
|
473 |
+
|
474 |
+
weight = torch.ones(y.size(),requires_grad=False).to(self.device)
|
475 |
+
dydx = torch.autograd.grad(outputs=y,
|
476 |
+
inputs=x,
|
477 |
+
grad_outputs=weight,
|
478 |
+
retain_graph=True,
|
479 |
+
create_graph=True,
|
480 |
+
only_inputs=True)[0]
|
481 |
+
|
482 |
+
dydx = dydx.view(dydx.size(0), -1)
|
483 |
+
gradient_penalty = ((dydx.norm(2, dim=1) - 1) ** 2).mean()
|
484 |
+
|
485 |
+
return gradient_penalty
|
486 |
+
|
487 |
+
def train(self):
|
488 |
+
|
489 |
+
''' Training Script starts from here'''
|
490 |
+
|
491 |
+
#wandb.config = {'beta2': 0.999}
|
492 |
+
#wandb.init(project="DrugGEN2", entity="atabeyunlu")
|
493 |
+
|
494 |
+
# Defining sampling paths and creating logger
|
495 |
+
|
496 |
+
self.arguments = "{}_glr{}_dlr{}_g2lr{}_d2lr{}_dim{}_depth{}_heads{}_decdepth{}_decheads{}_ncritic{}_batch{}_epoch{}_warmup{}_dataset{}_dropout{}".format(self.submodel,self.g_lr,self.d_lr,self.g2_lr,self.d2_lr,self.dim,self.depth,self.heads,self.dec_depth,self.dec_heads,self.n_critic,self.batch_size,self.epoch,self.warm_up_steps,self.dataset_name,self.dropout)
|
497 |
+
|
498 |
+
self.model_directory= os.path.join(self.model_save_dir,self.arguments)
|
499 |
+
self.sample_directory=os.path.join(self.sample_dir,self.arguments)
|
500 |
+
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
|
501 |
+
if not os.path.exists(self.model_directory):
|
502 |
+
os.makedirs(self.model_directory)
|
503 |
+
if not os.path.exists(self.sample_directory):
|
504 |
+
os.makedirs(self.sample_directory)
|
505 |
+
|
506 |
+
# Learning rate cache for decaying.
|
507 |
+
|
508 |
+
|
509 |
+
# protein data
|
510 |
+
full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
|
511 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
|
512 |
+
|
513 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
514 |
+
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
515 |
+
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
516 |
+
|
517 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
518 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
519 |
+
|
520 |
+
# Start training.
|
521 |
+
|
522 |
+
print('Start training...')
|
523 |
+
self.start_time = time.time()
|
524 |
+
for idx in range(self.epoch):
|
525 |
+
|
526 |
+
# =================================================================================== #
|
527 |
+
# 1. Preprocess input data #
|
528 |
+
# =================================================================================== #
|
529 |
+
|
530 |
+
# Load the data
|
531 |
+
|
532 |
+
dataloader_iterator = iter(self.drugs_loader)
|
533 |
+
|
534 |
+
for i, data in enumerate(self.loader):
|
535 |
+
try:
|
536 |
+
drugs = next(dataloader_iterator)
|
537 |
+
except StopIteration:
|
538 |
+
dataloader_iterator = iter(self.drugs_loader)
|
539 |
+
drugs = next(dataloader_iterator)
|
540 |
+
|
541 |
+
# Preprocess both dataset
|
542 |
+
|
543 |
+
bulk_data = load_data(data,
|
544 |
+
drugs,
|
545 |
+
self.batch_size,
|
546 |
+
self.device,
|
547 |
+
self.b_dim,
|
548 |
+
self.m_dim,
|
549 |
+
self.drugs_b_dim,
|
550 |
+
self.drugs_m_dim,
|
551 |
+
self.z_dim,
|
552 |
+
self.vertexes)
|
553 |
+
|
554 |
+
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
555 |
+
|
556 |
+
if self.submodel == "CrossLoss":
|
557 |
+
GAN1_input_e = drugs_a_tensor
|
558 |
+
GAN1_input_x = drugs_x_tensor
|
559 |
+
GAN1_disc_e = a_tensor
|
560 |
+
GAN1_disc_x = x_tensor
|
561 |
+
elif self.submodel == "Ligand":
|
562 |
+
GAN1_input_e = a_tensor
|
563 |
+
GAN1_input_x = x_tensor
|
564 |
+
GAN1_disc_e = a_tensor
|
565 |
+
GAN1_disc_x = x_tensor
|
566 |
+
GAN2_input_e = drugs_a_tensor
|
567 |
+
GAN2_input_x = drugs_x_tensor
|
568 |
+
GAN2_disc_e = drugs_a_tensor
|
569 |
+
GAN2_disc_x = drugs_x_tensor
|
570 |
+
elif self.submodel == "Prot":
|
571 |
+
GAN1_input_e = a_tensor
|
572 |
+
GAN1_input_x = x_tensor
|
573 |
+
GAN1_disc_e = a_tensor
|
574 |
+
GAN1_disc_x = x_tensor
|
575 |
+
GAN2_input_e = akt1_human_adj
|
576 |
+
GAN2_input_x = akt1_human_annot
|
577 |
+
GAN2_disc_e = drugs_a_tensor
|
578 |
+
GAN2_disc_x = drugs_x_tensor
|
579 |
+
elif self.submodel == "RL":
|
580 |
+
GAN1_input_e = z_edge
|
581 |
+
GAN1_input_x = z_node
|
582 |
+
GAN1_disc_e = a_tensor
|
583 |
+
GAN1_disc_x = x_tensor
|
584 |
+
GAN2_input_e = drugs_a_tensor
|
585 |
+
GAN2_input_x = drugs_x_tensor
|
586 |
+
GAN2_disc_e = drugs_a_tensor
|
587 |
+
GAN2_disc_x = drugs_x_tensor
|
588 |
+
elif self.submodel == "NoTarget":
|
589 |
+
GAN1_input_e = z_edge
|
590 |
+
GAN1_input_x = z_node
|
591 |
+
GAN1_disc_e = a_tensor
|
592 |
+
GAN1_disc_x = x_tensor
|
593 |
+
|
594 |
+
# =================================================================================== #
|
595 |
+
# 2. Train the discriminator #
|
596 |
+
# =================================================================================== #
|
597 |
+
loss = {}
|
598 |
+
self.reset_grad()
|
599 |
+
|
600 |
+
# Compute discriminator loss.
|
601 |
+
|
602 |
+
node, edge, d_loss = discriminator_loss(self.G,
|
603 |
+
self.D,
|
604 |
+
real_graphs,
|
605 |
+
GAN1_disc_e,
|
606 |
+
GAN1_disc_x,
|
607 |
+
self.batch_size,
|
608 |
+
self.device,
|
609 |
+
self.gradient_penalty,
|
610 |
+
self.lambda_gp,
|
611 |
+
GAN1_input_e,
|
612 |
+
GAN1_input_x)
|
613 |
+
|
614 |
+
d_total = d_loss
|
615 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
616 |
+
d2_loss = discriminator2_loss(self.G2,
|
617 |
+
self.D2,
|
618 |
+
drug_graphs,
|
619 |
+
edge,
|
620 |
+
node,
|
621 |
+
self.batch_size,
|
622 |
+
self.device,
|
623 |
+
self.gradient_penalty,
|
624 |
+
self.lambda_gp,
|
625 |
+
GAN2_input_e,
|
626 |
+
GAN2_input_x)
|
627 |
+
d_total = d_loss + d2_loss
|
628 |
+
|
629 |
+
loss["d_total"] = d_total.item()
|
630 |
+
d_total.backward()
|
631 |
+
self.d_optimizer.step()
|
632 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
633 |
+
self.d2_optimizer.step()
|
634 |
+
self.reset_grad()
|
635 |
+
generator_output = generator_loss(self.G,
|
636 |
+
self.D,
|
637 |
+
self.V,
|
638 |
+
GAN1_input_e,
|
639 |
+
GAN1_input_x,
|
640 |
+
self.batch_size,
|
641 |
+
sim_reward,
|
642 |
+
self.dataset.matrices2mol_drugs,
|
643 |
+
fps_r,
|
644 |
+
self.submodel)
|
645 |
+
|
646 |
+
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
|
647 |
+
|
648 |
+
self.reset_grad()
|
649 |
+
g_total = g_loss
|
650 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
651 |
+
output = generator2_loss(self.G2,
|
652 |
+
self.D2,
|
653 |
+
self.V2,
|
654 |
+
edge,
|
655 |
+
node,
|
656 |
+
self.batch_size,
|
657 |
+
sim_reward,
|
658 |
+
self.dataset.matrices2mol_drugs,
|
659 |
+
fps_r,
|
660 |
+
GAN2_input_e,
|
661 |
+
GAN2_input_x,
|
662 |
+
self.submodel)
|
663 |
+
|
664 |
+
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
|
665 |
+
|
666 |
+
g_total = g_loss + g2_loss
|
667 |
+
|
668 |
+
loss["g_total"] = g_total.item()
|
669 |
+
g_total.backward()
|
670 |
+
self.g_optimizer.step()
|
671 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
672 |
+
self.g2_optimizer.step()
|
673 |
+
|
674 |
+
if self.submodel == "RL":
|
675 |
+
self.v_optimizer.step()
|
676 |
+
self.v2_optimizer.step()
|
677 |
+
|
678 |
+
|
679 |
+
if (i+1) % self.log_step == 0:
|
680 |
+
|
681 |
+
logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1,self.sample_directory)
|
682 |
+
mol_sample(self.sample_directory,"GAN1",fake_mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), idx, i)
|
683 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
684 |
+
logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2,self.sample_directory)
|
685 |
+
mol_sample(self.sample_directory,"GAN2",fake_mol_g, dr_g_edges_hat_sample.detach(), dr_g_nodes_hat_sample.detach(), idx, i)
|
686 |
+
|
687 |
+
|
688 |
+
if (idx+1) % 10 == 0:
|
689 |
+
self.save_model(self.model_directory,idx,i)
|
690 |
+
print("model saved at epoch {} and iteration {}".format(idx,i))
|
691 |
+
|
692 |
+
|
693 |
+
|
694 |
+
def inference(self):
|
695 |
+
|
696 |
+
# Load the trained generator.
|
697 |
+
self.G.to(self.device)
|
698 |
+
#self.D.to(self.device)
|
699 |
+
self.G2.to(self.device)
|
700 |
+
#self.D2.to(self.device)
|
701 |
+
|
702 |
+
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
|
703 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
704 |
+
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
|
705 |
+
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
706 |
+
|
707 |
+
|
708 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
|
709 |
+
|
710 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
711 |
+
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
712 |
+
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
713 |
+
|
714 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
715 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
716 |
+
|
717 |
+
self.G.eval()
|
718 |
+
#self.D.eval()
|
719 |
+
self.G2.eval()
|
720 |
+
#self.D2.eval()
|
721 |
+
|
722 |
+
self.inf_batch_size =256
|
723 |
+
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
724 |
+
self.inf_dataset_file,
|
725 |
+
self.inf_raw_file,
|
726 |
+
self.max_atom,
|
727 |
+
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
|
728 |
+
# Can create any molecular graph dataset given smiles string.
|
729 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
730 |
+
# Uses sparse matrix representation for graphs,
|
731 |
+
# For computational and speed efficiency.
|
732 |
+
|
733 |
+
self.inf_loader = DataLoader(self.inf_dataset,
|
734 |
+
shuffle=True,
|
735 |
+
batch_size=self.inf_batch_size,
|
736 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
737 |
+
|
738 |
+
self.inf_drugs = DruggenDataset(self.drug_data_dir,
|
739 |
+
self.inf_drugs_dataset_file,
|
740 |
+
self.inf_drug_raw_file,
|
741 |
+
self.max_atom,
|
742 |
+
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
|
743 |
+
# Can create any molecular graph dataset given smiles string.
|
744 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
745 |
+
# Uses sparse matrix representation for graphs,
|
746 |
+
# For computational and speed efficiency.
|
747 |
+
|
748 |
+
self.inf_drugs_loader = DataLoader(self.inf_drugs,
|
749 |
+
shuffle=True,
|
750 |
+
batch_size=self.inf_batch_size,
|
751 |
+
drop_last=True) # PyG dataloader for the second GAN.
|
752 |
+
start_time = time.time()
|
753 |
+
#metric_calc_mol = []
|
754 |
+
metric_calc_dr = []
|
755 |
+
date = time.time()
|
756 |
+
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
|
757 |
+
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
|
758 |
+
with torch.inference_mode():
|
759 |
+
|
760 |
+
dataloader_iterator = iter(self.drugs_loader)
|
761 |
+
|
762 |
+
for i, data in enumerate(self.loader):
|
763 |
+
try:
|
764 |
+
drugs = next(dataloader_iterator)
|
765 |
+
except StopIteration:
|
766 |
+
dataloader_iterator = iter(self.drugs_loader)
|
767 |
+
drugs = next(dataloader_iterator)
|
768 |
+
|
769 |
+
# Preprocess both dataset
|
770 |
+
|
771 |
+
bulk_data = load_data(data,
|
772 |
+
drugs,
|
773 |
+
self.batch_size,
|
774 |
+
self.device,
|
775 |
+
self.b_dim,
|
776 |
+
self.m_dim,
|
777 |
+
self.drugs_b_dim,
|
778 |
+
self.drugs_m_dim,
|
779 |
+
self.z_dim,
|
780 |
+
self.vertexes)
|
781 |
+
|
782 |
+
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
783 |
+
|
784 |
+
if self.submodel == "CrossLoss":
|
785 |
+
GAN1_input_e = a_tensor
|
786 |
+
GAN1_input_x = x_tensor
|
787 |
+
GAN1_disc_e = drugs_a_tensor
|
788 |
+
GAN1_disc_x = drugs_x_tensor
|
789 |
+
GAN2_input_e = drugs_a_tensor
|
790 |
+
GAN2_input_x = drugs_x_tensor
|
791 |
+
GAN2_disc_e = a_tensor
|
792 |
+
GAN2_disc_x = x_tensor
|
793 |
+
elif self.submodel == "Ligand":
|
794 |
+
GAN1_input_e = a_tensor
|
795 |
+
GAN1_input_x = x_tensor
|
796 |
+
GAN1_disc_e = a_tensor
|
797 |
+
GAN1_disc_x = x_tensor
|
798 |
+
GAN2_input_e = drugs_a_tensor
|
799 |
+
GAN2_input_x = drugs_x_tensor
|
800 |
+
GAN2_disc_e = drugs_a_tensor
|
801 |
+
GAN2_disc_x = drugs_x_tensor
|
802 |
+
elif self.submodel == "Prot":
|
803 |
+
GAN1_input_e = a_tensor
|
804 |
+
GAN1_input_x = x_tensor
|
805 |
+
GAN1_disc_e = a_tensor
|
806 |
+
GAN1_disc_x = x_tensor
|
807 |
+
GAN2_input_e = akt1_human_adj
|
808 |
+
GAN2_input_x = akt1_human_annot
|
809 |
+
GAN2_disc_e = drugs_a_tensor
|
810 |
+
GAN2_disc_x = drugs_x_tensor
|
811 |
+
elif self.submodel == "RL":
|
812 |
+
GAN1_input_e = z_edge
|
813 |
+
GAN1_input_x = z_node
|
814 |
+
GAN1_disc_e = a_tensor
|
815 |
+
GAN1_disc_x = x_tensor
|
816 |
+
GAN2_input_e = drugs_a_tensor
|
817 |
+
GAN2_input_x = drugs_x_tensor
|
818 |
+
GAN2_disc_e = drugs_a_tensor
|
819 |
+
GAN2_disc_x = drugs_x_tensor
|
820 |
+
elif self.submodel == "NoTarget":
|
821 |
+
GAN1_input_e = z_edge
|
822 |
+
GAN1_input_x = z_node
|
823 |
+
GAN1_disc_e = a_tensor
|
824 |
+
GAN1_disc_x = x_tensor
|
825 |
+
# =================================================================================== #
|
826 |
+
# 2. GAN1 Inference #
|
827 |
+
# =================================================================================== #
|
828 |
+
generator_output = generator_loss(self.G,
|
829 |
+
self.D,
|
830 |
+
self.V,
|
831 |
+
GAN1_input_e,
|
832 |
+
GAN1_input_x,
|
833 |
+
self.batch_size,
|
834 |
+
sim_reward,
|
835 |
+
self.dataset.matrices2mol_drugs,
|
836 |
+
fps_r,
|
837 |
+
self.submodel)
|
838 |
+
|
839 |
+
_, fake_mol, _, _, node, edge = generator_output
|
840 |
+
|
841 |
+
# =================================================================================== #
|
842 |
+
# 3. GAN2 Inference #
|
843 |
+
# =================================================================================== #
|
844 |
+
|
845 |
+
output = generator2_loss(self.G2,
|
846 |
+
self.D2,
|
847 |
+
self.V2,
|
848 |
+
edge,
|
849 |
+
node,
|
850 |
+
self.batch_size,
|
851 |
+
sim_reward,
|
852 |
+
self.dataset.matrices2mol_drugs,
|
853 |
+
fps_r,
|
854 |
+
GAN2_input_e,
|
855 |
+
GAN2_input_x,
|
856 |
+
self.submodel)
|
857 |
+
|
858 |
+
_, fake_mol_g, _, _ = output
|
859 |
+
|
860 |
+
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
861 |
+
|
862 |
+
|
863 |
+
|
864 |
+
#inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
|
865 |
+
|
866 |
+
|
867 |
+
|
868 |
+
print("molecule batch {} inferred".format(i))
|
869 |
+
|
870 |
+
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
871 |
+
for molecules in inference_drugs:
|
872 |
+
|
873 |
+
f.write(molecules)
|
874 |
+
f.write("\n")
|
875 |
+
metric_calc_dr.append(molecules)
|
876 |
+
|
877 |
+
|
878 |
+
|
879 |
+
if i == 120:
|
880 |
+
break
|
881 |
+
|
882 |
+
et = time.time() - start_time
|
883 |
+
|
884 |
+
print("Inference mode is lasted for {:.2f} seconds".format(et))
|
885 |
+
|
886 |
+
print("Metrics calculation started using MOSES.")
|
887 |
+
|
888 |
+
print("Validity: ", fraction_valid(inference_drugs), "\n")
|
889 |
+
print("Uniqueness: ", fraction_unique(inference_drugs), "\n")
|
890 |
+
print("Validity: ", novelty(inference_drugs, drug_smiles), "\n")
|
891 |
+
|
892 |
+
print("Metrics are calculated.")
|
training_data.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch_geometric.utils as geoutils
|
3 |
+
from utils import *
|
4 |
+
|
5 |
+
def load_data(data, drugs, batch_size, device, b_dim, m_dim, drugs_b_dim, drugs_m_dim,z_dim,vertexes):
|
6 |
+
|
7 |
+
z = sample_z(batch_size, z_dim) # (batch,max_len)
|
8 |
+
|
9 |
+
z = torch.from_numpy(z).to(device).float().requires_grad_(True)
|
10 |
+
data = data.to(device)
|
11 |
+
drugs = drugs.to(device)
|
12 |
+
z_e = sample_z_edge(batch_size,vertexes,b_dim) # (batch,max_len,max_len)
|
13 |
+
z_n = sample_z_node(batch_size,vertexes,m_dim) # (batch,max_len)
|
14 |
+
z_edge = torch.from_numpy(z_e).to(device).float().requires_grad_(True) # Edge noise.(batch,max_len,max_len)
|
15 |
+
z_node = torch.from_numpy(z_n).to(device).float().requires_grad_(True) # Node noise.(batch,max_len)
|
16 |
+
a = geoutils.to_dense_adj(edge_index = data.edge_index,batch=data.batch,edge_attr=data.edge_attr, max_num_nodes=int(data.batch.shape[0]/batch_size))
|
17 |
+
x = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
|
18 |
+
|
19 |
+
a_tensor = label2onehot(a, b_dim, device)
|
20 |
+
#x_tensor = label2onehot(x, m_dim)
|
21 |
+
x_tensor = x
|
22 |
+
|
23 |
+
a_tensor = a_tensor #+ torch.randn([a_tensor.size(0), a_tensor.size(1), a_tensor.size(2),1], device=a_tensor.device) * noise_strength_0
|
24 |
+
x_tensor = x_tensor #+ torch.randn([x_tensor.size(0), x_tensor.size(1),1], device=x_tensor.device) * noise_strength_1
|
25 |
+
|
26 |
+
drugs_a = geoutils.to_dense_adj(edge_index = drugs.edge_index,batch=drugs.batch,edge_attr=drugs.edge_attr, max_num_nodes=int(drugs.batch.shape[0]/batch_size))
|
27 |
+
|
28 |
+
drugs_x = drugs.x.view(batch_size,int(drugs.batch.shape[0]/batch_size),-1)
|
29 |
+
|
30 |
+
drugs_a = drugs_a.to(device).long()
|
31 |
+
drugs_x = drugs_x.to(device)
|
32 |
+
drugs_a_tensor = label2onehot(drugs_a, drugs_b_dim,device).float()
|
33 |
+
drugs_x_tensor = drugs_x
|
34 |
+
|
35 |
+
drugs_a_tensor = drugs_a_tensor #+ torch.randn([drugs_a_tensor.size(0), drugs_a_tensor.size(1), drugs_a_tensor.size(2),1], device=drugs_a_tensor.device) * noise_strength_2
|
36 |
+
drugs_x_tensor = drugs_x_tensor #+ torch.randn([drugs_x_tensor.size(0), drugs_x_tensor.size(1),1], device=drugs_x_tensor.device) * noise_strength_3
|
37 |
+
#prot_n = akt1_human_annot[None,:].to(device).float()
|
38 |
+
#prot_e = akt1_human_adj[None,None,:].view(1,546,546,1).to(device).float()
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
a_tensor_vec = a_tensor.reshape(batch_size,-1)
|
43 |
+
x_tensor_vec = x_tensor.reshape(batch_size,-1)
|
44 |
+
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
|
45 |
+
|
46 |
+
a_drug_vec = drugs_a_tensor.reshape(batch_size,-1)
|
47 |
+
x_drug_vec = drugs_x_tensor.reshape(batch_size,-1)
|
48 |
+
drug_graphs = torch.concat((x_drug_vec,a_drug_vec),dim=-1)
|
49 |
+
|
50 |
+
return drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node
|
utils.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from statistics import mean
|
2 |
+
from rdkit import DataStructs
|
3 |
+
from rdkit import Chem
|
4 |
+
from rdkit.Chem import AllChem
|
5 |
+
from rdkit.Chem import Draw
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import seaborn as sns
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from matplotlib.lines import Line2D
|
11 |
+
from rdkit import RDLogger
|
12 |
+
import torch
|
13 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
14 |
+
import math
|
15 |
+
import time
|
16 |
+
import datetime
|
17 |
+
import re
|
18 |
+
RDLogger.DisableLog('rdApp.*')
|
19 |
+
import warnings
|
20 |
+
from multiprocessing import Pool
|
21 |
+
class Metrics(object):
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def valid(x):
|
25 |
+
return x is not None and Chem.MolToSmiles(x) != ''
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def tanimoto_sim_1v2(data1, data2):
|
29 |
+
min_len = data1.size if data1.size > data2.size else data2
|
30 |
+
sims = []
|
31 |
+
for i in range(min_len):
|
32 |
+
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
|
33 |
+
sims.append(sim)
|
34 |
+
mean_sim = mean(sim)
|
35 |
+
return mean_sim
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def mol_length(x):
|
39 |
+
if x is not None:
|
40 |
+
return len([char for char in max(Chem.MolToSmiles(x).split(sep =".")).upper() if char.isalpha()])
|
41 |
+
else:
|
42 |
+
return 0
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def max_component(data, max_len):
|
46 |
+
|
47 |
+
return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
|
48 |
+
|
49 |
+
def sim_reward(mol_gen, fps_r):
|
50 |
+
|
51 |
+
gen_scaf = []
|
52 |
+
|
53 |
+
for x in mol_gen:
|
54 |
+
if x is not None:
|
55 |
+
try:
|
56 |
+
|
57 |
+
gen_scaf.append(MurckoScaffold.GetScaffoldForMol(x))
|
58 |
+
except:
|
59 |
+
pass
|
60 |
+
|
61 |
+
if len(gen_scaf) == 0:
|
62 |
+
|
63 |
+
rew = 1
|
64 |
+
else:
|
65 |
+
fps = [Chem.RDKFingerprint(x) for x in gen_scaf]
|
66 |
+
|
67 |
+
|
68 |
+
fps = np.array(fps)
|
69 |
+
fps_r = np.array(fps_r)
|
70 |
+
|
71 |
+
rew = average_agg_tanimoto(fps_r,fps)[0]
|
72 |
+
if math.isnan(rew):
|
73 |
+
rew = 1
|
74 |
+
|
75 |
+
return rew ## change this to penalty
|
76 |
+
|
77 |
+
##########################################
|
78 |
+
##########################################
|
79 |
+
##########################################
|
80 |
+
|
81 |
+
def mols2grid_image(mols,path):
|
82 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
83 |
+
|
84 |
+
for i in range(len(mols)):
|
85 |
+
if Metrics.valid(mols[i]):
|
86 |
+
#if Chem.MolToSmiles(mols[i]) != '':
|
87 |
+
AllChem.Compute2DCoords(mols[i])
|
88 |
+
Draw.MolToFile(mols[i], os.path.join(path,"{}.png".format(i+1)), size=(1200,1200))
|
89 |
+
else:
|
90 |
+
continue
|
91 |
+
|
92 |
+
def save_smiles_matrices(mols,edges_hard, nodes_hard,path,data_source = None):
|
93 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
94 |
+
|
95 |
+
for i in range(len(mols)):
|
96 |
+
if Metrics.valid(mols[i]):
|
97 |
+
#m0= all_scores_for_print(mols[i], data_source, norm=False)
|
98 |
+
#if Chem.MolToSmiles(mols[i]) != '':
|
99 |
+
save_path = os.path.join(path,"{}.txt".format(i+1))
|
100 |
+
with open(save_path, "a") as f:
|
101 |
+
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n",fmt='%1.2f')
|
102 |
+
f.write("\n")
|
103 |
+
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:",fmt='%1.2f')
|
104 |
+
f.write("\n")
|
105 |
+
#f.write(m0)
|
106 |
+
f.write("\n")
|
107 |
+
|
108 |
+
|
109 |
+
print(Chem.MolToSmiles(mols[i]), file=open(save_path,"a"))
|
110 |
+
else:
|
111 |
+
continue
|
112 |
+
|
113 |
+
##########################################
|
114 |
+
##########################################
|
115 |
+
##########################################
|
116 |
+
|
117 |
+
def dense_to_sparse_with_attr(adj):
|
118 |
+
###
|
119 |
+
assert adj.dim() >= 2 and adj.dim() <= 3
|
120 |
+
assert adj.size(-1) == adj.size(-2)
|
121 |
+
|
122 |
+
index = adj.nonzero(as_tuple=True)
|
123 |
+
edge_attr = adj[index]
|
124 |
+
|
125 |
+
if len(index) == 3:
|
126 |
+
batch = index[0] * adj.size(-1)
|
127 |
+
index = (batch + index[1], batch + index[2])
|
128 |
+
#index = torch.stack(index, dim=0)
|
129 |
+
return index, edge_attr
|
130 |
+
|
131 |
+
|
132 |
+
def label2onehot(labels, dim, device):
|
133 |
+
|
134 |
+
"""Convert label indices to one-hot vectors."""
|
135 |
+
|
136 |
+
out = torch.zeros(list(labels.size())+[dim]).to(device)
|
137 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
138 |
+
|
139 |
+
return out.float()
|
140 |
+
|
141 |
+
|
142 |
+
def sample_z_node(batch_size, vertexes, nodes):
|
143 |
+
|
144 |
+
''' Random noise for nodes logits. '''
|
145 |
+
|
146 |
+
return np.random.normal(0,1, size=(batch_size,vertexes, nodes)) # 128, 9, 5
|
147 |
+
|
148 |
+
|
149 |
+
def sample_z_edge(batch_size, vertexes, edges):
|
150 |
+
|
151 |
+
''' Random noise for edges logits. '''
|
152 |
+
|
153 |
+
return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
|
154 |
+
|
155 |
+
def sample_z( batch_size, z_dim):
|
156 |
+
|
157 |
+
''' Random noise. '''
|
158 |
+
|
159 |
+
return np.random.normal(0,1, size=(batch_size,z_dim)) # 128, 9, 5
|
160 |
+
|
161 |
+
|
162 |
+
def mol_sample(sample_directory, model_name, mol, edges, nodes, idx, i):
|
163 |
+
sample_path = os.path.join(sample_directory,"{}-{}_{}-epoch_iteration".format(model_name,idx+1, i+1))
|
164 |
+
|
165 |
+
if not os.path.exists(sample_path):
|
166 |
+
os.makedirs(sample_path)
|
167 |
+
|
168 |
+
mols2grid_image(mol,sample_path)
|
169 |
+
|
170 |
+
save_smiles_matrices(mol,edges.detach(), nodes.detach(), sample_path)
|
171 |
+
|
172 |
+
if len(os.listdir(sample_path)) == 0:
|
173 |
+
os.rmdir(sample_path)
|
174 |
+
|
175 |
+
print("Valid molecules are saved.")
|
176 |
+
print("Valid matrices and smiles are saved")
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path):
|
183 |
+
|
184 |
+
gen_smiles = []
|
185 |
+
for line in mols:
|
186 |
+
if line is not None:
|
187 |
+
gen_smiles.append(Chem.MolToSmiles(line))
|
188 |
+
elif line is None:
|
189 |
+
gen_smiles.append(None)
|
190 |
+
|
191 |
+
#gen_smiles_saves = [None if x is None else re.sub('\*', '', x) for x in gen_smiles]
|
192 |
+
#gen_smiles_saves = [None if x is None else re.sub('\.', '', x) for x in gen_smiles_saves]
|
193 |
+
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
|
194 |
+
|
195 |
+
sample_save_dir = os.path.join(save_path, "samples-GAN{}.txt".format(model_num))
|
196 |
+
with open(sample_save_dir, "a") as f:
|
197 |
+
for idxs in range(len(gen_smiles_saves)):
|
198 |
+
if gen_smiles_saves[idxs] is not None:
|
199 |
+
|
200 |
+
f.write(gen_smiles_saves[idxs])
|
201 |
+
f.write("\n")
|
202 |
+
|
203 |
+
k = len(set(gen_smiles_saves) - {None})
|
204 |
+
|
205 |
+
|
206 |
+
et = time.time() - start_time
|
207 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
208 |
+
log = "Elapsed [{}], Epoch/Iteration [{}/{}] for GAN{}".format(et, idx, i+1, model_num)
|
209 |
+
|
210 |
+
# Log update
|
211 |
+
#m0 = get_all_metrics(gen = gen_smiles, train = train_smiles, batch_size=batch_size, k = valid_mol_num, device=self.device)
|
212 |
+
valid = fraction_valid(gen_smiles_saves)
|
213 |
+
unique = fraction_unique(gen_smiles_saves, k, check_validity=False)
|
214 |
+
novel = novelty(gen_smiles_saves, train_smiles)
|
215 |
+
|
216 |
+
#qed = [QED(mol) for mol in mols if mol is not None]
|
217 |
+
#sa = [SA(mol) for mol in mols if mol is not None]
|
218 |
+
#logp = [logP(mol) for mol in mols if mol is not None]
|
219 |
+
|
220 |
+
#IntDiv = internal_diversity(gen_smiles)
|
221 |
+
#m0= all_scores_val(fake_mol, mols, full_mols, full_smiles, vert, norm=True) # 'mols' is output of Fake Reward
|
222 |
+
#m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
|
223 |
+
#m0.update(m1)
|
224 |
+
|
225 |
+
#maxlen = MolecularMetrics.max_component(mols, 45)
|
226 |
+
|
227 |
+
#m0 = {k: np.array(v).mean() for k, v in m0.items()}
|
228 |
+
#loss.update(m0)
|
229 |
+
loss.update({'Valid': valid})
|
230 |
+
loss.update({'Unique@{}'.format(k): unique})
|
231 |
+
loss.update({'Novel': novel})
|
232 |
+
#loss.update({'QED': statistics.mean(qed)})
|
233 |
+
#loss.update({'SA': statistics.mean(sa)})
|
234 |
+
#loss.update({'LogP': statistics.mean(logp)})
|
235 |
+
#loss.update({'IntDiv': IntDiv})
|
236 |
+
|
237 |
+
#wandb.log({"maxlen": maxlen})
|
238 |
+
|
239 |
+
for tag, value in loss.items():
|
240 |
+
|
241 |
+
log += ", {}: {:.4f}".format(tag, value)
|
242 |
+
with open(log_path, "a") as f:
|
243 |
+
f.write(log)
|
244 |
+
f.write("\n")
|
245 |
+
print(log)
|
246 |
+
print("\n")
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
|
251 |
+
|
252 |
+
cols = 4
|
253 |
+
rows = int(heads/cols)
|
254 |
+
|
255 |
+
fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
|
256 |
+
axes = axes.flat
|
257 |
+
attentions_pos = attn_w[0]
|
258 |
+
attentions_pos = attentions_pos.cpu().detach().numpy()
|
259 |
+
for i,att in enumerate(attentions_pos):
|
260 |
+
|
261 |
+
#im = axes[i].imshow(att, cmap='gray')
|
262 |
+
sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
|
263 |
+
axes[i].set_title(f'head - {i} ')
|
264 |
+
axes[i].set_ylabel('layers')
|
265 |
+
pltsavedir = "/home/atabey/attn/second"
|
266 |
+
plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
267 |
+
|
268 |
+
|
269 |
+
def plot_grad_flow(named_parameters, model, iter, epoch):
|
270 |
+
|
271 |
+
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
|
272 |
+
'''Plots the gradients flowing through different layers in the net during training.
|
273 |
+
Can be used for checking for possible gradient vanishing / exploding problems.
|
274 |
+
|
275 |
+
Usage: Plug this function in Trainer class after loss.backwards() as
|
276 |
+
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
|
277 |
+
ave_grads = []
|
278 |
+
max_grads= []
|
279 |
+
layers = []
|
280 |
+
for n, p in named_parameters:
|
281 |
+
if(p.requires_grad) and ("bias" not in n):
|
282 |
+
print(p.grad,n)
|
283 |
+
layers.append(n)
|
284 |
+
ave_grads.append(p.grad.abs().mean().cpu())
|
285 |
+
max_grads.append(p.grad.abs().max().cpu())
|
286 |
+
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
|
287 |
+
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
|
288 |
+
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
|
289 |
+
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
|
290 |
+
plt.xlim(left=0, right=len(ave_grads))
|
291 |
+
plt.ylim(bottom = -0.001, top=1) # zoom in on the lower gradient regions
|
292 |
+
plt.xlabel("Layers")
|
293 |
+
plt.ylabel("average gradient")
|
294 |
+
plt.title("Gradient flow")
|
295 |
+
plt.grid(True)
|
296 |
+
plt.legend([Line2D([0], [0], color="c", lw=4),
|
297 |
+
Line2D([0], [0], color="b", lw=4),
|
298 |
+
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
299 |
+
pltsavedir = "/home/atabey/gradients/tryout"
|
300 |
+
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
301 |
+
|
302 |
+
"""
|
303 |
+
def _genDegree():
|
304 |
+
|
305 |
+
''' Generates the Degree distribution tensor for PNA, should be used everytime a different
|
306 |
+
dataset is used.
|
307 |
+
Can be called without arguments and saves the tensor for later use. If tensor was created
|
308 |
+
before, it just loads the degree tensor.
|
309 |
+
'''
|
310 |
+
|
311 |
+
degree_path = os.path.join(self.degree_dir, self.dataset_name + '-degree.pt')
|
312 |
+
if not os.path.exists(degree_path):
|
313 |
+
|
314 |
+
|
315 |
+
max_degree = -1
|
316 |
+
for data in self.dataset:
|
317 |
+
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
318 |
+
max_degree = max(max_degree, int(d.max()))
|
319 |
+
|
320 |
+
# Compute the in-degree histogram tensor
|
321 |
+
deg = torch.zeros(max_degree + 1, dtype=torch.long)
|
322 |
+
for data in self.dataset:
|
323 |
+
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
324 |
+
deg += torch.bincount(d, minlength=deg.numel())
|
325 |
+
torch.save(deg, 'DrugGEN/data/' + self.dataset_name + '-degree.pt')
|
326 |
+
else:
|
327 |
+
deg = torch.load(degree_path, map_location=lambda storage, loc: storage)
|
328 |
+
|
329 |
+
return deg
|
330 |
+
"""
|
331 |
+
def get_mol(smiles_or_mol):
|
332 |
+
'''
|
333 |
+
Loads SMILES/molecule into RDKit's object
|
334 |
+
'''
|
335 |
+
if isinstance(smiles_or_mol, str):
|
336 |
+
if len(smiles_or_mol) == 0:
|
337 |
+
return None
|
338 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
339 |
+
if mol is None:
|
340 |
+
return None
|
341 |
+
try:
|
342 |
+
Chem.SanitizeMol(mol)
|
343 |
+
except ValueError:
|
344 |
+
return None
|
345 |
+
return mol
|
346 |
+
return smiles_or_mol
|
347 |
+
|
348 |
+
def mapper(n_jobs):
|
349 |
+
'''
|
350 |
+
Returns function for map call.
|
351 |
+
If n_jobs == 1, will use standard map
|
352 |
+
If n_jobs > 1, will use multiprocessing pool
|
353 |
+
If n_jobs is a pool object, will return its map function
|
354 |
+
'''
|
355 |
+
if n_jobs == 1:
|
356 |
+
def _mapper(*args, **kwargs):
|
357 |
+
return list(map(*args, **kwargs))
|
358 |
+
|
359 |
+
return _mapper
|
360 |
+
if isinstance(n_jobs, int):
|
361 |
+
pool = Pool(n_jobs)
|
362 |
+
|
363 |
+
def _mapper(*args, **kwargs):
|
364 |
+
try:
|
365 |
+
result = pool.map(*args, **kwargs)
|
366 |
+
finally:
|
367 |
+
pool.terminate()
|
368 |
+
return result
|
369 |
+
|
370 |
+
return _mapper
|
371 |
+
return n_jobs.map
|
372 |
+
def remove_invalid(gen, canonize=True, n_jobs=1):
|
373 |
+
"""
|
374 |
+
Removes invalid molecules from the dataset
|
375 |
+
"""
|
376 |
+
if not canonize:
|
377 |
+
mols = mapper(n_jobs)(get_mol, gen)
|
378 |
+
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
379 |
+
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
380 |
+
x is not None]
|
381 |
+
def fraction_valid(gen, n_jobs=1):
|
382 |
+
"""
|
383 |
+
Computes a number of valid molecules
|
384 |
+
Parameters:
|
385 |
+
gen: list of SMILES
|
386 |
+
n_jobs: number of threads for calculation
|
387 |
+
"""
|
388 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
389 |
+
return 1 - gen.count(None) / len(gen)
|
390 |
+
def canonic_smiles(smiles_or_mol):
|
391 |
+
mol = get_mol(smiles_or_mol)
|
392 |
+
if mol is None:
|
393 |
+
return None
|
394 |
+
return Chem.MolToSmiles(mol)
|
395 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
396 |
+
"""
|
397 |
+
Computes a number of unique molecules
|
398 |
+
Parameters:
|
399 |
+
gen: list of SMILES
|
400 |
+
k: compute unique@k
|
401 |
+
n_jobs: number of threads for calculation
|
402 |
+
check_validity: raises ValueError if invalid molecules are present
|
403 |
+
"""
|
404 |
+
if k is not None:
|
405 |
+
if len(gen) < k:
|
406 |
+
warnings.warn(
|
407 |
+
"Can't compute unique@{}.".format(k) +
|
408 |
+
"gen contains only {} molecules".format(len(gen))
|
409 |
+
)
|
410 |
+
gen = gen[:k]
|
411 |
+
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
412 |
+
if None in canonic and check_validity:
|
413 |
+
raise ValueError("Invalid molecule passed to unique@k")
|
414 |
+
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
415 |
+
|
416 |
+
def novelty(gen, train, n_jobs=1):
|
417 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
418 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
419 |
+
train_set = set(train)
|
420 |
+
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
425 |
+
batch_size=5000, agg='max',
|
426 |
+
device='cpu', p=1):
|
427 |
+
"""
|
428 |
+
For each molecule in gen_vecs finds closest molecule in stock_vecs.
|
429 |
+
Returns average tanimoto score for between these molecules
|
430 |
+
|
431 |
+
Parameters:
|
432 |
+
stock_vecs: numpy array <n_vectors x dim>
|
433 |
+
gen_vecs: numpy array <n_vectors' x dim>
|
434 |
+
agg: max or mean
|
435 |
+
p: power for averaging: (mean x^p)^(1/p)
|
436 |
+
"""
|
437 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
438 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
439 |
+
total = np.zeros(len(gen_vecs))
|
440 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
441 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
442 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
443 |
+
|
444 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
445 |
+
y_gen = y_gen.transpose(0, 1)
|
446 |
+
tp = torch.mm(x_stock, y_gen)
|
447 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) +
|
448 |
+
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
449 |
+
jac[np.isnan(jac)] = 1
|
450 |
+
if p != 1:
|
451 |
+
jac = jac**p
|
452 |
+
if agg == 'max':
|
453 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
454 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
455 |
+
elif agg == 'mean':
|
456 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
457 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
458 |
+
if agg == 'mean':
|
459 |
+
agg_tanimoto /= total
|
460 |
+
if p != 1:
|
461 |
+
agg_tanimoto = (agg_tanimoto)**(1/p)
|
462 |
+
return np.mean(agg_tanimoto)
|