hma / magvit2 /modules /diffusionmodules /improved_model.py
LeroyWaa's picture
draft
246c106
raw
history blame
7.73 kB
import torch
import torch.nn as nn
from magvit2.config import VQConfig
def swish(x):
# swish
return x*torch.sigmoid(x)
class ResBlock(nn.Module):
def __init__(self,
in_filters,
out_filters,
use_conv_shortcut = False
) -> None:
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.use_conv_shortcut = use_conv_shortcut
self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
if in_filters != out_filters:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
else:
self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False)
def forward(self, x, **kwargs):
residual = x
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_filters != self.out_filters:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return x + residual
class Encoder(nn.Module):
def __init__(
# self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
self, config: VQConfig,
):
super().__init__()
self.in_channels = config.in_channels
self.z_channels = config.z_channels
self.num_res_blocks = config.num_res_blocks
self.num_blocks = len(config.ch_mult)
self.conv_in = nn.Conv2d(
config.in_channels,
config.base_channels,
kernel_size=(3, 3),
padding=1,
bias=False
)
## construct the model
self.down = nn.ModuleList()
in_ch_mult = (1,) + tuple(config.ch_mult)
for i_level in range(self.num_blocks):
block = nn.ModuleList()
block_in = config.base_channels * in_ch_mult[i_level] # [1, 1, 2, 2, 4]
block_out = config.base_channels * config.ch_mult[i_level] # [1, 2, 2, 4]
for _ in range(self.num_res_blocks):
block.append(ResBlock(block_in, block_out))
block_in = block_out
down = nn.Module()
down.block = block
if i_level < self.num_blocks - 1:
down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.down.append(down)
### mid
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
### end
self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
self.conv_out = nn.Conv2d(block_out, config.z_channels, kernel_size=(1, 1))
def forward(self, x):
## down
x = self.conv_in(x)
for i_level in range(self.num_blocks):
for i_block in range(self.num_res_blocks):
x = self.down[i_level].block[i_block](x)
if i_level < self.num_blocks - 1:
x = self.down[i_level].downsample(x)
## mid
for res in range(self.num_res_blocks):
x = self.mid_block[res](x)
x = self.norm_out(x)
x = swish(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
# def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)) -> None:
def __init__(self, config: VQConfig) -> None:
super().__init__()
self.base_channels = config.base_channels
self.num_blocks = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
block_in = self.base_channels * config.ch_mult[self.num_blocks - 1]
self.conv_in = nn.Conv2d(
config.z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True
)
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_blocks)):
block = nn.ModuleList()
block_out = self.base_channels * config.ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResBlock(block_in, block_out))
block_in = block_out
up = nn.Module()
up.block = block
if i_level > 0:
up.upsample = Upsampler(block_in)
self.up.insert(0, up)
self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
self.conv_out = nn.Conv2d(block_in, config.out_channels, kernel_size=(3, 3), padding=1)
def forward(self, z):
z = self.conv_in(z)
## mid
for res in range(self.num_res_blocks):
z = self.mid_block[res](z)
## upsample
for i_level in reversed(range(self.num_blocks)):
for i_block in range(self.num_res_blocks):
z = self.up[i_level].block[i_block](z)
if i_level > 0:
z = self.up[i_level].upsample(z)
z = self.norm_out(z)
z = swish(z)
z = self.conv_out(z)
return z
def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
""" Depth-to-Space DCR mode (depth-column-row) core implementation.
Args:
x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported.
block_size (int): block side size
"""
# check inputs
if x.dim() < 3:
raise ValueError(
f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions"
)
c, h, w = x.shape[-3:]
s = block_size**2
if c % s != 0:
raise ValueError(
f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels"
)
outer_dims = x.shape[:-3]
# splitting two additional dimensions from the channel dimension
x = x.view(-1, block_size, block_size, c // s, h, w)
# putting the two new dimensions along H and W
x = x.permute(0, 3, 4, 1, 5, 2)
# merging the two new dimensions with H and W
x = x.contiguous().view(*outer_dims, c // s, h * block_size,
w * block_size)
return x
class Upsampler(nn.Module):
def __init__(
self,
dim,
dim_out = None
):
super().__init__()
dim_out = dim * 4
self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1)
self.depth2space = depth_to_space
def forward(self, x):
"""
input_image: [B C H W]
"""
out = self.conv1(x)
out = self.depth2space(out, block_size=2)
return out
# if __name__ == "__main__":
# x = torch.randn(size = (2, 3, 128, 128))
# encoder = Encoder(ch=128, in_channels=3, num_res_blocks=2, z_channels=18, out_ch=3, resolution=128)
# decoder = Decoder(out_ch=3, z_channels=18, num_res_blocks=2, ch=128, in_channels=3, resolution=128)
# z = encoder(x)
# out = decoder(z)