File size: 5,027 Bytes
21a662b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.optim as optim
import lightning as L
from .base import Discriminator, Generator

class ConditionalWGAN_GP(L.LightningModule):
    """Conditional WGAN-GP implementation using PyTorch Lightning.

    Attributes:
        image_size: Size of the generated images.
        critic_repeats: Number of critic iterations per generator iteration.
        c_lambda: Gradient penalty lambda hyperparameter.
        generator: The generator model.
        critic: The discriminator (critic) model.
        critic_losses: List to store critic loss values.
        generator_losses: List to store generator loss values.
        curr_step: The current training step.
        fixed_latent_space: Fixed latent vectors for generating consistent images.
        fixed_label: Fixed labels corresponding to the latent vectors.
    """

    def __init__(self, image_size, learning_rate, z_dim, embed_size, num_classes,
                 critic_repeats, feature_gen, feature_critic, c_lambda, beta_1,
                 beta_2, display_step):
        """Initializes the Conditional WGAN-GP model.

        Args:
            image_size: Size of the generated images.
            learning_rate: Learning rate for the optimizers.
            z_dim: Dimension of the latent space.
            embed_size: Size of the embedding for the labels.
            num_classes: Number of classes for the conditional generation.
            critic_repeats: Number of critic iterations per generator iteration.
            feature_gen: Number of features for the generator.
            feature_critic: Number of features for the critic.
            c_lambda: Gradient penalty lambda hyperparameter.
            beta_1: Beta1 parameter for the Adam optimizer.
            beta_2: Beta2 parameter for the Adam optimizer.
            display_step: Step interval for displaying generated images.
        """
        super().__init__()

        self.automatic_optimization = False

        self.image_size = image_size
        self.critic_repeats = critic_repeats
        self.c_lambda = c_lambda

        self.generator = Generator(
            embed_size=embed_size,
            num_classes=num_classes,
            image_size=image_size,
            features_generator=feature_gen,
            input_dim=z_dim,
        )

        self.critic = Discriminator(
            num_classes=num_classes,
            image_size=image_size,
            features_discriminator=feature_critic,
        )

        self.critic_losses = []
        self.generator_losses = []
        self.curr_step = 0

        self.fixed_latent_space = torch.randn(25, z_dim, 1, 1)
        self.fixed_label = torch.tensor([i % num_classes for i in range(25)])

        self.save_hyperparameters()

    def configure_optimizers(self):
        """Configures the optimizers for the generator and critic.

        Returns:
            A tuple of two Adam optimizers, one for the generator and one for the critic.
        """
        optimizer_g = optim.Adam(
            self.generator.parameters(),
            lr=self.hparams.learning_rate,
            betas=(self.hparams.beta_1, self.hparams.beta_2),
        )
        optimizer_c = optim.Adam(
            self.critic.parameters(),
            lr=self.hparams.learning_rate,
            betas=(self.hparams.beta_1, self.hparams.beta_2),
        )

        return optimizer_g, optimizer_c

    def on_load_checkpoint(self, checkpoint):
        """Loads necessary variables from a checkpoint.

        Args:
            checkpoint: The checkpoint dictionary.
        """
        if self.current_epoch != 0:
            self.critic_losses = checkpoint['critic_losses']
            self.generator_losses = checkpoint['generator_losses']
            self.curr_step = checkpoint['curr_step']
            self.fixed_latent_space = checkpoint['fixed_latent_space']
            self.fixed_label = checkpoint['fixed_label']

    def on_save_checkpoint(self, checkpoint):
        """Saves necessary variables to a checkpoint.

        Args:
            checkpoint: The checkpoint dictionary.
        """
        checkpoint['critic_losses'] = self.critic_losses
        checkpoint['generator_losses'] = self.generator_losses
        checkpoint['curr_step'] = self.curr_step
        checkpoint['fixed_latent_space'] = self.fixed_latent_space
        checkpoint['fixed_label'] = self.fixed_label

    def forward(self, noise, labels):
        """Generates an image given noise and labels.

        Args:
            noise: Latent noise vector.
            labels: Class labels for conditional generation.

        Returns:
            Generated image tensor.
        """
        return self.generator(noise, labels)

    def predict_step(self, noise, labels):
        """Predicts an image given noise and labels.

        Args:
            noise: Latent noise vector.
            labels: Class labels for conditional generation.

        Returns:
            Generated image tensor.
        """
        return self.generator(noise, labels)