|
import math
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.model_zoo as model_zoo
|
|
from torch import nn
|
|
from torch.nn import Parameter
|
|
import pdb
|
|
import numpy as np
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, inchannel, outchannel, stride=1):
|
|
super(ResidualBlock, self).__init__()
|
|
self.left = nn.Sequential(
|
|
nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
|
|
nn.BatchNorm2d(outchannel),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
|
|
nn.BatchNorm2d(outchannel)
|
|
)
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or inchannel != outchannel:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
|
|
nn.BatchNorm2d(outchannel)
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.left(x)
|
|
out += self.shortcut(x)
|
|
out = F.relu(out)
|
|
return out
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
def __init__(self, ResidualBlock = ResidualBlock):
|
|
super(ResNet, self).__init__()
|
|
self.inchannel = 64
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
)
|
|
self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
|
|
self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
|
|
self.layer3 = self.make_layer(ResidualBlock, 196, 2, stride=2)
|
|
self.layer4 = self.make_layer(ResidualBlock, 256, 2, stride=2)
|
|
self.lastconv1 = nn.Sequential(
|
|
nn.Conv2d(128 + 196 + 256, 128, kernel_size=3, stride=1, padding=1, bias=False),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(),
|
|
)
|
|
self.lastconv2 = nn.Sequential(
|
|
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
)
|
|
self.lastconv3 = nn.Sequential(
|
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
|
nn.ReLU(),
|
|
)
|
|
self.downsample32x32 = nn.Upsample(size=(32, 32), mode='bilinear')
|
|
|
|
self.mu_head = nn.Sequential(
|
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
|
|
|
)
|
|
self.logvar_head = nn.Sequential(
|
|
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
|
|
|
)
|
|
|
|
def make_layer(self, block, channels, num_blocks, stride):
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
layers = []
|
|
for stride in strides:
|
|
layers.append(block(self.inchannel, channels, stride))
|
|
self.inchannel = channels
|
|
return nn.Sequential(*layers)
|
|
|
|
def _reparameterize(self, mu, logvar):
|
|
std = torch.exp(logvar).sqrt()
|
|
epsilon = torch.randn_like(std)
|
|
return mu + epsilon * std
|
|
|
|
def forward(self, x):
|
|
x_input = x
|
|
x = self.conv1(x)
|
|
x_block1 = self.layer1(x)
|
|
x_block2 = self.layer2(x_block1)
|
|
x_block2_32 = self.downsample32x32(x_block2)
|
|
x_block3 = self.layer3(x_block2)
|
|
x_block3_32 = self.downsample32x32(x_block3)
|
|
x_block4 = self.layer4(x_block3)
|
|
x_block4_32 = self.downsample32x32(x_block4)
|
|
|
|
x_concat = torch.cat((x_block2_32, x_block3_32, x_block4_32), dim=1)
|
|
|
|
x = self.lastconv1(x_concat)
|
|
x = self.lastconv2(x)
|
|
mu = self.mu_head(x)
|
|
mu = mu.squeeze(1)
|
|
logvar = self.logvar_head(x)
|
|
logvar = logvar.squeeze(1)
|
|
embedding = self._reparameterize(mu, logvar)
|
|
|
|
return mu, logvar, embedding, x_concat, x_block2, x_block3, x_block4, x_input
|
|
|
|
|
|
def ResNet18_u():
|
|
|
|
return ResNet(ResidualBlock)
|
|
|
|
|