File size: 2,833 Bytes
e61c431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch

# Gradient Penalty Calculation - Calculate Gradient of Critic Score
def gradient_penalty(critic, labels, real_images, fake_images, epsilon):
    """
    This function calculates the gradient penalty for the WGAN-GP loss function

    Parameters:
    critic (nn.Module): The critic model
    labels (torch.tensor): The labels for the images
    real_images (torch.tensor): The real images
    fake_images (torch.tensor): The fake images
    epsilon (torch.tensor): The interpolation parameter

    Returns:
    gradient_penalty (torch.tensor): The gradient penalty for the critic model
    """

    # Create the interpolated images as a weighted combination of real and fake images
    interpolated_images = real_images * epsilon + fake_images * (1 - epsilon)

    mixed_scores = critic(interpolated_images, labels)
    
    create_real_label = torch.ones_like(mixed_scores)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=create_real_label,
        create_graph=True,
        retain_graph=True,
    )[0]
    
    # Reshape each image in the batch into a 1D tensor (flatten the images)
    gradient = gradient.view(len(gradient), -1)

    # Calculate the L2 norm of the gradients
    gradient_norm = gradient.norm(2, dim=1)

    # Calculate the penalty as the mean squared distance of the norms from 1
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    
    return gradient_penalty

# Critic Loss Calculation
def calculate_critic_loss(critic_fake_prediction, critic_real_prediction, gradient_penalty, critic_lambda):
    """
    Calculates the critic loss, which is the difference between the mean of the real scores and mean of the fake scores plus the gradient penalty.

    Parameters:
    critic_fake_prediction (torch.tensor): The critic predictions for the fake images
    critic_real_prediction (torch.tensor): The critic predictions for the real images
    gradient_penalty (torch.tensor): The gradient penalty for the critic model
    critic_lambda (float): The coefficient for the gradient penalty

    Returns:
    critic_loss (torch.tensor): The critic loss
    """
    critic_loss = (
        -(torch.mean(critic_real_prediction) - torch.mean(critic_fake_prediction))  + critic_lambda * gradient_penalty
    )
    
    return critic_loss

# Generator Loss Calculation
def calculate_generator_loss(critic_fake_prediction):
    """
    Calculates the generator loss, which is the mean of the critic predictions for the fake images with a negative sign.

    Parameters:
    critic_fake_prediction (torch.tensor): The critic predictions for the fake images

    Returns:
    generator_loss (torch.tensor): The generator loss
    """
    generator_loss = -1.0 * torch.mean(critic_fake_prediction)
    return generator_loss