File size: 5,335 Bytes
09b13b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from pathlib import Path
from lightweight_gan import Trainer, NanException
import torch
import torch.multiprocessing as mp
import numpy as np
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_list(el):
return el if isinstance(el, list) else [el]
def timestamped_filename(prefix = 'generated-'):
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
return f'{prefix}{timestamp}'
def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
def run_training(model_args, data, load_from, new, num_train_steps, name, seed):
if seed is not None:
set_seed(seed)
model = Trainer(**model_args)
if not new:
model.load(load_from)
else:
model.clear()
progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
G, D, D_aug = model.init_accelerator()
# model.set_data_src(data)
while model.steps < num_train_steps:
# retry_call(model.train, tries=3, exceptions=NanException)
model.train(G, D, D_aug)
progress_bar.n = model.steps
progress_bar.refresh()
if model.accelerator.is_local_main_process and model.steps % 50 == 0:
model.print_log()
model.save(model.checkpoint_num)
def train_from_folder(
dataset_name = 'huggan/CelebA-faces',
data = './data',
results_dir = './results',
models_dir = './models',
name = 'default',
new = False,
load_from = -1,
image_size = 256,
optimizer = 'adam',
fmap_max = 512,
transparent = False,
greyscale = False,
batch_size = 10,
gradient_accumulate_every = 4,
num_train_steps = 150000,
learning_rate = 2e-4,
save_every = 10000,
evaluate_every = 1000,
generate = False,
generate_types = ['default', 'ema'],
generate_interpolation = False,
aug_test = False,
aug_prob=None,
aug_types=['cutout', 'translation'],
dataset_aug_prob=0.,
attn_res_layers = [32],
freq_chan_attn = False,
disc_output_size = 1,
dual_contrast_loss = False,
antialias = False,
interpolation_num_steps = 100,
save_frames = False,
num_image_tiles = None,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
seed = 42,
cpu = False,
mixed_precision = "no",
show_progress = False,
wandb = False,
push_to_hub = False,
organization_name = None,
):
if push_to_hub:
if name == 'default':
raise RuntimeError(
"You've chosen to push to hub, but have left the --name flag as 'default'."
" You should name your model something other than 'default'!"
)
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
model_args = dict(
dataset_name = dataset_name,
name = name,
results_dir = results_dir,
models_dir = models_dir,
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
attn_res_layers = cast_list(attn_res_layers),
freq_chan_attn = freq_chan_attn,
disc_output_size = disc_output_size,
dual_contrast_loss = dual_contrast_loss,
antialias = antialias,
image_size = image_size,
num_image_tiles = num_image_tiles,
optimizer = optimizer,
fmap_max = fmap_max,
transparent = transparent,
greyscale = greyscale,
lr = learning_rate,
save_every = save_every,
evaluate_every = evaluate_every,
aug_prob = aug_prob,
aug_types = cast_list(aug_types),
dataset_aug_prob = dataset_aug_prob,
calculate_fid_every = calculate_fid_every,
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
cpu = cpu,
mixed_precision = mixed_precision,
wandb = wandb,
push_to_hub = push_to_hub,
organization_name = organization_name
)
if generate:
model = Trainer(**model_args)
model.load(load_from)
samples_name = timestamped_filename()
checkpoint = model.checkpoint_num
dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
print(f'sample images generated at {dir_result}')
return
if generate_interpolation:
model = Trainer(**model_args)
model.load(load_from)
samples_name = timestamped_filename()
model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
return
if show_progress:
model = Trainer(**model_args)
model.show_progress(num_images=num_image_tiles, types=generate_types)
return
run_training(model_args, data, load_from, new, num_train_steps, name, seed)
def main():
fire.Fire(train_from_folder)
if __name__ == "__main__":
main() |