flux-dev-flax / flux /modules /autoencoder.py
lnyan's picture
Update
d4607d7
from dataclasses import dataclass
from einops import rearrange
import jax
import jax.numpy as jnp
from jax import Array as Tensor
from flax import nnx
from flux.wrapper import TorchWrapper
from flux.math import dot_product_attention
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
swish = nnx.swish
class AttnBlock(nnx.Module):
def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs, dtype=dtype)
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# b, c, h, w = q.shape
b, h, w, c = q.shape
# q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
# k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
# v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
q = rearrange(q, "b h w c -> b 1 (h w) c")
k = rearrange(k, "b h w c -> b 1 (h w) c")
v = rearrange(v, "b h w c -> b 1 (h w) c")
# h_ = nn.functional.scaled_dot_product_attention(q, k, v)
h_ = dot_product_attention(q, k, v)
# return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
return rearrange(h_, "b 1 (h w) c -> b h w c", h=h, w=w, c=c, b=b)
def __call__(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nnx.Module):
def __init__(self, in_channels: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs, dtype=dtype)
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def __call__(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nnx.Module):
def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs, dtype=dtype)
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def __call__(self, x: Tensor):
# pad = (0, 1, 0, 1)
# x = nn.functional.pad(x, pad, mode="constant", value=0)
x = jnp.pad(x, ((0, 0), (0, 1), (0, 1), (0, 0)), mode="constant")
x = self.conv(x)
return x
class Upsample(nnx.Module):
def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs, dtype=dtype)
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def __call__(self, x: Tensor):
# x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
B, H, W, C = x.shape
x = jax.image.resize(x, (B, H * 2, W * 2, C), method="nearest")
x = self.conv(x)
return x
ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class = ResnetBlock, Downsample, Upsample, AttnBlock
class Encoder(nnx.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
dtype=jnp.float32,
rngs: nnx.Rngs = None
):
nn = TorchWrapper(rngs, dtype=dtype)
ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def __call__(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nnx.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
dtype=jnp.float32,
rngs: nnx.Rngs = None
):
nn = TorchWrapper(rngs, dtype=dtype)
ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nnx.Module):
def __init__(self, sample: bool = True, chunk_dim: int = -1, dtype=jnp.float32, rngs: nnx.Rngs = None):
self.sample = sample
self.chunk_dim = chunk_dim
self.rngs = rngs
self.dtype = dtype
def __call__(self, z: Tensor) -> Tensor:
# mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
mean, logvar = jnp.split(z, 2, axis=self.chunk_dim)
if self.sample:
# std = torch.exp(0.5 * logvar)
# return mean + std * torch.randn_like(mean)
std = jnp.exp(0.5 * logvar)
return mean + std * jax.random.normal(self.rngs(), mean.shape)
else:
return mean
Encoder_class, Decoder_class, DiagonalGaussian_class = Encoder, Decoder, DiagonalGaussian
class AutoEncoder(nnx.Module):
def __init__(self, params: AutoEncoderParams, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs, dtype=dtype)
Encoder, Decoder, DiagonalGaussian = nn.declare_with_rng(Encoder_class, Decoder_class, DiagonalGaussian_class)
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def __call__(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))