try: |
from localutils.debugger import enable_debug |
enable_debug() |
except ImportError: |
pass |
import flax.linen as nn |
import jax.numpy as jnp |
from absl import app, flags |
from functools import partial |
import numpy as np |
import tqdm |
import jax |
import jax.numpy as jnp |
import flax |
import optax |
import wandb |
from ml_collections import config_flags |
import ml_collections |
import tensorflow_datasets as tfds |
import tensorflow as tf |
tf.config.set_visible_devices([], "GPU") |
tf.config.set_visible_devices([], "TPU") |
import matplotlib.pyplot as plt |
from typing import Any |
import os |
from utils.wandb import setup_wandb, default_wandb_config |
from utils.train_state import TrainState, target_update |
from utils.checkpoint import Checkpoint |
from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model |
from utils.fid import get_fid_network, fid_from_stats |
from models.vqvae import VQVAE |
from models.discriminator import Discriminator |
FLAGS = flags.FLAGS |
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.') |
flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).') |
flags.DEFINE_string('load_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint.tmp" , 'Load dir (if not None, load params from here).') |
flags.DEFINE_integer('seed', 0, 'Random seed.') |
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') |
flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.') |
flags.DEFINE_integer('save_interval', 1000, 'Save interval.') |
flags.DEFINE_integer('batch_size', 64, 'Total Batch size.') |
flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.') |
model_config = ml_collections.ConfigDict({ |
'lr': 0.0001, |
'beta1': 0.0, |
'beta2': 0.99, |
'lr_warmup_steps': 2000, |
'lr_decay_steps': 500_000, |
'filters': 128, |
'num_res_blocks': 2, |
'channel_multipliers': (1, 1, 2, 2, 4, 4), |
'embedding_dim': 4, |
'norm_type': 'GN', |
'weight_decay': 0.05, |
'clip_gradient': 1.0, |
'l2_loss_weight': 1.0, |
'eps_update_rate': 0.9999, |
'quantizer_type': 'vq', |
'quantizer_loss_ratio': 1, |
'codebook_size': 1024, |
'entropy_loss_ratio': 0.1, |
'entropy_loss_type': 'softmax', |
'entropy_temperature': 0.01, |
'commitment_cost': 0.25, |
'fsq_levels': 5, |
'kl_weight': 0.000001, |
'g_adversarial_loss_weight': 0.5, |
'g_grad_penalty_cost': 10, |
'perceptual_loss_weight': 0.5, |
'gan_warmup_steps': 25000, |
}) |
wandb_config = default_wandb_config() |
wandb_config.update({ |
'project': 'vqvae', |
'name': 'vqvae_{dataset_name}', |
}) |
config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False) |
config_flags.DEFINE_config_dict('model', model_config, lock_config=False) |
@jax.vmap |
def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: |
"""https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py |
""" |
zeros = jnp.zeros_like(logits, dtype=logits.dtype) |
condition = (logits >= zeros) |
relu_logits = jnp.where(condition, logits, zeros) |
neg_abs_logits = jnp.where(condition, -logits, logits) |
return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits)) |
class VQGANModel(flax.struct.PyTreeNode): |
rng: Any |
config: dict = flax.struct.field(pytree_node=False) |
vqvae: TrainState |
vqvae_eps: TrainState |
discriminator: TrainState |
@partial(jax.pmap, axis_name='data', in_axes=(0, 0)) |
def update(self, images, pmap_axis='data'): |
new_rng, curr_key = jax.random.split(self.rng, 2) |
resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy') |
is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32) |
def loss_fn(params_vqvae, params_disc): |
reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key}) |
print("Reconstructed images shape", reconstructed_images.shape) |
print("Input images shape", images.shape) |
assert reconstructed_images.shape == images.shape |
discriminator_fn = lambda x: self.discriminator(x, params=params_disc) |
real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False) |
gradient = vjp_fn(jnp.ones_like(real_logit))[0] |
gradient = gradient.reshape((images.shape[0], -1)) |
gradient = jnp.asarray(gradient, jnp.float32) |
penalty = jnp.sum(jnp.square(gradient), axis=-1) |
penalty = jnp.mean(penalty) |
fake_logit = discriminator_fn(reconstructed_images) |
d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean() |
d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean() |
loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost']) |
d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean() |
d_loss_for_vae = d_loss_for_vae * is_gan_training |
real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images) |
fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images) |
perceptual_loss = jnp.mean((real_pools - fake_pools)**2) |
l2_loss = jnp.mean((reconstructed_images - images) ** 2) |
quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0 |
if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two": |
quantizer_loss = quantizer_loss * self.config['kl_weight'] |
loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \ |
+ (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \ |
+ (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \ |
+ (perceptual_loss * FLAGS.model['perceptual_loss_weight']) |
codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0 |
return (loss_vae, loss_d), { |
'loss_vae': loss_vae, |
'loss_d': loss_d, |
'l2_loss': l2_loss, |
'd_loss_for_vae': d_loss_for_vae, |
'perceptual_loss': perceptual_loss, |
'quantizer_loss': quantizer_loss, |
'codebook_usage': codebook_usage, |
} |
_, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True) |
vae_grads, _ = grad_fn((1., 0.)) |
_, d_grads = grad_fn((0., 1.)) |
vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis) |
d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis) |
d_grads = jax.tree_map(lambda x: x * is_gan_training, d_grads) |
info = jax.lax.pmean(info, axis_name=pmap_axis) |
if self.config['quantizer_type'] == 'fsq': |
info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1] |
updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params) |
new_params = optax.apply_updates(self.vqvae.params, updates) |
new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state) |
updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params) |
new_params = optax.apply_updates(self.discriminator.params, updates) |
new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state) |
info['grad_norm_vae'] = optax.global_norm(vae_grads) |
info['grad_norm_d'] = optax.global_norm(d_grads) |
info['update_norm'] = optax.global_norm(updates) |
info['param_norm'] = optax.global_norm(new_params) |
info['is_gan_training'] = is_gan_training |
new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate']) |
new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator) |
return new_model, info |
@partial(jax.pmap, axis_name='data', in_axes=(0, 0)) |
def reconstruction(self, images, pmap_axis='data', sampling = False): |
if not sampling: |
reconstructed_images, _ = self.vqvae_eps(images) |
else: |
new_rng, curr_key = jax.random.split(self.rng, 2) |
reconstructed_images, _ = self.vqvae(images, rngs={'noise': curr_key}) |
reconstructed_images = jnp.clip(reconstructed_images, 0, 1) |
return reconstructed_images |
def main(_): |
np.random.seed(FLAGS.seed) |
print("Using devices", jax.local_devices()) |
device_count = len(jax.local_devices()) |
global_device_count = jax.device_count() |
local_batch_size = FLAGS.batch_size // (global_device_count // device_count) |
print("Device count", device_count) |
print("Global device count", global_device_count) |
print("Global Batch: ", FLAGS.batch_size) |
print("Node Batch: ", local_batch_size) |
print("Device Batch:", local_batch_size // device_count) |
if jax.process_index() == 0: |
setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb) |
def get_dataset(is_train): |
if 'imagenet' in FLAGS.dataset_name: |
def deserialization_fn(data): |
image = data['image'] |
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) |
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side) |
if 'imagenet256' in FLAGS.dataset_name: |
image = tf.image.resize(image, (256, 256)) |
elif 'imagenet128' in FLAGS.dataset_name: |
image = tf.image.resize(image, (128, 128)) |
else: |
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}") |
if is_train: |
image = tf.image.random_flip_left_right(image) |
image = tf.cast(image, tf.float32) / 255.0 |
return image |
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True) |
print(split) |
dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm") |
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE) |
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True) |
dataset = dataset.repeat() |
dataset = dataset.batch(local_batch_size) |
dataset = dataset.prefetch(tf.data.AUTOTUNE) |
dataset = tfds.as_numpy(dataset) |
dataset = iter(dataset) |
return dataset |
else: |
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}") |
dataset = get_dataset(is_train=True) |
dataset_valid = get_dataset(is_train=False) |
example_obs = next(dataset)[:1] |
get_fid_activations = get_fid_network() |
if not os.path.exists('./data/imagenet256_fidstats_openai.npz'): |
raise ValueError("Please download the FID stats file! See the README.") |
truth_fid_stats = np.load("./base_stats.npz") |
rng = jax.random.PRNGKey(FLAGS.seed) |
rng, param_key = jax.random.split(rng) |
print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB") |
FLAGS.model.image_channels = example_obs.shape[-1] |
FLAGS.model.image_size = example_obs.shape[1] |
vqvae_def = VQVAE(FLAGS.model, train=True) |
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params'] |
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2']) |
vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx) |
vqvae_def_eps = VQVAE(FLAGS.model, train=False) |
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params) |
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params))) |
discriminator_def = Discriminator(FLAGS.model) |
discriminator_params = discriminator_def.init(param_key, example_obs)['params'] |
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2']) |
discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx) |
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params))) |
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model) |
if FLAGS.load_dir is not None: |
try: |
cp = Checkpoint(FLAGS.load_dir) |
model = cp.load_model(model) |
print("Loaded model with step", model.vqvae.step) |
except: |
print("Random init") |
else: |
print("Random init") |
model = flax.jax_utils.replicate(model, devices=jax.local_devices()) |
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias']) |
best_fid = 100000 |
for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1), |
smoothing=0.1, |
dynamic_ncols=True): |
batch_images = next(dataset) |
batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) |
model, update_info = model.update(batch_images) |
if i % FLAGS.log_interval == 0: |
update_info = jax.tree_map(lambda x: x.mean(), update_info) |
train_metrics = {f'training/{k}': v for k, v in update_info.items()} |
if jax.process_index() == 0: |
wandb.log(train_metrics, step=i) |
if i % FLAGS.eval_interval == 0: |
reconstructed_images = model.reconstruction(batch_images) |
valid_images = next(dataset_valid) |
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) |
valid_reconstructed_images = model.reconstruction(valid_images) |
if jax.process_index() == 0: |
wandb.log({'batch_image_mean': batch_images.mean()}, step=i) |
wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i) |
wandb.log({'batch_image_std': batch_images.std()}, step=i) |
wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i) |
fig, axs = plt.subplots(2, 8, figsize=(30, 15)) |
for j in range(4): |
axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1) |
axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1) |
wandb.log({'reconstruction': wandb.Image(fig)}, step=i) |
plt.close(fig) |
fig, axs = plt.subplots(2, 8, figsize=(30, 15)) |
for j in range(4): |
axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1) |
axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1) |
wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i) |
plt.close(fig) |
_, valid_update_info = model.update(valid_images) |
valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info) |
valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()} |
if jax.process_index() == 0: |
wandb.log(valid_metrics, step=i) |
activations = [] |
activations2 = [] |
for _ in range(780): |
valid_images = next(dataset_valid) |
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) |
valid_reconstructed_images = model.reconstruction(valid_images) |
valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3), |
method='bilinear', antialias=False) |
valid_reconstructed_images = 2 * valid_reconstructed_images - 1 |
activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]] |
activations = np.concatenate(activations, axis=0) |
activations = activations.reshape((-1, activations.shape[-1])) |
print("doing this much FID", activations.shape) |
mu1 = np.mean(activations, axis=0) |
sigma1 = np.cov(activations, rowvar=False) |
fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma']) |
if jax.process_index() == 0: |
wandb.log({'validation/fid': fid}, step=i) |
print("validation FID at step", i, fid) |
if fid < best_fid: |
model_single = flax.jax_utils.unreplicate(model) |
cp = Checkpoint(FLAGS.save_dir + "best.tmp") |
cp.set_model(model_single) |
cp.save() |
best_fid = fid |
if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None): |
if jax.process_index() == 0: |
model_single = flax.jax_utils.unreplicate(model) |
cp = Checkpoint(FLAGS.save_dir) |
cp.set_model(model_single) |
cp.save() |
if __name__ == '__main__': |
app.run(main) |