""" Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import torch import torch.nn as nn from einops import rearrange from .convnext import CvnxtBlock class ConvUpsample(nn.Module): def __init__( self, hidden_dim, num_layers: int = 2, expansion: int = 4, layer_scale: float = 1.0, kernel_size: int = 7, **kwargs ): super().__init__() self.convs = nn.ModuleList([]) for _ in range(num_layers): self.convs.append( CvnxtBlock( hidden_dim, kernel_size=kernel_size, expansion=expansion, layer_scale=layer_scale, ) ) self.up = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0), nn.UpsamplingBilinear2d(scale_factor=2), nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor): for conv in self.convs: x = conv(x) x = self.up(x) x = rearrange(x, "b c h w -> b (h w) c") return x class ConvUpsampleShuffle(nn.Module): def __init__( self, hidden_dim, expansion: int = 4, layer_scale: float = 1.0, **kwargs ): super().__init__() self.conv1 = CvnxtBlock( hidden_dim, expansion=expansion, layer_scale=layer_scale ) self.conv2 = CvnxtBlock( hidden_dim, expansion=expansion, layer_scale=layer_scale ) self.up = nn.Sequential( nn.PixelShuffle(2), nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor): x = self.conv1(x) x = self.conv2(x) x = self.up(x) x = rearrange(x, "b c h w -> b (h w) c") return x