import torch import torch.nn as nn import torch.nn.functional as F class double_res_conv(nn.Module): def __init__(self, in_ch, out_ch, bn=False): super(double_res_conv, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.InstanceNorm2d(out_ch), ) self.conv2 = nn.Sequential( nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.InstanceNorm2d(out_ch), ) self.relu = nn.LeakyReLU(0.1) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.relu(x2) return x3 class inconv(nn.Module): def __init__(self, in_ch, out_ch, bn=True): super(inconv, self).__init__() self.conv = double_res_conv(in_ch, out_ch, bn) def forward(self, x): x = self.conv(x) return x class down(nn.Module): def __init__(self, in_ch, out_ch, bn=True): super(down, self).__init__() self.mpconv = nn.Sequential(nn.AvgPool2d(2), double_res_conv(in_ch, out_ch, bn)) def forward(self, x): x = self.mpconv(x) return x class up(nn.Module): def __init__(self, in_ch, out_ch, bilinear=True, bn=True): super(up, self).__init__() self.bilinear = bilinear if not bilinear: self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) self.conv = double_res_conv(in_ch, out_ch, bn) def forward(self, x1, x2): if not self.bilinear: x1 = self.up(x1) else: x1 = nn.functional.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) x = torch.cat([x2, x1], dim=1) x = self.conv(x) return x class outconv(nn.Module): def __init__(self, in_ch, out_ch): super(outconv, self).__init__() self.conv = nn.Conv2d(in_ch, out_ch, 1, padding=0) def forward(self, x): x = self.conv(x) return x class PostUNet(nn.Module): def __init__(self, n_channels=1, scale=1): super(PostUNet, self).__init__() self.inc = inconv(n_channels, 64 // scale) self.down1 = down(64 // scale, 128 // scale) self.down2 = down(128 // scale, 256 // scale) self.down3 = down(256 // scale, 512 // scale) self.down4 = down(512 // scale, 512 // scale) self.up1 = up(1024 // scale, 256 // scale) self.up2 = up(512 // scale, 128 // scale) self.up3 = up(256 // scale, 64 // scale) self.up4 = up(128 // scale, 32 // scale) self.reduce = outconv(32 // scale, 1) def forward(self, x0): x1 = self.inc(x0) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) x = self.reduce(x) x = x[:, 0, :, :] return x