|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Invertible adaptor based on iRevNet. |
|
|
|
Based on the PyTorch version from: |
|
https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/iRevNet.py |
|
""" |
|
|
|
from typing import Any, Optional, Sequence |
|
|
|
from big_vision import utils |
|
from big_vision.models import common |
|
from big_vision.models.proj.givt import cnn |
|
import einops |
|
import flax.core |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
def _split(x: jax.Array) -> tuple[jax.Array, jax.Array]: |
|
n = x.shape[-1] // 2 |
|
x1 = x[:, :, :, :n] |
|
x2 = x[:, :, :, n:] |
|
return x1, x2 |
|
|
|
|
|
def _merge(x1: jax.Array, x2: jax.Array) -> jax.Array: |
|
return jnp.concatenate((x1, x2), axis=-1) |
|
|
|
|
|
class IRevNetBlock(nn.Module): |
|
"""iRevNet Block.""" |
|
first: int = False |
|
dropout_rate: float = 0. |
|
num_channels: int = 2 |
|
num_channels_bottleneck: Optional[int] = None |
|
num_grps_norm: int = 32 |
|
|
|
@nn.compact |
|
def _fx2(self, x: jax.Array, train: bool = True) -> jax.Array: |
|
if not self.first: |
|
y = nn.GroupNorm(num_groups=self.num_grps_norm, name="gn_0")(x) |
|
y = nn.relu(y) |
|
else: |
|
y = x |
|
|
|
ks = (3, 3) |
|
y = nn.Conv(self.num_channels_bottleneck or self.num_channels, |
|
kernel_size=ks, padding=1, use_bias=False)(y) |
|
y = nn.GroupNorm(num_groups=self.num_grps_norm, name="gn_1")(y) |
|
y = nn.relu(y) |
|
|
|
y = nn.Conv(self.num_channels_bottleneck or self.num_channels, |
|
kernel_size=ks, padding=1, use_bias=False)(y) |
|
y = nn.Dropout(rate=self.dropout_rate, deterministic=(not train))(y) |
|
y = nn.GroupNorm(num_groups=self.num_grps_norm, name="gn_2")(y) |
|
y = nn.relu(y) |
|
|
|
y = nn.Conv(self.num_channels, kernel_size=ks, padding=1, use_bias=False)(y) |
|
|
|
return y |
|
|
|
def forward( |
|
self, |
|
x: tuple[jax.Array, jax.Array], |
|
train: bool = True, |
|
) -> tuple[jax.Array, jax.Array]: |
|
"""Bijective block forward.""" |
|
x1, x2 = x[0], x[1] |
|
fx2 = self._fx2(x2, train=train) |
|
y1 = fx2 + x1 |
|
return (x2, y1) |
|
|
|
def inverse(self, |
|
x: tuple[jax.Array, jax.Array], |
|
train: bool = True |
|
) -> tuple[jax.Array, jax.Array]: |
|
"""Bijective block inverse.""" |
|
x2, y1 = x[0], x[1] |
|
fx2 = -self._fx2(x2, train=train) |
|
x1 = fx2 + y1 |
|
return (x1, x2) |
|
|
|
|
|
class IRevNet(nn.Module): |
|
"""iRevNet.""" |
|
num_blocks: int = 4 |
|
num_channels: int = 4 |
|
num_channels_bottleneck: Optional[int] = None |
|
dropout_rate: float = 0.0 |
|
|
|
def setup(self) -> None: |
|
num_grps_norm = min(32, self.num_channels // 2) |
|
self.modules = [ |
|
IRevNetBlock( |
|
first=(i == 0), |
|
num_channels=self.num_channels // 2, |
|
num_channels_bottleneck=( |
|
self.num_channels_bottleneck or self.num_channels) // 2, |
|
num_grps_norm=num_grps_norm, |
|
dropout_rate=self.dropout_rate, |
|
) |
|
for i in range(self.num_blocks) |
|
] |
|
|
|
def forward(self, x: jax.Array, train: bool = True) -> jax.Array: |
|
out = _split(x) |
|
for m in self.modules: |
|
out = m.forward(out, train=train) |
|
out_bij = _merge(out[0], out[1]) |
|
return out_bij |
|
|
|
def inverse(self, out_bij: jax.Array, train: bool = True) -> jax.Array: |
|
out = _split(out_bij) |
|
for m in reversed(self.modules): |
|
out = m.inverse(out, train=train) |
|
out = _merge(out[0], out[1]) |
|
return out |
|
|
|
def __call__(self, x: jax.Array, train: bool = True) -> jax.Array: |
|
return self.forward(x, train=train) |
|
|
|
|
|
class Model(IRevNet): |
|
"""Wrapper for IRevNet to function as an adaptor in our setup.""" |
|
|
|
pixel_shuffle_patch_size: tuple[int, int] = (1, 1) |
|
|
|
def forward(self, x: jax.Array, train: bool = True) -> jax.Array: |
|
|
|
|
|
h, w = cnn.get_h_w_pixelshuffle(x.shape[1], self.pixel_shuffle_patch_size) |
|
x = einops.rearrange(x, "b (h w) c -> b h w c", h=h, w=w) |
|
x = super().forward(x, train) |
|
x = einops.rearrange(x, "b h w c -> b (h w) c") |
|
|
|
return x |
|
|
|
def inverse(self, out_bij: jax.Array, train: bool = True) -> jax.Array: |
|
|
|
h, w = cnn.get_h_w_pixelshuffle( |
|
out_bij.shape[1], self.pixel_shuffle_patch_size) |
|
out_bij = einops.rearrange(out_bij, "b (h w) c -> b h w c", h=h, w=w) |
|
out_bij = super().inverse(out_bij, train) |
|
out_bij = einops.rearrange(out_bij, "b h w c -> b (h w) c") |
|
|
|
return out_bij |
|
|
|
|
|
def load( |
|
init_params: Any, |
|
init_file: str, |
|
model_params: Any = None, |
|
dont_load: Sequence[str] = (), |
|
) -> Any: |
|
"""Loads params from init checkpoint and merges into init_params.""" |
|
del model_params |
|
ckpt_params = flax.core.unfreeze(utils.load_params(init_file)) |
|
if init_params is not None: |
|
ckpt_params = common.merge_params(ckpt_params, init_params, dont_load) |
|
return ckpt_params |
|
|