import torch import torch.optim as optim import lightning as L from .base import Discriminator, Generator class ConditionalWGAN_GP(L.LightningModule): """Conditional WGAN-GP implementation using PyTorch Lightning. Attributes: image_size: Size of the generated images. critic_repeats: Number of critic iterations per generator iteration. c_lambda: Gradient penalty lambda hyperparameter. generator: The generator model. critic: The discriminator (critic) model. critic_losses: List to store critic loss values. generator_losses: List to store generator loss values. curr_step: The current training step. fixed_latent_space: Fixed latent vectors for generating consistent images. fixed_label: Fixed labels corresponding to the latent vectors. """ def __init__(self, image_size, learning_rate, z_dim, embed_size, num_classes, critic_repeats, feature_gen, feature_critic, c_lambda, beta_1, beta_2, display_step): """Initializes the Conditional WGAN-GP model. Args: image_size: Size of the generated images. learning_rate: Learning rate for the optimizers. z_dim: Dimension of the latent space. embed_size: Size of the embedding for the labels. num_classes: Number of classes for the conditional generation. critic_repeats: Number of critic iterations per generator iteration. feature_gen: Number of features for the generator. feature_critic: Number of features for the critic. c_lambda: Gradient penalty lambda hyperparameter. beta_1: Beta1 parameter for the Adam optimizer. beta_2: Beta2 parameter for the Adam optimizer. display_step: Step interval for displaying generated images. """ super().__init__() self.automatic_optimization = False self.image_size = image_size self.critic_repeats = critic_repeats self.c_lambda = c_lambda self.generator = Generator( embed_size=embed_size, num_classes=num_classes, image_size=image_size, features_generator=feature_gen, input_dim=z_dim, ) self.critic = Discriminator( num_classes=num_classes, image_size=image_size, features_discriminator=feature_critic, ) self.critic_losses = [] self.generator_losses = [] self.curr_step = 0 self.fixed_latent_space = torch.randn(25, z_dim, 1, 1) self.fixed_label = torch.tensor([i % num_classes for i in range(25)]) self.save_hyperparameters() def configure_optimizers(self): """Configures the optimizers for the generator and critic. Returns: A tuple of two Adam optimizers, one for the generator and one for the critic. """ optimizer_g = optim.Adam( self.generator.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2), ) optimizer_c = optim.Adam( self.critic.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2), ) return optimizer_g, optimizer_c def on_load_checkpoint(self, checkpoint): """Loads necessary variables from a checkpoint. Args: checkpoint: The checkpoint dictionary. """ if self.current_epoch != 0: self.critic_losses = checkpoint['critic_losses'] self.generator_losses = checkpoint['generator_losses'] self.curr_step = checkpoint['curr_step'] self.fixed_latent_space = checkpoint['fixed_latent_space'] self.fixed_label = checkpoint['fixed_label'] def on_save_checkpoint(self, checkpoint): """Saves necessary variables to a checkpoint. Args: checkpoint: The checkpoint dictionary. """ checkpoint['critic_losses'] = self.critic_losses checkpoint['generator_losses'] = self.generator_losses checkpoint['curr_step'] = self.curr_step checkpoint['fixed_latent_space'] = self.fixed_latent_space checkpoint['fixed_label'] = self.fixed_label def forward(self, noise, labels): """Generates an image given noise and labels. Args: noise: Latent noise vector. labels: Class labels for conditional generation. Returns: Generated image tensor. """ return self.generator(noise, labels) def predict_step(self, noise, labels): """Predicts an image given noise and labels. Args: noise: Latent noise vector. labels: Class labels for conditional generation. Returns: Generated image tensor. """ return self.generator(noise, labels)