Spaces:
Running
Running
import torch | |
def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device): | |
""" | |
Calculate gradient penalty for WGAN-GP. | |
Args: | |
discriminator: The discriminator model | |
real_node: Real node features | |
real_edge: Real edge features | |
fake_node: Generated node features | |
fake_edge: Generated edge features | |
batch_size: Batch size | |
device: Device to compute on | |
Returns: | |
Gradient penalty term | |
""" | |
# Generate random interpolation factors | |
eps_edge = torch.rand(batch_size, 1, 1, 1, device=device) | |
eps_node = torch.rand(batch_size, 1, 1, device=device) | |
# Create interpolated samples | |
int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True) | |
int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True) | |
logits_interpolated = discriminator(int_edge, int_node) | |
# Calculate gradients for both node and edge inputs | |
weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device) | |
gradients = torch.autograd.grad( | |
outputs=logits_interpolated, | |
inputs=[int_node, int_edge], | |
grad_outputs=weight, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True | |
) | |
# Combine gradients from both inputs | |
gradients_node = gradients[0].view(batch_size, -1) | |
gradients_edge = gradients[1].view(batch_size, -1) | |
gradients = torch.cat([gradients_node, gradients_edge], dim=1) | |
# Calculate gradient penalty | |
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() | |
return gradient_penalty | |
def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp): | |
# Compute loss for drugs | |
logits_real_disc = discriminator(drug_adj, drug_annot) | |
# Use mean reduction for more stable training | |
prediction_real = -torch.mean(logits_real_disc) | |
# Compute loss for generated molecules | |
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) | |
logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach()) | |
prediction_fake = torch.mean(logits_fake_disc) | |
# Compute gradient penalty using the new function | |
gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device) | |
# Calculate total discriminator loss | |
d_loss = prediction_fake + prediction_real + lambda_gp * gp | |
return node, edge, d_loss | |
def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size): | |
# Generate fake molecules | |
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot) | |
# Compute logits for fake molecules | |
logits_fake_disc = discriminator(edge_sample, node_sample) | |
prediction_fake = -torch.mean(logits_fake_disc) | |
g_loss = prediction_fake | |
return g_loss, node, edge, node_sample, edge_sample |