Spaces:
Sleeping
Sleeping
File size: 5,027 Bytes
21a662b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.optim as optim
import lightning as L
from .base import Discriminator, Generator
class ConditionalWGAN_GP(L.LightningModule):
"""Conditional WGAN-GP implementation using PyTorch Lightning.
Attributes:
image_size: Size of the generated images.
critic_repeats: Number of critic iterations per generator iteration.
c_lambda: Gradient penalty lambda hyperparameter.
generator: The generator model.
critic: The discriminator (critic) model.
critic_losses: List to store critic loss values.
generator_losses: List to store generator loss values.
curr_step: The current training step.
fixed_latent_space: Fixed latent vectors for generating consistent images.
fixed_label: Fixed labels corresponding to the latent vectors.
"""
def __init__(self, image_size, learning_rate, z_dim, embed_size, num_classes,
critic_repeats, feature_gen, feature_critic, c_lambda, beta_1,
beta_2, display_step):
"""Initializes the Conditional WGAN-GP model.
Args:
image_size: Size of the generated images.
learning_rate: Learning rate for the optimizers.
z_dim: Dimension of the latent space.
embed_size: Size of the embedding for the labels.
num_classes: Number of classes for the conditional generation.
critic_repeats: Number of critic iterations per generator iteration.
feature_gen: Number of features for the generator.
feature_critic: Number of features for the critic.
c_lambda: Gradient penalty lambda hyperparameter.
beta_1: Beta1 parameter for the Adam optimizer.
beta_2: Beta2 parameter for the Adam optimizer.
display_step: Step interval for displaying generated images.
"""
super().__init__()
self.automatic_optimization = False
self.image_size = image_size
self.critic_repeats = critic_repeats
self.c_lambda = c_lambda
self.generator = Generator(
embed_size=embed_size,
num_classes=num_classes,
image_size=image_size,
features_generator=feature_gen,
input_dim=z_dim,
)
self.critic = Discriminator(
num_classes=num_classes,
image_size=image_size,
features_discriminator=feature_critic,
)
self.critic_losses = []
self.generator_losses = []
self.curr_step = 0
self.fixed_latent_space = torch.randn(25, z_dim, 1, 1)
self.fixed_label = torch.tensor([i % num_classes for i in range(25)])
self.save_hyperparameters()
def configure_optimizers(self):
"""Configures the optimizers for the generator and critic.
Returns:
A tuple of two Adam optimizers, one for the generator and one for the critic.
"""
optimizer_g = optim.Adam(
self.generator.parameters(),
lr=self.hparams.learning_rate,
betas=(self.hparams.beta_1, self.hparams.beta_2),
)
optimizer_c = optim.Adam(
self.critic.parameters(),
lr=self.hparams.learning_rate,
betas=(self.hparams.beta_1, self.hparams.beta_2),
)
return optimizer_g, optimizer_c
def on_load_checkpoint(self, checkpoint):
"""Loads necessary variables from a checkpoint.
Args:
checkpoint: The checkpoint dictionary.
"""
if self.current_epoch != 0:
self.critic_losses = checkpoint['critic_losses']
self.generator_losses = checkpoint['generator_losses']
self.curr_step = checkpoint['curr_step']
self.fixed_latent_space = checkpoint['fixed_latent_space']
self.fixed_label = checkpoint['fixed_label']
def on_save_checkpoint(self, checkpoint):
"""Saves necessary variables to a checkpoint.
Args:
checkpoint: The checkpoint dictionary.
"""
checkpoint['critic_losses'] = self.critic_losses
checkpoint['generator_losses'] = self.generator_losses
checkpoint['curr_step'] = self.curr_step
checkpoint['fixed_latent_space'] = self.fixed_latent_space
checkpoint['fixed_label'] = self.fixed_label
def forward(self, noise, labels):
"""Generates an image given noise and labels.
Args:
noise: Latent noise vector.
labels: Class labels for conditional generation.
Returns:
Generated image tensor.
"""
return self.generator(noise, labels)
def predict_step(self, noise, labels):
"""Predicts an image given noise and labels.
Args:
noise: Latent noise vector.
labels: Class labels for conditional generation.
Returns:
Generated image tensor.
"""
return self.generator(noise, labels)
|