import torch import lightning as L import torch.optim as optim from models.generator import Generator from models.discriminator import Discriminator from utility.helper import initialize_weights, plot_images_from_tensor from utility.wgan_gp import gradient_penalty, calculate_generator_loss, calculate_critic_loss class ConditionalWGAN_GP(L.LightningModule): def __init__(self, image_channel, label_channel, image_size, learning_rate, z_dim, embed_size, num_classes, critic_repeats, feature_gen, feature_critic, c_lambda, beta_1, beta_2, display_step): super().__init__() self.automatic_optimization = False self.image_size = image_size self.critic_repeats = critic_repeats self.c_lambda = c_lambda self.generator = Generator( embed_size=embed_size, num_classes=num_classes, image_size=image_size, features_generator=feature_gen, input_dim=z_dim, ) self.critic = Discriminator( num_classes=num_classes, embed_size=embed_size, image_size=image_size, features_discriminator=feature_critic, image_channel=image_channel, label_channel=label_channel, ) self.critic_losses = [] self.generator_losses = [] self.curr_step = 0 self.fixed_latent_space = torch.randn(25, z_dim, 1, 1) self.fixed_label = torch.tensor([i % num_classes for i in range(25)]) self.save_hyperparameters() def configure_optimizers(self): # READ: https://lightning.ai/docs/pytorch/stable/common/optimization.html#use-multiple-optimizers-like-gans # READ: https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html # READ: https://lightning.ai/docs/pytorch/stable/model/build_model_advanced.html # READ: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.backward # READ: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#manual-backward optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2)) optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2)) return optimizer_G, optimizer_C def on_load_checkpoint(self, checkpoint): # List of keys that you expect to load from the checkpoint keys_to_load = ['critic_losses', 'generator_losses', 'curr_step', 'fixed_latent_space', 'fixed_label'] # Iterate over the keys and load them if they exist in the checkpoint for key in keys_to_load: if key in checkpoint: setattr(self, key, checkpoint[key]) def on_save_checkpoint(self, checkpoint): # Save necessary variable to checkpoint checkpoint['critic_losses'] = self.critic_losses checkpoint['generator_losses'] = self.generator_losses checkpoint['curr_step'] = self.curr_step checkpoint['fixed_latent_space'] = self.fixed_latent_space checkpoint['fixed_label'] = self.fixed_label def on_train_start(self): if self.current_epoch == 0: self.generator.apply(initialize_weights) self.critic.apply(initialize_weights) def training_step(self, batch, batch_idx): # Get the Optimizers opt_generator, opt_critic = self.optimizers() # Get Data and Label X, labels = batch # Get the current batch size batch_size = X.shape[0] ############################## # Train Critic ############### ############################## mean_critic_loss_for_this_iteration = 0 for _ in range(self.critic_repeats): # Clean the Gradient opt_critic.zero_grad() # Generate the noise. noise = torch.randn(batch_size, self.hparams.z_dim, device=self.device) # Generate fake image. fake = self.generator(noise, labels) # Get the Critic's prediction on the reals and fakes critic_fake_pred = self.critic(fake.detach(), labels) critic_real_pred = self.critic(X, labels) # Calculate the Critic loss using WGAN # Generate epsilon for interpolate image. epsilon = torch.rand(batch_size, 1, 1, 1, device=self.device, requires_grad=True) # Calculate Gradient Penalty Critic model gp = gradient_penalty(self.critic, labels, X, fake.detach(), epsilon) # calculate full of WGAN-GP loss for Critic critic_loss = calculate_critic_loss( critic_fake_pred, critic_real_pred, gp, self.c_lambda ) # Keep track of the average critic loss in this batch mean_critic_loss_for_this_iteration += critic_loss.item() / self.critic_repeats # Update the gradients Criticz # self.manual_backward(critic_loss, retain_graph=True) self.manual_backward(critic_loss) # no need retain graph cause, already detach() on the image, so it will cut from backpropagate. use that retain_graph=True if not using detach() # Update the optimizer opt_critic.step() ############################## # Train Generator ############ ############################## # Clean the gradient opt_generator.zero_grad() # Generate the noise. noise = torch.randn(batch_size, self.hparams.z_dim, device=self.device) # Generate fake image. fake = self.generator(noise, labels) # Get the Critic's prediction on the fakes by generator generator_fake_predictions = self.critic(fake, labels) # Calculate loss for Generator generator_loss = calculate_generator_loss(generator_fake_predictions) # update the gradient generator self.manual_backward(generator_loss) # Update the optimizer opt_generator.step() ############################## # Visualization ############## ############################## if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0: VISUALIZE = True if VISUALIZE: with torch.no_grad(): fake_images_fixed = self.generator( self.fixed_latent_space.to(self.device), self.fixed_label.to(self.device) ) path_save = f"/kaggle/working/generates/generated-{self.curr_step}-step.png" plot_images_from_tensor(fake_images_fixed, size=(3, self.image_size, self.image_size), show=False, save_path=path_save) plot_images_from_tensor(X, size=(3, self.image_size, self.image_size), show=False) print(f" ==== Critic Loss: {mean_critic_loss_for_this_iteration} ==== ") print(f" ==== Generator Loss: {generator_loss.item()} ==== ") self.curr_step += 1 ############################## # Logging #################### ############################## # Store the loss Critic into Log self.log("critic_loss", mean_critic_loss_for_this_iteration, on_step=False, on_epoch=True, prog_bar=True) self.log("generator_loss", generator_loss.item(), on_step=False, on_epoch=True, prog_bar=True) # store into list, so can used later for visualization self.critic_losses.append(mean_critic_loss_for_this_iteration) self.generator_losses.append(generator_loss.item()) def forward(self, noise, labels): return self.generator(noise, labels) def predict_step(self, noise, labels): return self.generator(noise, labels)