import torch import torch.nn as nn from magvit2.config import VQConfig def swish(x): # swish return x*torch.sigmoid(x) class ResBlock(nn.Module): def __init__(self, in_filters, out_filters, use_conv_shortcut = False ) -> None: super().__init__() self.in_filters = in_filters self.out_filters = out_filters self.use_conv_shortcut = use_conv_shortcut self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6) self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6) self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) if in_filters != out_filters: if self.use_conv_shortcut: self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) else: self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False) def forward(self, x, **kwargs): residual = x x = self.norm1(x) x = swish(x) x = self.conv1(x) x = self.norm2(x) x = swish(x) x = self.conv2(x) if self.in_filters != self.out_filters: if self.use_conv_shortcut: residual = self.conv_shortcut(residual) else: residual = self.nin_shortcut(residual) return x + residual class Encoder(nn.Module): def __init__( # self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4), self, config: VQConfig, ): super().__init__() self.in_channels = config.in_channels self.z_channels = config.z_channels self.num_res_blocks = config.num_res_blocks self.num_blocks = len(config.ch_mult) self.conv_in = nn.Conv2d( config.in_channels, config.base_channels, kernel_size=(3, 3), padding=1, bias=False ) ## construct the model self.down = nn.ModuleList() in_ch_mult = (1,) + tuple(config.ch_mult) for i_level in range(self.num_blocks): block = nn.ModuleList() block_in = config.base_channels * in_ch_mult[i_level] # [1, 1, 2, 2, 4] block_out = config.base_channels * config.ch_mult[i_level] # [1, 2, 2, 4] for _ in range(self.num_res_blocks): block.append(ResBlock(block_in, block_out)) block_in = block_out down = nn.Module() down.block = block if i_level < self.num_blocks - 1: down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1) self.down.append(down) ### mid self.mid_block = nn.ModuleList() for res_idx in range(self.num_res_blocks): self.mid_block.append(ResBlock(block_in, block_in)) ### end self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6) self.conv_out = nn.Conv2d(block_out, config.z_channels, kernel_size=(1, 1)) def forward(self, x): ## down x = self.conv_in(x) for i_level in range(self.num_blocks): for i_block in range(self.num_res_blocks): x = self.down[i_level].block[i_block](x) if i_level < self.num_blocks - 1: x = self.down[i_level].downsample(x) ## mid for res in range(self.num_res_blocks): x = self.mid_block[res](x) x = self.norm_out(x) x = swish(x) x = self.conv_out(x) return x class Decoder(nn.Module): # def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)) -> None: def __init__(self, config: VQConfig) -> None: super().__init__() self.base_channels = config.base_channels self.num_blocks = len(config.ch_mult) self.num_res_blocks = config.num_res_blocks block_in = self.base_channels * config.ch_mult[self.num_blocks - 1] self.conv_in = nn.Conv2d( config.z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True ) self.mid_block = nn.ModuleList() for res_idx in range(self.num_res_blocks): self.mid_block.append(ResBlock(block_in, block_in)) self.up = nn.ModuleList() for i_level in reversed(range(self.num_blocks)): block = nn.ModuleList() block_out = self.base_channels * config.ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResBlock(block_in, block_out)) block_in = block_out up = nn.Module() up.block = block if i_level > 0: up.upsample = Upsampler(block_in) self.up.insert(0, up) self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) self.conv_out = nn.Conv2d(block_in, config.out_channels, kernel_size=(3, 3), padding=1) def forward(self, z): z = self.conv_in(z) ## mid for res in range(self.num_res_blocks): z = self.mid_block[res](z) ## upsample for i_level in reversed(range(self.num_blocks)): for i_block in range(self.num_res_blocks): z = self.up[i_level].block[i_block](z) if i_level > 0: z = self.up[i_level].upsample(z) z = self.norm_out(z) z = swish(z) z = self.conv_out(z) return z def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor: """ Depth-to-Space DCR mode (depth-column-row) core implementation. Args: x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported. block_size (int): block side size """ # check inputs if x.dim() < 3: raise ValueError( f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions" ) c, h, w = x.shape[-3:] s = block_size**2 if c % s != 0: raise ValueError( f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels" ) outer_dims = x.shape[:-3] # splitting two additional dimensions from the channel dimension x = x.view(-1, block_size, block_size, c // s, h, w) # putting the two new dimensions along H and W x = x.permute(0, 3, 4, 1, 5, 2) # merging the two new dimensions with H and W x = x.contiguous().view(*outer_dims, c // s, h * block_size, w * block_size) return x class Upsampler(nn.Module): def __init__( self, dim, dim_out = None ): super().__init__() dim_out = dim * 4 self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1) self.depth2space = depth_to_space def forward(self, x): """ input_image: [B C H W] """ out = self.conv1(x) out = self.depth2space(out, block_size=2) return out # if __name__ == "__main__": # x = torch.randn(size = (2, 3, 128, 128)) # encoder = Encoder(ch=128, in_channels=3, num_res_blocks=2, z_channels=18, out_ch=3, resolution=128) # decoder = Decoder(out_ch=3, z_channels=18, num_res_blocks=2, ch=128, in_channels=3, resolution=128) # z = encoder(x) # out = decoder(z)