File size: 2,992 Bytes
4c9e6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch


def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device):
    """
    Calculate gradient penalty for WGAN-GP.
    
    Args:
        discriminator: The discriminator model
        real_node: Real node features
        real_edge: Real edge features
        fake_node: Generated node features
        fake_edge: Generated edge features
        batch_size: Batch size
        device: Device to compute on
        
    Returns:
        Gradient penalty term
    """
    # Generate random interpolation factors
    eps_edge = torch.rand(batch_size, 1, 1, 1, device=device)
    eps_node = torch.rand(batch_size, 1, 1, device=device)
    
    # Create interpolated samples
    int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True)
    int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True)

    logits_interpolated = discriminator(int_edge, int_node)

    # Calculate gradients for both node and edge inputs
    weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device)
    gradients = torch.autograd.grad(
        outputs=logits_interpolated,
        inputs=[int_node, int_edge],
        grad_outputs=weight,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )

    # Combine gradients from both inputs
    gradients_node = gradients[0].view(batch_size, -1)
    gradients_edge = gradients[1].view(batch_size, -1)
    gradients = torch.cat([gradients_node, gradients_edge], dim=1)

    # Calculate gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty


def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp):
    # Compute loss for drugs
    logits_real_disc = discriminator(drug_adj, drug_annot)

    # Use mean reduction for more stable training
    prediction_real = -torch.mean(logits_real_disc)

    # Compute loss for generated molecules
    node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)

    logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach())

    prediction_fake = torch.mean(logits_fake_disc)

    # Compute gradient penalty using the new function
    gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device)

    # Calculate total discriminator loss
    d_loss = prediction_fake + prediction_real + lambda_gp * gp

    return node, edge, d_loss


def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size):
    # Generate fake molecules
    node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)

    # Compute logits for fake molecules
    logits_fake_disc = discriminator(edge_sample, node_sample)

    prediction_fake = -torch.mean(logits_fake_disc)
    g_loss = prediction_fake

    return g_loss, node, edge, node_sample, edge_sample