57894-Pix2Pix / train.py
Muhammad Naufal Rizqullah
first commit
ae0af75
import argparse
import lightning as L
from config.core import config
from training.model import Pix2Pix
from training.callbacks import MyCustomSavingCallback
from data.dataloader import FaceToComicDataModule
# Add argparser for config params
parser = argparse.ArgumentParser()
parser.add_argument("--load_checkpoint", action='store_true', help="Load checkpoint if this flag is set. If not set, start training from scratch.")
parser.add_argument("--no_load_checkpoint", action='store_false', dest='load_checkpoint', help="Do not load checkpoint. If set, start training from scratch.")
parser.add_argument("--ckpt_path", type=str, default=config.CKPT_PATH, help="Path to checkpoint file. If load_checkpoint is set, this path will be used to load the checkpoint.")
parser.add_argument("--learning_rate", type=float, default=config.LEARNING_RATE, help="Learning rate for Adam optimizer.")
parser.add_argument("--l1_lambda", type=int, default=config.L1_LAMBDA, help="Scale factor for L1 loss.")
parser.add_argument("--features_discriminator", type=int, nargs='+', default=config.FEATURE_DISCRIMINATOR, help="List of feature sizes for the discriminator network.")
parser.add_argument("--features_generator", type=int, default=config.FEATURE_GENERATOR, help="Feature size for the generator network.")
parser.add_argument("--display_step", type=int, default=config.DISPLAY_STEP, help="Interval of epochs to display loss and save examples.")
parser.add_argument("--num_epoch", type=int, default=config.NUM_EPOCH, help="Number of epochs to train for.")
parser.add_argument("--path_face", type=str, default=config.PATH_FACE, help="Path to folder containing face images.")
parser.add_argument("--path_comic", type=str, default=config.PATH_COMIC, help="Path to folder containing comic images.")
parser.add_argument("--image_size", type=int, default=config.IMAGE_SIZE, help="Size of input images.")
parser.add_argument("--batch_size", type=int, default=config.BATCH_SIZE, help="Batch size for training.")
parser.add_argument("--max_samples", type=int, default=config.MAX_SAMPLES, help="Maximum number of samples to use for training. If set to None, all samples will be used.")
args = parser.parse_args()
config.LOAD_CHECKPOINT = args.load_checkpoint if args.load_checkpoint is not None else config.LOAD_CHECKPOINT
config.CKPT_PATH = args.ckpt_path
config.LEARNING_RATE = args.learning_rate
config.L1_LAMBDA = args.l1_lambda
config.FEATURE_DISCRIMINATOR = args.features_discriminator
config.FEATURE_GENERATOR = args.features_generator
config.DISPLAY_STEP = args.display_step
config.NUM_EPOCH = args.num_epoch
config.PATH_FACE = args.path_face
config.PATH_COMIC = args.path_comic
config.IMAGE_SIZE = args.image_size
config.BATCH_SIZE = args.batch_size
config.MAX_SAMPLES = args.max_samples
# Initialize the Model Lightning
model = Pix2Pix(
in_channels=3,
learning_rate=config.LEARNING_RATE,
l1_lambda=config.L1_LAMBDA,
features_discriminator=config.FEATURE_DISCRIMINATOR,
features_generator=config.FEATURE_GENERATOR,
display_step=config.DISPLAY_STEP,
)
# Setup Trainer
n_log = None
trainer = L.Trainer(
accelerator="auto",
devices="auto",
strategy="auto",
log_every_n_steps=n_log,
max_epochs=config.NUM_EPOCH,
callbacks=[MyCustomSavingCallback()],
default_root_dir="/kaggle/working/",
precision="16-mixed",
# fast_dev_run=True
)
# Lightning DataModule
dm = FaceToComicDataModule(
face_path=config.PATH_FACE,
comic_path=config.PATH_COMIC,
image_size=(config.IMAGE_SIZE, config.IMAGE_SIZE),
batch_size=config.BATCH_SIZE,
max_samples=None
)
# Training set
if config.LOAD_CHECKPOINT:
trainer.fit(model, datamodule=dm, ckpt_path=config.CKPT_PATH)
else:
trainer.fit(model, datamodule=dm)