stylegan2-flax-tpu / training.py
akhaliq's picture
akhaliq HF staff
add files
81170fd
raw
history blame
18.8 kB
import jax
import jax.numpy as jnp
import flax
from flax.optim import dynamic_scale as dynamic_scale_lib
from flax.core import frozen_dict
import optax
import numpy as np
import functools
import wandb
import time
import stylegan2
import data_pipeline
import checkpoint
import training_utils
import training_steps
from fid import FID
import logging
logger = logging.getLogger(__name__)
def tree_shape(item):
return jax.tree_map(lambda c: c.shape, item)
def train_and_evaluate(config):
num_devices = jax.device_count() # 8
num_local_devices = jax.local_device_count() # 4
num_workers = jax.process_count()
# --------------------------------------
# Data
# --------------------------------------
ds_train, dataset_info = data_pipeline.get_data(data_dir=config.data_dir,
img_size=config.resolution,
img_channels=config.img_channels,
num_classes=config.c_dim,
num_local_devices=num_local_devices,
batch_size=config.batch_size)
# --------------------------------------
# Seeding and Precision
# --------------------------------------
rng = jax.random.PRNGKey(config.random_seed)
if config.mixed_precision:
dtype = jnp.float16
elif config.bf16:
dtype = jnp.bfloat16
else:
dtype = jnp.float32
logger.info(f'Running on dtype {dtype}')
platform = jax.local_devices()[0].platform
if config.mixed_precision and platform == 'gpu':
dynamic_scale_G_main = dynamic_scale_lib.DynamicScale()
dynamic_scale_D_main = dynamic_scale_lib.DynamicScale()
dynamic_scale_G_reg = dynamic_scale_lib.DynamicScale()
dynamic_scale_D_reg = dynamic_scale_lib.DynamicScale()
clip_conv = 256
num_fp16_res = 4
else:
dynamic_scale_G_main = None
dynamic_scale_D_main = None
dynamic_scale_G_reg = None
dynamic_scale_D_reg = None
clip_conv = None
num_fp16_res = 0
# --------------------------------------
# Initialize Models
# --------------------------------------
logger.info('Initialize models...')
rng, init_rng = jax.random.split(rng)
# Generator initialization for training
start_mn = time.time()
logger.info("Creating MappingNetwork...")
mapping_net = stylegan2.MappingNetwork(z_dim=config.z_dim,
c_dim=config.c_dim,
w_dim=config.w_dim,
num_ws=int(np.log2(config.resolution)) * 2 - 3,
num_layers=8,
dtype=dtype)
mapping_net_vars = mapping_net.init(init_rng,
jnp.ones((1, config.z_dim)),
jnp.ones((1, config.c_dim)))
mapping_net_params, moving_stats = mapping_net_vars['params'], mapping_net_vars['moving_stats']
logger.info(f"MappingNetwork took {time.time() - start_mn:.2f}s")
logger.info("Creating SynthesisNetwork...")
start_sn = time.time()
synthesis_net = stylegan2.SynthesisNetwork(resolution=config.resolution,
num_channels=config.img_channels,
w_dim=config.w_dim,
fmap_base=config.fmap_base,
num_fp16_res=num_fp16_res,
clip_conv=clip_conv,
dtype=dtype)
synthesis_net_vars = synthesis_net.init(init_rng,
jnp.ones((1, mapping_net.num_ws, config.w_dim)))
synthesis_net_params, noise_consts = synthesis_net_vars['params'], synthesis_net_vars['noise_consts']
logger.info(f"SynthesisNetwork took {time.time() - start_sn:.2f}s")
params_G = frozen_dict.FrozenDict(
{'mapping': mapping_net_params,
'synthesis': synthesis_net_params}
)
# Discriminator initialization for training
logger.info("Creating Discriminator...")
start_d = time.time()
discriminator = stylegan2.Discriminator(resolution=config.resolution,
num_channels=config.img_channels,
c_dim=config.c_dim,
mbstd_group_size=config.mbstd_group_size,
num_fp16_res=num_fp16_res,
clip_conv=clip_conv,
dtype=dtype)
rng, init_rng = jax.random.split(rng)
params_D = discriminator.init(init_rng,
jnp.ones((1, config.resolution, config.resolution, config.img_channels)),
jnp.ones((1, config.c_dim)))
logger.info(f"Discriminator took {time.time() - start_d:.2f}s")
# Exponential average Generator initialization
logger.info("Creating Generator EMA...")
start_g = time.time()
generator_ema = stylegan2.Generator(resolution=config.resolution,
num_channels=config.img_channels,
z_dim=config.z_dim,
c_dim=config.c_dim,
w_dim=config.w_dim,
num_ws=int(np.log2(config.resolution)) * 2 - 3,
num_mapping_layers=8,
fmap_base=config.fmap_base,
num_fp16_res=num_fp16_res,
clip_conv=clip_conv,
dtype=dtype)
params_ema_G = generator_ema.init(init_rng,
jnp.ones((1, config.z_dim)),
jnp.ones((1, config.c_dim)))
logger.info(f"Took {time.time() - start_g:.2f}s")
# --------------------------------------
# Initialize States and Optimizers
# --------------------------------------
logger.info('Initialize states...')
tx_G = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
tx_D = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
state_G = training_utils.TrainStateG.create(apply_fn=None,
apply_mapping=mapping_net.apply,
apply_synthesis=synthesis_net.apply,
params=params_G,
moving_stats=moving_stats,
noise_consts=noise_consts,
tx=tx_G,
dynamic_scale_main=dynamic_scale_G_main,
dynamic_scale_reg=dynamic_scale_G_reg,
epoch=0)
state_D = training_utils.TrainStateD.create(apply_fn=discriminator.apply,
params=params_D,
tx=tx_D,
dynamic_scale_main=dynamic_scale_D_main,
dynamic_scale_reg=dynamic_scale_D_reg,
epoch=0)
# Copy over the parameters from the training generator to the ema generator
params_ema_G = training_utils.update_generator_ema(state_G, params_ema_G, config, ema_beta=0)
# Running mean of path length for path length regularization
pl_mean = jnp.zeros((), dtype=dtype)
step = 0
epoch_offset = 0
best_fid_score = np.inf
ckpt_path = None
if config.resume_run_id is not None:
# Resume training from existing checkpoint
ckpt_path = checkpoint.get_latest_checkpoint(config.ckpt_dir)
logger.info(f'Resume training from checkpoint: {ckpt_path}')
ckpt = checkpoint.load_checkpoint(ckpt_path)
step = ckpt['step']
epoch_offset = ckpt['epoch']
best_fid_score = ckpt['fid_score']
pl_mean = ckpt['pl_mean']
state_G = ckpt['state_G']
state_D = ckpt['state_D']
params_ema_G = ckpt['params_ema_G']
config = ckpt['config']
elif config.load_from_pkl is not None:
# Load checkpoint and start new run
ckpt_path = config.load_from_pkl
logger.info(f'Load model state from from : {ckpt_path}')
ckpt = checkpoint.load_checkpoint(ckpt_path)
pl_mean = ckpt['pl_mean']
state_G = ckpt['state_G']
state_D = ckpt['state_D']
params_ema_G = ckpt['params_ema_G']
# Replicate states across devices
pl_mean = flax.jax_utils.replicate(pl_mean)
state_G = flax.jax_utils.replicate(state_G)
state_D = flax.jax_utils.replicate(state_D)
# --------------------------------------
# Precompile train and eval steps
# --------------------------------------
logger.info('Precompile training steps...')
p_main_step_G = jax.pmap(training_steps.main_step_G, axis_name='batch')
p_regul_step_G = jax.pmap(functools.partial(training_steps.regul_step_G, config=config), axis_name='batch')
p_main_step_D = jax.pmap(training_steps.main_step_D, axis_name='batch')
p_regul_step_D = jax.pmap(functools.partial(training_steps.regul_step_D, config=config), axis_name='batch')
# --------------------------------------
# Training
# --------------------------------------
logger.info('Start training...')
fid_metric = FID(generator_ema, ds_train, config)
# Dict to collect training statistics / losses
metrics = {}
num_imgs_processed = 0
num_steps_per_epoch = dataset_info['num_examples'] // (config.batch_size * num_devices)
effective_batch_size = config.batch_size * num_devices
if config.wandb and jax.process_index() == 0:
# do some more logging
wandb.config.effective_batch_size = effective_batch_size
wandb.config.num_steps_per_epoch = num_steps_per_epoch
wandb.config.num_workers = num_workers
wandb.config.device_count = num_devices
wandb.config.num_examples = dataset_info['num_examples']
wandb.config.vm_name = training_utils.get_vm_name()
for epoch in range(epoch_offset, config.num_epochs):
if config.wandb and jax.process_index() == 0:
wandb.log({'training/epochs': epoch}, step=step)
for batch in data_pipeline.prefetch(ds_train, config.num_prefetch):
assert batch['image'].shape[1] == config.batch_size, f"Mismatched batch (batch size: {config.batch_size}, this batch: {batch['image'].shape[1]})"
# pbar.update(num_devices * config.batch_size)
iteration_start_time = time.time()
if config.c_dim == 0:
# No labels in the dataset
batch['label'] = None
# Create two latent noise vectors and combine them for the style mixing regularization
rng, key = jax.random.split(rng)
z_latent1 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
rng, key = jax.random.split(rng)
z_latent2 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
# Split PRNGs across devices
rkey = jax.random.split(key, num=num_local_devices)
mixing_prob = flax.jax_utils.replicate(config.mixing_prob)
# --------------------------------------
# Update Discriminator
# --------------------------------------
time_d_start = time.time()
state_D, metrics = p_main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
time_d_end = time.time()
if step % config.D_reg_interval == 0:
state_D, metrics = p_regul_step_D(state_D, batch, metrics)
# --------------------------------------
# Update Generator
# --------------------------------------
time_g_start = time.time()
state_G, metrics = p_main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
if step % config.G_reg_interval == 0:
H, W = batch['image'].shape[-3], batch['image'].shape[-2]
rng, key = jax.random.split(rng)
pl_noise = jax.random.normal(key, batch['image'].shape, dtype=dtype) / np.sqrt(H * W)
state_G, metrics, pl_mean = p_regul_step_G(state_G, batch, z_latent1, pl_noise, pl_mean, metrics,
rng=rkey)
params_ema_G = training_utils.update_generator_ema(flax.jax_utils.unreplicate(state_G),
params_ema_G,
config)
time_g_end = time.time()
# --------------------------------------
# Logging and Checkpointing
# --------------------------------------
if step % config.save_every == 0 and config.disable_fid:
# If FID evaluation is disabled, a checkpoint will be saved every 'save_every' steps.
if jax.process_index() == 0:
logger.info('Saving checkpoint...')
checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step,
epoch)
num_imgs_processed += num_devices * config.batch_size
if step % config.eval_fid_every == 0 and not config.disable_fid:
# If FID evaluation is enabled, only save a checkpoint if FID score is better.
if jax.process_index() == 0:
logger.info('Computing FID...')
fid_score = fid_metric.compute_fid(params_ema_G).item()
if config.wandb:
wandb.log({'training/gen/fid': fid_score}, step=step)
logger.info(f'Computed FID: {fid_score:.2f}')
if fid_score < best_fid_score:
best_fid_score = fid_score
logger.info(f'New best FID score ({best_fid_score:.3f}). Saving checkpoint...')
ts = time.time()
checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=fid_score)
te = time.time()
logger.info(f'... successfully saved checkpoint in {(te-ts)/60:.1f}min')
sec_per_kimg = (time.time() - iteration_start_time) / (num_devices * config.batch_size / 1000.0)
time_taken_g = time_g_end - time_g_start
time_taken_d = time_d_end - time_d_start
time_taken_per_step = time.time() - iteration_start_time
g_loss = jnp.mean(metrics['G_loss']).item()
d_loss = jnp.mean(metrics['D_loss']).item()
if config.wandb and jax.process_index() == 0:
# wandb logging - happens every step
wandb.log({'training/gen/loss': jnp.mean(metrics['G_loss']).item()}, step=step, commit=False)
wandb.log({'training/dis/loss': jnp.mean(metrics['D_loss']).item()}, step=step, commit=False)
wandb.log({'training/dis/fake_logits': jnp.mean(metrics['fake_logits']).item()}, step=step, commit=False)
wandb.log({'training/dis/real_logits': jnp.mean(metrics['real_logits']).item()}, step=step, commit=False)
wandb.log({'training/time_taken_g': time_taken_g, 'training/time_taken_d': time_taken_d}, step=step, commit=False)
wandb.log({'training/time_taken_per_step': time_taken_per_step}, step=step, commit=False)
wandb.log({'training/num_imgs_trained': num_imgs_processed}, step=step, commit=False)
wandb.log({'training/sec_per_kimg': sec_per_kimg}, step=step)
if step % config.log_every == 0:
# console logging - happens every log_every steps
logger.info(f'Total steps: {step:>6,} - epoch {epoch:>3,}/{config.num_epochs} @ {step % num_steps_per_epoch:>6,}/{num_steps_per_epoch:,} - G loss: {g_loss:.5f} - D loss: {d_loss:.5f} - sec/kimg: {sec_per_kimg:.2f}s - time per step: {time_taken_per_step:.3f}s')
if step % config.generate_samples_every == 0 and config.wandb and jax.process_index() == 0:
# Generate training images
train_snapshot = training_utils.get_training_snapshot(
image_real=flax.jax_utils.unreplicate(batch['image']),
image_gen=flax.jax_utils.unreplicate(metrics['image_gen']),
max_num=10
)
wandb.log({'training/snapshot': wandb.Image(train_snapshot)}, commit=False, step=step)
# Generate evaluation images
labels = None if config.c_dim == 0 else batch['label'][0]
image_gen_eval = training_steps.eval_step_G(
generator_ema, params=params_ema_G,
z_latent=z_latent1[0],
labels=labels,
truncation=1
)
image_gen_eval_trunc = training_steps.eval_step_G(
generator_ema,
params=params_ema_G,
z_latent=z_latent1[0],
labels=labels,
truncation=0.5
)
eval_snapshot = training_utils.get_eval_snapshot(image=image_gen_eval, max_num=10)
eval_snapshot_trunc = training_utils.get_eval_snapshot(image=image_gen_eval_trunc, max_num=10)
wandb.log({'eval/snapshot': wandb.Image(eval_snapshot)}, commit=False, step=step)
wandb.log({'eval/snapshot_trunc': wandb.Image(eval_snapshot_trunc)}, step=step)
step += 1
# Sync moving stats across devices
state_G = training_utils.sync_moving_stats(state_G)
# Sync moving average of path length mean (Generator regularization)
pl_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name='batch'), axis_name='batch')(pl_mean)