File size: 3,767 Bytes
ae0af75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import lightning as L

from config.core import config
from training.model import Pix2Pix
from training.callbacks import MyCustomSavingCallback
from data.dataloader import FaceToComicDataModule


# Add argparser for config params
parser = argparse.ArgumentParser()
parser.add_argument("--load_checkpoint", action='store_true', help="Load checkpoint if this flag is set. If not set, start training from scratch.")
parser.add_argument("--no_load_checkpoint", action='store_false', dest='load_checkpoint', help="Do not load checkpoint. If set, start training from scratch.")

parser.add_argument("--ckpt_path", type=str, default=config.CKPT_PATH, help="Path to checkpoint file. If load_checkpoint is set, this path will be used to load the checkpoint.")
parser.add_argument("--learning_rate", type=float, default=config.LEARNING_RATE, help="Learning rate for Adam optimizer.")
parser.add_argument("--l1_lambda", type=int, default=config.L1_LAMBDA, help="Scale factor for L1 loss.")
parser.add_argument("--features_discriminator", type=int, nargs='+', default=config.FEATURE_DISCRIMINATOR, help="List of feature sizes for the discriminator network.")
parser.add_argument("--features_generator", type=int, default=config.FEATURE_GENERATOR, help="Feature size for the generator network.")
parser.add_argument("--display_step", type=int, default=config.DISPLAY_STEP, help="Interval of epochs to display loss and save examples.")
parser.add_argument("--num_epoch", type=int, default=config.NUM_EPOCH, help="Number of epochs to train for.")
parser.add_argument("--path_face", type=str, default=config.PATH_FACE, help="Path to folder containing face images.")
parser.add_argument("--path_comic", type=str, default=config.PATH_COMIC, help="Path to folder containing comic images.")
parser.add_argument("--image_size", type=int, default=config.IMAGE_SIZE, help="Size of input images.")
parser.add_argument("--batch_size", type=int, default=config.BATCH_SIZE, help="Batch size for training.")
parser.add_argument("--max_samples", type=int, default=config.MAX_SAMPLES, help="Maximum number of samples to use for training. If set to None, all samples will be used.")

args = parser.parse_args()

config.LOAD_CHECKPOINT = args.load_checkpoint if args.load_checkpoint is not None else config.LOAD_CHECKPOINT
config.CKPT_PATH = args.ckpt_path
config.LEARNING_RATE = args.learning_rate
config.L1_LAMBDA = args.l1_lambda
config.FEATURE_DISCRIMINATOR = args.features_discriminator
config.FEATURE_GENERATOR = args.features_generator
config.DISPLAY_STEP = args.display_step
config.NUM_EPOCH = args.num_epoch
config.PATH_FACE = args.path_face
config.PATH_COMIC = args.path_comic
config.IMAGE_SIZE = args.image_size
config.BATCH_SIZE = args.batch_size
config.MAX_SAMPLES = args.max_samples

# Initialize the Model Lightning
model = Pix2Pix(
    in_channels=3,
    learning_rate=config.LEARNING_RATE,
    l1_lambda=config.L1_LAMBDA,
    features_discriminator=config.FEATURE_DISCRIMINATOR,
    features_generator=config.FEATURE_GENERATOR,
    display_step=config.DISPLAY_STEP,
)

# Setup Trainer
n_log = None

trainer = L.Trainer(
    accelerator="auto",
    devices="auto",
    strategy="auto",
    log_every_n_steps=n_log,
    max_epochs=config.NUM_EPOCH,
    callbacks=[MyCustomSavingCallback()],
    default_root_dir="/kaggle/working/",
    precision="16-mixed",
    # fast_dev_run=True
)

# Lightning DataModule
dm = FaceToComicDataModule(
    face_path=config.PATH_FACE, 
    comic_path=config.PATH_COMIC, 
    image_size=(config.IMAGE_SIZE, config.IMAGE_SIZE), 
    batch_size=config.BATCH_SIZE,
    max_samples=None
)

# Training set
if config.LOAD_CHECKPOINT:
    trainer.fit(model, datamodule=dm, ckpt_path=config.CKPT_PATH)
else:
    trainer.fit(model, datamodule=dm)