File size: 1,987 Bytes
ffbcf9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import torch
import torch.nn as nn
from einops import rearrange
from .convnext import CvnxtBlock
class ConvUpsample(nn.Module):
def __init__(
self,
hidden_dim,
num_layers: int = 2,
expansion: int = 4,
layer_scale: float = 1.0,
kernel_size: int = 7,
**kwargs
):
super().__init__()
self.convs = nn.ModuleList([])
for _ in range(num_layers):
self.convs.append(
CvnxtBlock(
hidden_dim,
kernel_size=kernel_size,
expansion=expansion,
layer_scale=layer_scale,
)
)
self.up = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor):
for conv in self.convs:
x = conv(x)
x = self.up(x)
x = rearrange(x, "b c h w -> b (h w) c")
return x
class ConvUpsampleShuffle(nn.Module):
def __init__(
self, hidden_dim, expansion: int = 4, layer_scale: float = 1.0, **kwargs
):
super().__init__()
self.conv1 = CvnxtBlock(
hidden_dim, expansion=expansion, layer_scale=layer_scale
)
self.conv2 = CvnxtBlock(
hidden_dim, expansion=expansion, layer_scale=layer_scale
)
self.up = nn.Sequential(
nn.PixelShuffle(2),
nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor):
x = self.conv1(x)
x = self.conv2(x)
x = self.up(x)
x = rearrange(x, "b c h w -> b (h w) c")
return x
|