osbm commited on
Commit
9f5b1d1
·
1 Parent(s): 24a14ee

Update loss.py

Browse files
Files changed (1) hide show
  1. loss.py +4 -4
loss.py CHANGED
@@ -34,7 +34,7 @@ def discriminator_loss(generator, discriminator, mol_graph, adj, annot, batch_si
34
  return node, edge,d_loss
35
 
36
 
37
- def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel):
38
 
39
  # Compute loss with fake molecules.
40
 
@@ -53,7 +53,7 @@ def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty,
53
  g_edges_hat_sample = torch.max(edge_sample, -1)[1]
54
  g_nodes_hat_sample = torch.max(node_sample , -1)[1]
55
 
56
- fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
57
  for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
58
  g_loss = prediction_fake
59
  # Compute penalty loss.
@@ -116,7 +116,7 @@ def discriminator2_loss(generator, discriminator, mol_graph, adj, annot, batch_s
116
 
117
  return d2_loss
118
 
119
- def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel):
120
 
121
  # Generate molecules.
122
 
@@ -140,7 +140,7 @@ def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty
140
  g2_loss_fake = - torch.mean(g_tra_logits_fake2)
141
 
142
  # Reward
143
- fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
144
  for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]
145
  g2_loss = g2_loss_fake
146
  if submodel == "RL":
 
34
  return node, edge,d_loss
35
 
36
 
37
+ def generator_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,submodel, dataset_name):
38
 
39
  # Compute loss with fake molecules.
40
 
 
53
  g_edges_hat_sample = torch.max(edge_sample, -1)[1]
54
  g_nodes_hat_sample = torch.max(node_sample , -1)[1]
55
 
56
+ fake_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
57
  for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
58
  g_loss = prediction_fake
59
  # Compute penalty loss.
 
116
 
117
  return d2_loss
118
 
119
+ def generator2_loss(generator, discriminator, v, adj, annot, batch_size, penalty, matrices2mol, fps_r,ak1_adj,akt1_annot, submodel, drugs_name):
120
 
121
  # Generate molecules.
122
 
 
140
  g2_loss_fake = - torch.mean(g_tra_logits_fake2)
141
 
142
  # Reward
143
+ fake_mol_g = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=drugs_name)
144
  for e_, n_ in zip(dr_g_edges_hat_sample, dr_g_nodes_hat_sample)]
145
  g2_loss = g2_loss_fake
146
  if submodel == "RL":