wuxulong19950206
First model version
14d1720
import torch
import torch.nn as nn
import torch.nn.functional as F
class double_res_conv(nn.Module):
def __init__(self, in_ch, out_ch, bn=False):
super(double_res_conv, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
)
self.relu = nn.LeakyReLU(0.1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.relu(x2)
return x3
class inconv(nn.Module):
def __init__(self, in_ch, out_ch, bn=True):
super(inconv, self).__init__()
self.conv = double_res_conv(in_ch, out_ch, bn)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch, bn=True):
super(down, self).__init__()
self.mpconv = nn.Sequential(nn.AvgPool2d(2), double_res_conv(in_ch, out_ch, bn))
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True, bn=True):
super(up, self).__init__()
self.bilinear = bilinear
if not bilinear:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
self.conv = double_res_conv(in_ch, out_ch, bn)
def forward(self, x1, x2):
if not self.bilinear:
x1 = self.up(x1)
else:
x1 = nn.functional.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1, padding=0)
def forward(self, x):
x = self.conv(x)
return x
class PostUNet(nn.Module):
def __init__(self, n_channels=1, scale=1):
super(PostUNet, self).__init__()
self.inc = inconv(n_channels, 64 // scale)
self.down1 = down(64 // scale, 128 // scale)
self.down2 = down(128 // scale, 256 // scale)
self.down3 = down(256 // scale, 512 // scale)
self.down4 = down(512 // scale, 512 // scale)
self.up1 = up(1024 // scale, 256 // scale)
self.up2 = up(512 // scale, 128 // scale)
self.up3 = up(256 // scale, 64 // scale)
self.up4 = up(128 // scale, 32 // scale)
self.reduce = outconv(32 // scale, 1)
def forward(self, x0):
x1 = self.inc(x0)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.reduce(x)
x = x[:, 0, :, :]
return x