|
"""This file contains the definition of the the autoencoder parts""" |
|
|
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
class Conv2dSame(torch.nn.Conv2d): |
|
"""Convolution wrapper for 2D convolutions using `SAME` padding.""" |
|
|
|
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: |
|
"""Calculate padding such that the output has the same height/width when stride=1. |
|
|
|
Args: |
|
i -> int: Input size. |
|
k -> int: Kernel size. |
|
s -> int: Stride size. |
|
d -> int: Dilation rate. |
|
""" |
|
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the convolution applying explicit `same` padding. |
|
|
|
Args: |
|
x -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
ih, iw = x.size()[-2:] |
|
|
|
pad_h = self.calc_same_pad( |
|
i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] |
|
) |
|
pad_w = self.calc_same_pad( |
|
i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] |
|
) |
|
|
|
if pad_h > 0 or pad_w > 0: |
|
x = F.pad( |
|
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] |
|
) |
|
return super().forward(x) |
|
|
|
|
|
def GroupNorm(in_channels): |
|
"""GroupNorm with 32 groups.""" |
|
if in_channels % 32 != 0: |
|
raise ValueError( |
|
f"GroupNorm requires in_channels to be divisible by 32, got {in_channels}." |
|
) |
|
return torch.nn.GroupNorm( |
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
|
|
|
|
class ResidualBlock(torch.nn.Module): |
|
"""Residual block with two convolutional layers.""" |
|
|
|
def __init__(self, in_channels: int, out_channels: int = None, norm_func=GroupNorm): |
|
"""Initializes the residual block. |
|
|
|
Args: |
|
in_channels -> int: Number of input channels. |
|
out_channels -> int: Number of output channels. Default is in_channels. |
|
norm_func -> Callable: Normalization function. Default is GroupNorm. |
|
""" |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = self.in_channels if out_channels is None else out_channels |
|
|
|
self.norm1 = norm_func(self.in_channels) |
|
self.conv1 = Conv2dSame( |
|
self.in_channels, self.out_channels, kernel_size=3, bias=False |
|
) |
|
|
|
self.norm2 = norm_func(self.out_channels) |
|
self.conv2 = Conv2dSame( |
|
self.out_channels, self.out_channels, kernel_size=3, bias=False |
|
) |
|
|
|
if self.in_channels != self.out_channels: |
|
self.nin_shortcut = Conv2dSame( |
|
self.out_channels, self.out_channels, kernel_size=1, bias=False |
|
) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the residual block. |
|
|
|
Args: |
|
hidden_states -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
residual = hidden_states |
|
hidden_states = self.norm1(hidden_states) |
|
hidden_states = F.silu(hidden_states) |
|
hidden_states = self.conv1(hidden_states) |
|
|
|
hidden_states = self.norm2(hidden_states) |
|
hidden_states = F.silu(hidden_states) |
|
hidden_states = self.conv2(hidden_states) |
|
|
|
if self.in_channels != self.out_channels: |
|
residual = self.nin_shortcut(hidden_states) |
|
|
|
return hidden_states + residual |
|
|
|
|
|
class ResidualStage(torch.nn.Module): |
|
"""Residual stage with multiple residual blocks.""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_res_blocks: int, |
|
norm_func=GroupNorm, |
|
): |
|
"""Initializes the residual stage. |
|
|
|
Args: |
|
in_channels -> int: Number of input channels. |
|
out_channels -> int: Number of output channels. |
|
num_res_blocks -> int: Number of residual blocks. |
|
norm_func -> Callable: Normalization function. Default is GroupNorm. |
|
""" |
|
super().__init__() |
|
|
|
self.res_blocks = torch.nn.ModuleList() |
|
for _ in range(num_res_blocks): |
|
self.res_blocks.append( |
|
ResidualBlock(in_channels, out_channels, norm_func=norm_func) |
|
) |
|
in_channels = out_channels |
|
|
|
def forward(self, hidden_states: torch.Tensor, *unused_args) -> torch.Tensor: |
|
"""Forward pass of the residual stage. |
|
|
|
Args: |
|
hidden_states -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
for res_block in self.res_blocks: |
|
hidden_states = res_block(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class DownsamplingStage(torch.nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_res_blocks: int, |
|
sample_with_conv: bool = False, |
|
norm_func=GroupNorm, |
|
): |
|
"""Initializes the downsampling stage. |
|
|
|
Args: |
|
in_channels -> int: Number of input channels. |
|
out_channels -> int: Number of output channels. |
|
num_res_blocks -> int: Number of residual blocks. |
|
sample_with_conv -> bool: Whether to sample with a convolution or with a stride. Default is False. |
|
norm_func -> Callable: Normalization function. Default is GroupNorm. |
|
""" |
|
super().__init__() |
|
|
|
self.res_blocks = torch.nn.ModuleList() |
|
for _ in range(num_res_blocks): |
|
self.res_blocks.append(ResidualBlock(in_channels, out_channels, norm_func)) |
|
in_channels = out_channels |
|
|
|
self.sample_with_conv = sample_with_conv |
|
if self.sample_with_conv: |
|
self.down_conv = Conv2dSame( |
|
in_channels, in_channels, kernel_size=3, stride=2 |
|
) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the downsampling stage. |
|
|
|
Args: |
|
hidden_states -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
for res_block in self.res_blocks: |
|
hidden_states = res_block(hidden_states) |
|
|
|
if self.sample_with_conv: |
|
hidden_states = self.down_conv(hidden_states) |
|
else: |
|
hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2) |
|
|
|
return hidden_states |
|
|
|
|
|
class UpsamplingStage(torch.nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_res_blocks: int, |
|
norm_func=GroupNorm, |
|
): |
|
"""Initializes the upsampling stage. |
|
|
|
Args: |
|
in_channels -> int: Number of input channels. |
|
out_channels -> int: Number of output channels. |
|
num_res_blocks -> int: Number of residual blocks. |
|
norm_func -> Callable: Normalization function. Default is GroupNorm. |
|
""" |
|
super().__init__() |
|
|
|
self.res_blocks = torch.nn.ModuleList() |
|
for _ in range(num_res_blocks): |
|
self.res_blocks.append(ResidualBlock(in_channels, out_channels, norm_func)) |
|
in_channels = out_channels |
|
|
|
self.upsample_conv = Conv2dSame(out_channels, out_channels, kernel_size=3) |
|
|
|
def forward(self, hidden_states: torch.Tensor, *unused_args) -> torch.Tensor: |
|
"""Forward pass of the upsampling stage. |
|
|
|
Args: |
|
hidden_states -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
for res_block in self.res_blocks: |
|
hidden_states = res_block(hidden_states) |
|
|
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") |
|
hidden_states = self.upsample_conv(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class ConvEncoder(torch.nn.Module): |
|
def __init__(self, config): |
|
"""Initializes the convolutional encoder. |
|
|
|
Args: |
|
config: Configuration of the model architecture. |
|
""" |
|
super().__init__() |
|
self.config = config |
|
|
|
self.conv_in = Conv2dSame( |
|
self.config.num_channels, |
|
self.config.hidden_channels, |
|
kernel_size=3, |
|
bias=False, |
|
) |
|
|
|
in_channel_mult = (1,) + tuple(self.config.channel_mult) |
|
num_res_blocks = self.config.num_res_blocks |
|
hidden_channels = self.config.hidden_channels |
|
|
|
encoder_blocks = [] |
|
for i_level in range(self.config.num_resolutions): |
|
in_channels = hidden_channels * in_channel_mult[i_level] |
|
out_channels = hidden_channels * in_channel_mult[i_level + 1] |
|
|
|
if i_level < (self.config.num_resolutions - 1): |
|
encoder_blocks.append( |
|
DownsamplingStage( |
|
in_channels, |
|
out_channels, |
|
num_res_blocks, |
|
self.config.sample_with_conv, |
|
) |
|
) |
|
else: |
|
encoder_blocks.append( |
|
ResidualStage(in_channels, out_channels, num_res_blocks) |
|
) |
|
self.down = torch.nn.ModuleList(encoder_blocks) |
|
|
|
|
|
mid_channels = out_channels |
|
self.mid = ResidualStage(mid_channels, mid_channels, num_res_blocks) |
|
|
|
|
|
self.norm_out = GroupNorm(mid_channels) |
|
self.conv_out = Conv2dSame(mid_channels, self.config.token_size, kernel_size=1) |
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the convolutional encoder. |
|
|
|
Args: |
|
pixel_values -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
|
|
hidden_states = self.conv_in(pixel_values) |
|
|
|
for block in self.down: |
|
hidden_states = block(hidden_states) |
|
|
|
hidden_states = self.mid(hidden_states) |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = F.silu(hidden_states) |
|
hidden_states = self.conv_out(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class ConvDecoderLegacy(torch.nn.Module): |
|
""" |
|
This is a legacy decoder class. It is used to support older weights. |
|
""" |
|
|
|
def __init__(self, config): |
|
"""Initializes the convolutional decoder in a legacy variant. |
|
|
|
Args: |
|
config: Configuration of the model architecture. |
|
""" |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
|
|
block_in = ( |
|
self.config.hidden_channels |
|
* self.config.channel_mult[self.config.num_resolutions - 1] |
|
) |
|
num_res_blocks = self.config.num_res_blocks |
|
hidden_channels = self.config.hidden_channels |
|
in_channel_mult = tuple(self.config.channel_mult) + ( |
|
self.config.channel_mult[-1], |
|
) |
|
|
|
|
|
self.conv_in = Conv2dSame(self.config.token_size, block_in, kernel_size=3) |
|
|
|
|
|
self.mid = ResidualStage(block_in, block_in, num_res_blocks) |
|
|
|
|
|
decoder_blocks = [] |
|
for i_level in reversed(range(self.config.num_resolutions)): |
|
in_channels = hidden_channels * in_channel_mult[i_level + 1] |
|
out_channels = hidden_channels * in_channel_mult[i_level] |
|
if i_level > 0: |
|
decoder_blocks.append( |
|
UpsamplingStage(in_channels, out_channels, num_res_blocks) |
|
) |
|
else: |
|
decoder_blocks.append( |
|
ResidualStage(in_channels, out_channels, num_res_blocks) |
|
) |
|
|
|
self.up = torch.nn.ModuleList(list(reversed(decoder_blocks))) |
|
|
|
|
|
self.norm_out = GroupNorm(out_channels) |
|
self.conv_out = Conv2dSame( |
|
out_channels, self.config.num_channels, kernel_size=3 |
|
) |
|
|
|
def forward(self, z_quantized: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the convolutional decoder. |
|
|
|
Args: |
|
z_quantized -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
|
|
hidden_states = self.conv_in(z_quantized) |
|
|
|
|
|
hidden_states = self.mid(hidden_states) |
|
|
|
|
|
for block in reversed(self.up): |
|
hidden_states = block(hidden_states, z_quantized) |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = F.silu(hidden_states) |
|
hidden_states = self.conv_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class ConvDecoder(torch.nn.Module): |
|
def __init__(self, config): |
|
"""Initializes the convolutional decoder. |
|
|
|
Args: |
|
config: Configuration of the model architecture. |
|
""" |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
|
|
block_in = ( |
|
self.config.hidden_channels |
|
* self.config.channel_mult[self.config.num_resolutions - 1] |
|
) |
|
num_res_blocks = self.config.get( |
|
"num_res_blocks_decoder", self.config.num_res_blocks |
|
) |
|
hidden_channels = self.config.hidden_channels |
|
in_channel_mult = tuple(self.config.channel_mult) + ( |
|
self.config.channel_mult[-1], |
|
) |
|
|
|
|
|
if config.quantizer_type == "vae": |
|
self.conv_in = Conv2dSame( |
|
self.config.token_size // 2, block_in, kernel_size=3 |
|
) |
|
else: |
|
self.conv_in = Conv2dSame(self.config.token_size, block_in, kernel_size=3) |
|
|
|
|
|
self.mid = ResidualStage(block_in, block_in, num_res_blocks) |
|
|
|
|
|
decoder_blocks = [] |
|
for i_level in reversed(range(self.config.num_resolutions)): |
|
in_channels = hidden_channels * in_channel_mult[i_level + 1] |
|
out_channels = hidden_channels * in_channel_mult[i_level] |
|
if i_level > 0: |
|
decoder_blocks.append( |
|
UpsamplingStage(in_channels, out_channels, num_res_blocks) |
|
) |
|
else: |
|
decoder_blocks.append( |
|
ResidualStage(in_channels, out_channels, num_res_blocks) |
|
) |
|
self.up = torch.nn.ModuleList(decoder_blocks) |
|
|
|
|
|
self.norm_out = GroupNorm(out_channels) |
|
self.conv_out = Conv2dSame( |
|
out_channels, self.config.num_channels, kernel_size=3 |
|
) |
|
|
|
def forward(self, z_quantized: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass of the convolutional decoder. |
|
|
|
Args: |
|
z_quantized -> torch.Tensor: Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
|
|
hidden_states = self.conv_in(z_quantized) |
|
|
|
|
|
hidden_states = self.mid(hidden_states) |
|
|
|
|
|
for block in self.up: |
|
hidden_states = block(hidden_states, z_quantized) |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = F.silu(hidden_states) |
|
hidden_states = self.conv_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
class Config: |
|
def __init__(self, **kwargs): |
|
for key in kwargs: |
|
setattr(self, key, kwargs[key]) |
|
|
|
def get(self, key, default): |
|
return getattr(self, key, default) |
|
|
|
config_dict = dict( |
|
resolution=256, |
|
num_channels=3, |
|
hidden_channels=128, |
|
channel_mult=(1, 2, 2, 4), |
|
num_res_blocks=2, |
|
codebook_size=1024, |
|
token_size=256, |
|
num_resolutions=4, |
|
sample_with_conv=False, |
|
quantizer_type="lookup", |
|
) |
|
config = Config(**config_dict) |
|
|
|
encoder = ConvEncoder(config) |
|
decoder = ConvDecoder(config) |
|
|
|
config.sample_with_conv = True |
|
encoder_conv_down = ConvEncoder(config) |
|
|
|
print("Encoder:\n{}".format(encoder)) |
|
print("Encoder downsampling with conv:\n{}".format(encoder_conv_down)) |
|
print("Decoder:\n{}".format(decoder)) |
|
|
|
x = torch.randn((1, 3, 256, 256)) |
|
x_enc = encoder(x) |
|
x_enc_down_with_conv = encoder_conv_down(x) |
|
x_dec = decoder(x_enc) |
|
x_dec_down_with_conv = decoder(x_enc_down_with_conv) |
|
|
|
print(f"Input shape: {x.shape}") |
|
print(f"Encoder output shape: {x_enc.shape}") |
|
print(f"Encoder with conv as down output shape: {x_enc_down_with_conv.shape}") |
|
print(f"Decoder output shape: {x_dec.shape}") |
|
print(f"Decoder with conv as down output shape: {x_dec_down_with_conv.shape}") |
|
|