Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import lightning as L | |
import torch.optim as optim | |
from models.generator import Generator | |
from models.discriminator import Discriminator | |
from utility.helper import save_some_examples | |
class Pix2Pix(L.LightningModule): | |
def __init__(self, in_channels, learning_rate, l1_lambda, features_generator, features_discriminator, display_step): | |
super().__init__() | |
self.automatic_optimization = False | |
self.gen = Generator( | |
in_channels=in_channels, | |
features=features_generator | |
) | |
self.disc = Discriminator( | |
in_channels=in_channels, | |
features=features_discriminator | |
) | |
self.loss_fn = nn.BCEWithLogitsLoss() | |
self.discriminator_losses = [] | |
self.generator_losses = [] | |
self.curr_step = 0 | |
self.bce = nn.BCEWithLogitsLoss() | |
self.l1_loss = nn.L1Loss() | |
self.save_hyperparameters() | |
def configure_optimizers(self): | |
optimizer_G = optim.Adam(self.gen.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999)) | |
optimizer_D = optim.Adam(self.disc.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999)) | |
return optimizer_G, optimizer_D | |
def on_load_checkpoint(self, checkpoint): | |
# List of keys that you expect to load from the checkpoint | |
keys_to_load = ['discriminator_losses', 'generator_losses', 'curr_step'] | |
# Iterate over the keys and load them if they exist in the checkpoint | |
for key in keys_to_load: | |
if key in checkpoint: | |
setattr(self, key, checkpoint[key]) | |
def on_save_checkpoint(self, checkpoint): | |
# Save the current state of the model | |
checkpoint['discriminator_losses'] = self.discriminator_losses | |
checkpoint['generator_losses'] = self.generator_losses | |
checkpoint['curr_step'] = self.curr_step | |
def training_step(self, batch, batch_idx): | |
# Get the Optimizers | |
opt_generator, opt_discriminator = self.optimizers() | |
X, y = batch | |
# Train Discriminator | |
y_fake = self.gen(X) | |
D_real = self.disc(X, y) | |
D_fake = self.disc(X, y_fake.detach()) | |
D_real_loss = self.loss_fn(D_real, torch.ones_like(D_real)) | |
D_fake_loss = self.loss_fn(D_fake, torch.zeros_like(D_fake)) | |
D_loss = (D_real_loss + D_fake_loss) / 2 | |
opt_discriminator.zero_grad() | |
self.manual_backward(D_loss) | |
opt_discriminator.step() | |
self.log("D_loss", D_loss.item(), on_step=False, on_epoch=True, prog_bar=True) | |
self.discriminator_losses.append(D_loss.item()) | |
# Train Generator | |
D_fake = self.disc(X, y_fake) | |
G_fake_loss = self.bce(D_fake, torch.ones_like(D_fake)) | |
L1 = self.l1_loss(y_fake, y) * self.hparams.l1_lambda | |
G_loss = G_fake_loss + L1 | |
opt_generator.zero_grad() | |
self.manual_backward(G_loss) | |
opt_generator.step() | |
self.log("G_loss", G_loss.item(), on_step=False, on_epoch=True, prog_bar=True) | |
self.generator_losses.append(G_loss.item()) | |
self.log("Current_Step", self.curr_step, on_step=False, on_epoch=True, prog_bar=True) | |
# Visualize | |
if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0: | |
save_some_examples(self.gen, batch, self.current_epoch) | |
self.curr_step += 1 | |