|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Recon(nn.Module): |
|
def __init__(self, ch_in, ch_out): |
|
super().__init__() |
|
self.long = nn.Sequential( |
|
nn.Conv2d(ch_in, ch_out, 3, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(ch_out, ch_out, 3, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(ch_out, ch_out, 3, padding=1) |
|
) |
|
if ch_in != ch_out: |
|
self.short = nn.Conv2d(ch_in, ch_out, 1, bias=False) |
|
else: |
|
|
|
self.short = nn.Identity() |
|
self.fuse = nn.ReLU() |
|
|
|
def forward(self, x): |
|
return self.fuse(self.long(x) + self.short(x)) |
|
|
|
class TeaDecoder(nn.Module): |
|
def __init__(self, ch_in): |
|
super().__init__() |
|
self.block_in = nn.Sequential( |
|
nn.Conv2d(ch_in, 64, 3, padding=1), |
|
nn.ReLU() |
|
) |
|
self.middle = nn.Sequential( |
|
*[Recon(64, 64) for _ in range(3)], |
|
|
|
nn.Upsample(scale_factor=2), |
|
|
|
|
|
|
|
|
|
nn.Conv2d(64, 64, 3, padding=1, bias=False), |
|
|
|
*[Recon(64, 64) for _ in range(3)], |
|
nn.Upsample(scale_factor=2), |
|
nn.Conv2d(64, 64, 3, padding=1, bias=False), |
|
) |
|
self.block_out = nn.Sequential( |
|
Recon(64, 64), |
|
|
|
nn.Conv2d(64, 3, 3, padding=1), |
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
clamped = torch.tanh(x / 1) |
|
cooked = self.middle(self.block_in(clamped)) |
|
|
|
return self.block_out(cooked) |
|
|