huaweilin's picture
update
14ce5a9
"""This file contains the definition of the the autoencoder parts"""
import math
import torch
import torch.nn.functional as F
class Conv2dSame(torch.nn.Conv2d):
"""Convolution wrapper for 2D convolutions using `SAME` padding."""
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
"""Calculate padding such that the output has the same height/width when stride=1.
Args:
i -> int: Input size.
k -> int: Kernel size.
s -> int: Stride size.
d -> int: Dilation rate.
"""
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the convolution applying explicit `same` padding.
Args:
x -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
ih, iw = x.size()[-2:]
pad_h = self.calc_same_pad(
i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]
)
pad_w = self.calc_same_pad(
i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]
)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
)
return super().forward(x)
def GroupNorm(in_channels):
"""GroupNorm with 32 groups."""
if in_channels % 32 != 0:
raise ValueError(
f"GroupNorm requires in_channels to be divisible by 32, got {in_channels}."
)
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class ResidualBlock(torch.nn.Module):
"""Residual block with two convolutional layers."""
def __init__(self, in_channels: int, out_channels: int = None, norm_func=GroupNorm):
"""Initializes the residual block.
Args:
in_channels -> int: Number of input channels.
out_channels -> int: Number of output channels. Default is in_channels.
norm_func -> Callable: Normalization function. Default is GroupNorm.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = self.in_channels if out_channels is None else out_channels
self.norm1 = norm_func(self.in_channels)
self.conv1 = Conv2dSame(
self.in_channels, self.out_channels, kernel_size=3, bias=False
)
self.norm2 = norm_func(self.out_channels)
self.conv2 = Conv2dSame(
self.out_channels, self.out_channels, kernel_size=3, bias=False
)
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv2dSame(
self.out_channels, self.out_channels, kernel_size=1, bias=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass of the residual block.
Args:
hidden_states -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
residual = self.nin_shortcut(hidden_states)
return hidden_states + residual
class ResidualStage(torch.nn.Module):
"""Residual stage with multiple residual blocks."""
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
norm_func=GroupNorm,
):
"""Initializes the residual stage.
Args:
in_channels -> int: Number of input channels.
out_channels -> int: Number of output channels.
num_res_blocks -> int: Number of residual blocks.
norm_func -> Callable: Normalization function. Default is GroupNorm.
"""
super().__init__()
self.res_blocks = torch.nn.ModuleList()
for _ in range(num_res_blocks):
self.res_blocks.append(
ResidualBlock(in_channels, out_channels, norm_func=norm_func)
)
in_channels = out_channels
def forward(self, hidden_states: torch.Tensor, *unused_args) -> torch.Tensor:
"""Forward pass of the residual stage.
Args:
hidden_states -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
for res_block in self.res_blocks:
hidden_states = res_block(hidden_states)
return hidden_states
class DownsamplingStage(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
sample_with_conv: bool = False,
norm_func=GroupNorm,
):
"""Initializes the downsampling stage.
Args:
in_channels -> int: Number of input channels.
out_channels -> int: Number of output channels.
num_res_blocks -> int: Number of residual blocks.
sample_with_conv -> bool: Whether to sample with a convolution or with a stride. Default is False.
norm_func -> Callable: Normalization function. Default is GroupNorm.
"""
super().__init__()
self.res_blocks = torch.nn.ModuleList()
for _ in range(num_res_blocks):
self.res_blocks.append(ResidualBlock(in_channels, out_channels, norm_func))
in_channels = out_channels
self.sample_with_conv = sample_with_conv
if self.sample_with_conv:
self.down_conv = Conv2dSame(
in_channels, in_channels, kernel_size=3, stride=2
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass of the downsampling stage.
Args:
hidden_states -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
for res_block in self.res_blocks:
hidden_states = res_block(hidden_states)
if self.sample_with_conv:
hidden_states = self.down_conv(hidden_states)
else:
hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2)
return hidden_states
class UpsamplingStage(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
norm_func=GroupNorm,
):
"""Initializes the upsampling stage.
Args:
in_channels -> int: Number of input channels.
out_channels -> int: Number of output channels.
num_res_blocks -> int: Number of residual blocks.
norm_func -> Callable: Normalization function. Default is GroupNorm.
"""
super().__init__()
self.res_blocks = torch.nn.ModuleList()
for _ in range(num_res_blocks):
self.res_blocks.append(ResidualBlock(in_channels, out_channels, norm_func))
in_channels = out_channels
self.upsample_conv = Conv2dSame(out_channels, out_channels, kernel_size=3)
def forward(self, hidden_states: torch.Tensor, *unused_args) -> torch.Tensor:
"""Forward pass of the upsampling stage.
Args:
hidden_states -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
for res_block in self.res_blocks:
hidden_states = res_block(hidden_states)
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.upsample_conv(hidden_states)
return hidden_states
class ConvEncoder(torch.nn.Module):
def __init__(self, config):
"""Initializes the convolutional encoder.
Args:
config: Configuration of the model architecture.
"""
super().__init__()
self.config = config
self.conv_in = Conv2dSame(
self.config.num_channels,
self.config.hidden_channels,
kernel_size=3,
bias=False,
)
in_channel_mult = (1,) + tuple(self.config.channel_mult)
num_res_blocks = self.config.num_res_blocks
hidden_channels = self.config.hidden_channels
encoder_blocks = []
for i_level in range(self.config.num_resolutions):
in_channels = hidden_channels * in_channel_mult[i_level]
out_channels = hidden_channels * in_channel_mult[i_level + 1]
if i_level < (self.config.num_resolutions - 1):
encoder_blocks.append(
DownsamplingStage(
in_channels,
out_channels,
num_res_blocks,
self.config.sample_with_conv,
)
)
else:
encoder_blocks.append(
ResidualStage(in_channels, out_channels, num_res_blocks)
)
self.down = torch.nn.ModuleList(encoder_blocks)
# middle
mid_channels = out_channels
self.mid = ResidualStage(mid_channels, mid_channels, num_res_blocks)
# end
self.norm_out = GroupNorm(mid_channels)
self.conv_out = Conv2dSame(mid_channels, self.config.token_size, kernel_size=1)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Forward pass of the convolutional encoder.
Args:
pixel_values -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
# downsampling
hidden_states = self.conv_in(pixel_values)
for block in self.down:
hidden_states = block(hidden_states)
# middle
hidden_states = self.mid(hidden_states)
# end
hidden_states = self.norm_out(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class ConvDecoderLegacy(torch.nn.Module):
"""
This is a legacy decoder class. It is used to support older weights.
"""
def __init__(self, config):
"""Initializes the convolutional decoder in a legacy variant.
Args:
config: Configuration of the model architecture.
"""
super().__init__()
self.config = config
# compute in_channel_mult, block_in and curr_res at lowest res
block_in = (
self.config.hidden_channels
* self.config.channel_mult[self.config.num_resolutions - 1]
)
num_res_blocks = self.config.num_res_blocks
hidden_channels = self.config.hidden_channels
in_channel_mult = tuple(self.config.channel_mult) + (
self.config.channel_mult[-1],
)
# z to block_in
self.conv_in = Conv2dSame(self.config.token_size, block_in, kernel_size=3)
# middle
self.mid = ResidualStage(block_in, block_in, num_res_blocks)
# upsampling
decoder_blocks = []
for i_level in reversed(range(self.config.num_resolutions)):
in_channels = hidden_channels * in_channel_mult[i_level + 1]
out_channels = hidden_channels * in_channel_mult[i_level]
if i_level > 0:
decoder_blocks.append(
UpsamplingStage(in_channels, out_channels, num_res_blocks)
)
else:
decoder_blocks.append(
ResidualStage(in_channels, out_channels, num_res_blocks)
)
self.up = torch.nn.ModuleList(list(reversed(decoder_blocks)))
# end
self.norm_out = GroupNorm(out_channels)
self.conv_out = Conv2dSame(
out_channels, self.config.num_channels, kernel_size=3
)
def forward(self, z_quantized: torch.Tensor) -> torch.Tensor:
"""Forward pass of the convolutional decoder.
Args:
z_quantized -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
# z to block_in
hidden_states = self.conv_in(z_quantized)
# middle
hidden_states = self.mid(hidden_states)
# upsampling decoder
for block in reversed(self.up):
hidden_states = block(hidden_states, z_quantized)
# end
hidden_states = self.norm_out(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class ConvDecoder(torch.nn.Module):
def __init__(self, config):
"""Initializes the convolutional decoder.
Args:
config: Configuration of the model architecture.
"""
super().__init__()
self.config = config
# compute in_channel_mult, block_in and curr_res at lowest res
block_in = (
self.config.hidden_channels
* self.config.channel_mult[self.config.num_resolutions - 1]
)
num_res_blocks = self.config.get(
"num_res_blocks_decoder", self.config.num_res_blocks
)
hidden_channels = self.config.hidden_channels
in_channel_mult = tuple(self.config.channel_mult) + (
self.config.channel_mult[-1],
)
# z to block_in
if config.quantizer_type == "vae":
self.conv_in = Conv2dSame(
self.config.token_size // 2, block_in, kernel_size=3
)
else:
self.conv_in = Conv2dSame(self.config.token_size, block_in, kernel_size=3)
# middle
self.mid = ResidualStage(block_in, block_in, num_res_blocks)
# upsampling
decoder_blocks = []
for i_level in reversed(range(self.config.num_resolutions)):
in_channels = hidden_channels * in_channel_mult[i_level + 1]
out_channels = hidden_channels * in_channel_mult[i_level]
if i_level > 0:
decoder_blocks.append(
UpsamplingStage(in_channels, out_channels, num_res_blocks)
)
else:
decoder_blocks.append(
ResidualStage(in_channels, out_channels, num_res_blocks)
)
self.up = torch.nn.ModuleList(decoder_blocks)
# end
self.norm_out = GroupNorm(out_channels)
self.conv_out = Conv2dSame(
out_channels, self.config.num_channels, kernel_size=3
)
def forward(self, z_quantized: torch.Tensor) -> torch.Tensor:
"""Forward pass of the convolutional decoder.
Args:
z_quantized -> torch.Tensor: Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
# z to block_in
hidden_states = self.conv_in(z_quantized)
# middle
hidden_states = self.mid(hidden_states)
# upsampling decoder
for block in self.up:
hidden_states = block(hidden_states, z_quantized)
# end
hidden_states = self.norm_out(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
if __name__ == "__main__":
class Config:
def __init__(self, **kwargs):
for key in kwargs:
setattr(self, key, kwargs[key])
def get(self, key, default):
return getattr(self, key, default)
config_dict = dict(
resolution=256,
num_channels=3,
hidden_channels=128,
channel_mult=(1, 2, 2, 4),
num_res_blocks=2,
codebook_size=1024,
token_size=256,
num_resolutions=4,
sample_with_conv=False,
quantizer_type="lookup",
)
config = Config(**config_dict)
encoder = ConvEncoder(config)
decoder = ConvDecoder(config)
config.sample_with_conv = True
encoder_conv_down = ConvEncoder(config)
print("Encoder:\n{}".format(encoder))
print("Encoder downsampling with conv:\n{}".format(encoder_conv_down))
print("Decoder:\n{}".format(decoder))
x = torch.randn((1, 3, 256, 256))
x_enc = encoder(x)
x_enc_down_with_conv = encoder_conv_down(x)
x_dec = decoder(x_enc)
x_dec_down_with_conv = decoder(x_enc_down_with_conv)
print(f"Input shape: {x.shape}")
print(f"Encoder output shape: {x_enc.shape}")
print(f"Encoder with conv as down output shape: {x_enc_down_with_conv.shape}")
print(f"Decoder output shape: {x_dec.shape}")
print(f"Decoder with conv as down output shape: {x_dec_down_with_conv.shape}")