Spaces:
Running
Running
Update loss.py
Browse files
loss.py
CHANGED
@@ -34,7 +34,7 @@ def discriminator_loss(generator, discriminator, mol_graph, adj, annot, batch_si
|
|
34 |
return node, edge,d_loss
|
35 |
|
36 |
|
37 |
-
def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel):
|
38 |
|
39 |
# Compute loss with fake molecules.
|
40 |
|
@@ -53,7 +53,7 @@ def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty,
|
|
53 |
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
54 |
g_nodes_hat_sample = torch.max(node_sample , -1)[1]
|
55 |
|
56 |
-
fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
|
57 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
58 |
g_loss = prediction_fake
|
59 |
# Compute penalty loss.
|
@@ -116,7 +116,7 @@ def discriminator2_loss(generator, discriminator, mol_graph, adj, annot, batch_s
|
|
116 |
|
117 |
return d2_loss
|
118 |
|
119 |
-
def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel):
|
120 |
|
121 |
# Generate molecules.
|
122 |
|
@@ -140,7 +140,7 @@ def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty
|
|
140 |
g2_loss_fake = - torch.mean(g_tra_logits_fake2)
|
141 |
|
142 |
# Reward
|
143 |
-
fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
|
144 |
for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]
|
145 |
g2_loss = g2_loss_fake
|
146 |
if submodel == "RL":
|
|
|
34 |
return node, edge,d_loss
|
35 |
|
36 |
|
37 |
+
def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel, dataset_name):
|
38 |
|
39 |
# Compute loss with fake molecules.
|
40 |
|
|
|
53 |
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
54 |
g_nodes_hat_sample = torch.max(node_sample , -1)[1]
|
55 |
|
56 |
+
fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
|
57 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
58 |
g_loss = prediction_fake
|
59 |
# Compute penalty loss.
|
|
|
116 |
|
117 |
return d2_loss
|
118 |
|
119 |
+
def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel, drugs_name):
|
120 |
|
121 |
# Generate molecules.
|
122 |
|
|
|
140 |
g2_loss_fake = - torch.mean(g_tra_logits_fake2)
|
141 |
|
142 |
# Reward
|
143 |
+
fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=drugs_name)
|
144 |
for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]
|
145 |
g2_loss = g2_loss_fake
|
146 |
if submodel == "RL":
|