Muhammad Naufal Rizqullah
first commit
ae0af75
import torch
import torch.nn as nn
import lightning as L
import torch.optim as optim
from models.generator import Generator
from models.discriminator import Discriminator
from utility.helper import save_some_examples
class Pix2Pix(L.LightningModule):
def __init__(self, in_channels, learning_rate, l1_lambda, features_generator, features_discriminator, display_step):
super().__init__()
self.automatic_optimization = False
self.gen = Generator(
in_channels=in_channels,
features=features_generator
)
self.disc = Discriminator(
in_channels=in_channels,
features=features_discriminator
)
self.loss_fn = nn.BCEWithLogitsLoss()
self.discriminator_losses = []
self.generator_losses = []
self.curr_step = 0
self.bce = nn.BCEWithLogitsLoss()
self.l1_loss = nn.L1Loss()
self.save_hyperparameters()
def configure_optimizers(self):
optimizer_G = optim.Adam(self.gen.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(self.disc.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999))
return optimizer_G, optimizer_D
def on_load_checkpoint(self, checkpoint):
# List of keys that you expect to load from the checkpoint
keys_to_load = ['discriminator_losses', 'generator_losses', 'curr_step']
# Iterate over the keys and load them if they exist in the checkpoint
for key in keys_to_load:
if key in checkpoint:
setattr(self, key, checkpoint[key])
def on_save_checkpoint(self, checkpoint):
# Save the current state of the model
checkpoint['discriminator_losses'] = self.discriminator_losses
checkpoint['generator_losses'] = self.generator_losses
checkpoint['curr_step'] = self.curr_step
def training_step(self, batch, batch_idx):
# Get the Optimizers
opt_generator, opt_discriminator = self.optimizers()
X, y = batch
# Train Discriminator
y_fake = self.gen(X)
D_real = self.disc(X, y)
D_fake = self.disc(X, y_fake.detach())
D_real_loss = self.loss_fn(D_real, torch.ones_like(D_real))
D_fake_loss = self.loss_fn(D_fake, torch.zeros_like(D_fake))
D_loss = (D_real_loss + D_fake_loss) / 2
opt_discriminator.zero_grad()
self.manual_backward(D_loss)
opt_discriminator.step()
self.log("D_loss", D_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
self.discriminator_losses.append(D_loss.item())
# Train Generator
D_fake = self.disc(X, y_fake)
G_fake_loss = self.bce(D_fake, torch.ones_like(D_fake))
L1 = self.l1_loss(y_fake, y) * self.hparams.l1_lambda
G_loss = G_fake_loss + L1
opt_generator.zero_grad()
self.manual_backward(G_loss)
opt_generator.step()
self.log("G_loss", G_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
self.generator_losses.append(G_loss.item())
self.log("Current_Step", self.curr_step, on_step=False, on_epoch=True, prog_bar=True)
# Visualize
if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0:
save_some_examples(self.gen, batch, self.current_epoch)
self.curr_step += 1