|
import fire |
|
import random |
|
from retry.api import retry_call |
|
from tqdm import tqdm |
|
from datetime import datetime |
|
from pathlib import Path |
|
from lightweight_gan import Trainer, NanException |
|
|
|
import torch |
|
import torch.multiprocessing as mp |
|
|
|
import numpy as np |
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
def cast_list(el): |
|
return el if isinstance(el, list) else [el] |
|
|
|
def timestamped_filename(prefix = 'generated-'): |
|
now = datetime.now() |
|
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") |
|
return f'{prefix}{timestamp}' |
|
|
|
def set_seed(seed): |
|
torch.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
def run_training(model_args, data, load_from, new, num_train_steps, name, seed): |
|
|
|
if seed is not None: |
|
set_seed(seed) |
|
|
|
model = Trainer(**model_args) |
|
|
|
if not new: |
|
model.load(load_from) |
|
else: |
|
model.clear() |
|
|
|
progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>') |
|
G, D, D_aug = model.init_accelerator() |
|
|
|
|
|
|
|
while model.steps < num_train_steps: |
|
|
|
model.train(G, D, D_aug) |
|
progress_bar.n = model.steps |
|
progress_bar.refresh() |
|
if model.accelerator.is_local_main_process and model.steps % 50 == 0: |
|
model.print_log() |
|
|
|
model.save(model.checkpoint_num) |
|
|
|
def train_from_folder( |
|
dataset_name = 'huggan/CelebA-faces', |
|
data = './data', |
|
results_dir = './results', |
|
models_dir = './models', |
|
name = 'default', |
|
new = False, |
|
load_from = -1, |
|
image_size = 256, |
|
optimizer = 'adam', |
|
fmap_max = 512, |
|
transparent = False, |
|
greyscale = False, |
|
batch_size = 10, |
|
gradient_accumulate_every = 4, |
|
num_train_steps = 150000, |
|
learning_rate = 2e-4, |
|
save_every = 10000, |
|
evaluate_every = 1000, |
|
generate = False, |
|
generate_types = ['default', 'ema'], |
|
generate_interpolation = False, |
|
aug_test = False, |
|
aug_prob=None, |
|
aug_types=['cutout', 'translation'], |
|
dataset_aug_prob=0., |
|
attn_res_layers = [32], |
|
freq_chan_attn = False, |
|
disc_output_size = 1, |
|
dual_contrast_loss = False, |
|
antialias = False, |
|
interpolation_num_steps = 100, |
|
save_frames = False, |
|
num_image_tiles = None, |
|
calculate_fid_every = None, |
|
calculate_fid_num_images = 12800, |
|
clear_fid_cache = False, |
|
seed = 42, |
|
cpu = False, |
|
mixed_precision = "no", |
|
show_progress = False, |
|
wandb = False, |
|
push_to_hub = False, |
|
organization_name = None, |
|
): |
|
if push_to_hub: |
|
if name == 'default': |
|
raise RuntimeError( |
|
"You've chosen to push to hub, but have left the --name flag as 'default'." |
|
" You should name your model something other than 'default'!" |
|
) |
|
|
|
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8) |
|
|
|
model_args = dict( |
|
dataset_name = dataset_name, |
|
name = name, |
|
results_dir = results_dir, |
|
models_dir = models_dir, |
|
batch_size = batch_size, |
|
gradient_accumulate_every = gradient_accumulate_every, |
|
attn_res_layers = cast_list(attn_res_layers), |
|
freq_chan_attn = freq_chan_attn, |
|
disc_output_size = disc_output_size, |
|
dual_contrast_loss = dual_contrast_loss, |
|
antialias = antialias, |
|
image_size = image_size, |
|
num_image_tiles = num_image_tiles, |
|
optimizer = optimizer, |
|
fmap_max = fmap_max, |
|
transparent = transparent, |
|
greyscale = greyscale, |
|
lr = learning_rate, |
|
save_every = save_every, |
|
evaluate_every = evaluate_every, |
|
aug_prob = aug_prob, |
|
aug_types = cast_list(aug_types), |
|
dataset_aug_prob = dataset_aug_prob, |
|
calculate_fid_every = calculate_fid_every, |
|
calculate_fid_num_images = calculate_fid_num_images, |
|
clear_fid_cache = clear_fid_cache, |
|
cpu = cpu, |
|
mixed_precision = mixed_precision, |
|
wandb = wandb, |
|
push_to_hub = push_to_hub, |
|
organization_name = organization_name |
|
) |
|
|
|
if generate: |
|
model = Trainer(**model_args) |
|
model.load(load_from) |
|
samples_name = timestamped_filename() |
|
checkpoint = model.checkpoint_num |
|
dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types) |
|
print(f'sample images generated at {dir_result}') |
|
return |
|
|
|
if generate_interpolation: |
|
model = Trainer(**model_args) |
|
model.load(load_from) |
|
samples_name = timestamped_filename() |
|
model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames) |
|
print(f'interpolation generated at {results_dir}/{name}/{samples_name}') |
|
return |
|
|
|
if show_progress: |
|
model = Trainer(**model_args) |
|
model.show_progress(num_images=num_image_tiles, types=generate_types) |
|
return |
|
|
|
run_training(model_args, data, load_from, new, num_train_steps, name, seed) |
|
|
|
def main(): |
|
fire.Fire(train_from_folder) |
|
|
|
if __name__ == "__main__": |
|
main() |