|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""CNN encoder/decoder architecture based on the VQ-GAN and MaskGIT papers. |
|
|
|
Adapted from https://github.com/google-research/maskgit/blob/main/maskgit/nets/vqgan_tokenizer.py. # pylint: disable=line-too-long |
|
""" |
|
|
|
import dataclasses |
|
import functools |
|
import math |
|
from typing import Any, Sequence |
|
|
|
from big_vision import utils |
|
from big_vision.models import common |
|
from big_vision.models.proj.givt import vae |
|
|
|
import einops |
|
import flax.linen as nn |
|
import flax.training.checkpoints |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
def _get_norm_layer(train, dtype, norm_type="BN"): |
|
"""Create normalization layers. |
|
|
|
Args: |
|
train: Whether to use the layer in training or inference mode. |
|
dtype: Layer output type. |
|
norm_type: Which normalization to use "BN", "LN", or "GN". |
|
|
|
Returns: |
|
An instance of the the layer. |
|
""" |
|
if norm_type == "BN": |
|
return functools.partial( |
|
nn.BatchNorm, |
|
use_running_average=not train, |
|
momentum=0.9, |
|
epsilon=1e-5, |
|
axis_name=None, |
|
axis_index_groups=None, |
|
dtype=jnp.float32, |
|
use_fast_variance=False) |
|
elif norm_type == "LN": |
|
return functools.partial(nn.LayerNorm, dtype=dtype, use_fast_variance=False) |
|
elif norm_type == "GN": |
|
return functools.partial(nn.GroupNorm, dtype=dtype, use_fast_variance=False) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def _tensorflow_style_avg_pooling(x, window_shape, strides, padding: str): |
|
"""Avg pooling as done by TF (Flax layer gives different results). |
|
|
|
To be specific, Flax includes padding cells when taking the average, |
|
while TF does not. |
|
|
|
Args: |
|
x: Input tensor |
|
window_shape: Shape of pooling window; if 1-dim tuple is just 1d pooling, if |
|
2-dim tuple one gets 2d pooling. |
|
strides: Must have the same dimension as the window_shape. |
|
padding: Either 'SAME' or 'VALID' to indicate pooling method. |
|
|
|
Returns: |
|
pooled: Tensor after applying pooling. |
|
""" |
|
pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add, |
|
(1,) + window_shape + (1,), |
|
(1,) + strides + (1,), padding) |
|
pool_denom = jax.lax.reduce_window( |
|
jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,), |
|
(1,) + strides + (1,), padding) |
|
return pool_sum / pool_denom |
|
|
|
|
|
def _upsample(x, factor=2, method="nearest"): |
|
n, h, w, c = x.shape |
|
x = jax.image.resize(x, (n, h * factor, w * factor, c), method=method) |
|
return x |
|
|
|
|
|
def _dsample(x): |
|
return _tensorflow_style_avg_pooling( |
|
x, (2, 2), strides=(2, 2), padding="same") |
|
|
|
|
|
def get_h_w_pixelshuffle(hw, pixel_shuffle_patch_size): |
|
|
|
|
|
ph, pw = pixel_shuffle_patch_size |
|
s = int(math.sqrt(hw * ph * pw)) |
|
h, w = s // ph, s // pw |
|
assert h * w == hw, f"Length {hw} incompatible with pixelshuffle ({ph}, {pw})" |
|
return h, w |
|
|
|
|
|
class ResBlock(nn.Module): |
|
"""Basic Residual Block.""" |
|
filters: int |
|
norm_fn: Any |
|
conv_fn: Any |
|
dtype: int = jnp.float32 |
|
activation_fn: Any = nn.relu |
|
use_conv_shortcut: bool = False |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array) -> jax.Array: |
|
input_dim = x.shape[-1] |
|
residual = x |
|
x = self.norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) |
|
x = self.norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) |
|
if input_dim != self.filters: |
|
if self.use_conv_shortcut: |
|
residual = self.conv_fn( |
|
self.filters, kernel_size=(3, 3), use_bias=False)( |
|
x) |
|
else: |
|
residual = self.conv_fn( |
|
self.filters, kernel_size=(1, 1), use_bias=False)( |
|
x) |
|
return x + residual |
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Encoder Blocks.""" |
|
|
|
filters: int |
|
num_res_blocks: int |
|
channel_multipliers: list[int] |
|
embedding_dim: int |
|
conv_downsample: bool = False |
|
norm_type: str = "GN" |
|
activation_fn_str: str = "swish" |
|
dtype: int = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
if self.activation_fn_str == "relu": |
|
self.activation_fn = nn.relu |
|
elif self.activation_fn_str == "swish": |
|
self.activation_fn = nn.swish |
|
else: |
|
raise NotImplementedError |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array, train: bool = False) -> jax.Array: |
|
conv_fn = nn.Conv |
|
norm_fn = _get_norm_layer( |
|
train=train, dtype=self.dtype, norm_type=self.norm_type) |
|
block_args = dict( |
|
norm_fn=norm_fn, |
|
conv_fn=conv_fn, |
|
dtype=self.dtype, |
|
activation_fn=self.activation_fn, |
|
use_conv_shortcut=False, |
|
) |
|
x = conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) |
|
num_blocks = len(self.channel_multipliers) |
|
for i in range(num_blocks): |
|
filters = self.filters * self.channel_multipliers[i] |
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
if i < num_blocks - 1: |
|
if self.conv_downsample: |
|
x = conv_fn(filters, kernel_size=(4, 4), strides=(2, 2))(x) |
|
else: |
|
x = _dsample(x) |
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
x = norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = conv_fn(self.embedding_dim, kernel_size=(1, 1))(x) |
|
return x |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Decoder Blocks.""" |
|
|
|
filters: int |
|
num_res_blocks: int |
|
channel_multipliers: list[int] |
|
norm_type: str = "GN" |
|
activation_fn_str: str = "swish" |
|
output_dim: int = 3 |
|
dtype: Any = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
if self.activation_fn_str == "relu": |
|
self.activation_fn = nn.relu |
|
elif self.activation_fn_str == "swish": |
|
self.activation_fn = nn.swish |
|
else: |
|
raise NotImplementedError |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array, train: bool = False) -> jax.Array: |
|
conv_fn = nn.Conv |
|
norm_fn = _get_norm_layer( |
|
train=train, dtype=self.dtype, norm_type=self.norm_type) |
|
block_args = dict( |
|
norm_fn=norm_fn, |
|
conv_fn=conv_fn, |
|
dtype=self.dtype, |
|
activation_fn=self.activation_fn, |
|
use_conv_shortcut=False, |
|
) |
|
num_blocks = len(self.channel_multipliers) |
|
filters = self.filters * self.channel_multipliers[-1] |
|
x = conv_fn(filters, kernel_size=(3, 3), use_bias=True)(x) |
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
for i in reversed(range(num_blocks)): |
|
filters = self.filters * self.channel_multipliers[i] |
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
if i > 0: |
|
x = _upsample(x, 2) |
|
x = conv_fn(filters, kernel_size=(3, 3))(x) |
|
x = norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = conv_fn(self.output_dim, kernel_size=(3, 3))(x) |
|
return x |
|
|
|
|
|
class Model(vae.Model): |
|
"""CNN Model.""" |
|
|
|
filters: int = 128 |
|
num_res_blocks: int = 2 |
|
channel_multipliers: list[int] = dataclasses.field(default_factory=list) |
|
conv_downsample: bool = False |
|
activation_fn: str = "swish" |
|
norm_type: str = "GN" |
|
output_dim: int = 3 |
|
dtype: Any = jnp.float32 |
|
|
|
malib_ckpt: bool = False |
|
pixel_shuffle_patch_size: tuple[int, int] = (1, 1) |
|
|
|
def setup(self) -> None: |
|
|
|
self.encoder = Encoder( |
|
filters=self.filters, |
|
num_res_blocks=self.num_res_blocks, |
|
channel_multipliers=self.channel_multipliers, |
|
norm_type=self.norm_type, |
|
activation_fn_str=self.activation_fn, |
|
embedding_dim=2 * self.codeword_dim, |
|
conv_downsample=self.conv_downsample, |
|
dtype=self.dtype, |
|
name="cnn_encoder", |
|
) |
|
self.decoder = Decoder( |
|
filters=self.filters, |
|
num_res_blocks=self.num_res_blocks, |
|
channel_multipliers=self.channel_multipliers, |
|
norm_type=self.norm_type, |
|
activation_fn_str=self.activation_fn, |
|
output_dim=self.output_dim, |
|
dtype=self.dtype, |
|
name="cnn_decoder", |
|
) |
|
|
|
def _maybe_rescale_input(self, x): |
|
return (x + 1.0) / 2.0 if self.malib_ckpt else x |
|
|
|
def _maybe_rescale_output(self, x): |
|
return 2.0 * x - 1.0 if self.malib_ckpt else x |
|
|
|
def _maybe_clip_logvar(self, logvar): |
|
return jnp.clip(logvar, -30.0, 20.0) if self.malib_ckpt else logvar |
|
|
|
def encode( |
|
self, |
|
x: jax.Array, |
|
*, |
|
train: bool = False, |
|
) -> tuple[jax.Array, jax.Array]: |
|
x = self._maybe_rescale_input(x) |
|
x = self.encoder(x, train=train) |
|
assert x.shape[1] == x.shape[2], f"Square spatial dims. required: {x.shape}" |
|
mu, logvar = jnp.split(x, 2, axis=-1) |
|
logvar = self._maybe_clip_logvar(logvar) |
|
|
|
def _space_to_depth(z): |
|
ph, pw = self.pixel_shuffle_patch_size |
|
return einops.rearrange( |
|
z, "b (h ph) (w pw) c -> b (h w) (c ph pw)", |
|
ph=ph, pw=pw |
|
) |
|
|
|
mu, logvar = _space_to_depth(mu), _space_to_depth(logvar) |
|
|
|
return mu, logvar |
|
|
|
def decode(self, x: jax.Array, train: bool = False) -> jax.Array: |
|
|
|
ph, pw = self.pixel_shuffle_patch_size |
|
h, w = get_h_w_pixelshuffle(x.shape[1], (ph, pw)) |
|
|
|
x = einops.rearrange( |
|
x, "b (h w) (c ph pw) -> b (h ph) (w pw) c", |
|
h=h, w=w, |
|
ph=ph, pw=pw |
|
) |
|
x = self.decoder(x, train=train) |
|
x = self._maybe_rescale_output(x) |
|
x = jnp.clip(x, -1.0, 1.0) |
|
|
|
return x |
|
|
|
|
|
def load( |
|
init_params: Any, |
|
init_file: str, |
|
model_params: Any = None, |
|
dont_load: Sequence[str] = (), |
|
malib_ckpt: bool = False, |
|
use_ema_params: bool = False, |
|
) -> Any: |
|
"""Loads params from init checkpoint and merges into init_params. |
|
|
|
Args: |
|
init_params: pytree with (previously initialized) model parameters. |
|
init_file: Path of the checkpoint to load. |
|
model_params: Dict containing the model config. |
|
dont_load: Sequence of (flattened) parameter names which should not be |
|
loaded. |
|
malib_ckpt: Whether the given init_file is a malib checkpoint. |
|
use_ema_params: Whether to load the EMA params (for malib checkpoints). |
|
|
|
Returns: |
|
pytree containing the loaded model parameters. |
|
""" |
|
|
|
|
|
del model_params |
|
|
|
assert malib_ckpt or (not use_ema_params), ( |
|
"Loading EMA parameters is only supported for malib checkpoints.") |
|
|
|
if malib_ckpt: |
|
|
|
|
|
with jax.transfer_guard("allow"): |
|
vaegan_params = flax.training.checkpoints.restore_checkpoint( |
|
init_file, None) |
|
vaegan_params_flat = utils.tree_flatten_with_names(vaegan_params)[0] |
|
prefix_old = "ema_params/" if use_ema_params else "g_params/" |
|
vaegan_params_flat = [(k.replace(prefix_old, "cnn_"), v) |
|
for k, v in vaegan_params_flat if prefix_old in k] |
|
params = utils.tree_unflatten(vaegan_params_flat) |
|
else: |
|
params = flax.core.unfreeze(utils.load_params(init_file)) |
|
|
|
if init_params is not None: |
|
params = common.merge_params(params, init_params, dont_load) |
|
return params |
|
|