Spaces:
Sleeping
Sleeping
File size: 3,149 Bytes
14d1720 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class double_res_conv(nn.Module):
def __init__(self, in_ch, out_ch, bn=False):
super(double_res_conv, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
)
self.relu = nn.LeakyReLU(0.1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.relu(x2)
return x3
class inconv(nn.Module):
def __init__(self, in_ch, out_ch, bn=True):
super(inconv, self).__init__()
self.conv = double_res_conv(in_ch, out_ch, bn)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch, bn=True):
super(down, self).__init__()
self.mpconv = nn.Sequential(nn.AvgPool2d(2), double_res_conv(in_ch, out_ch, bn))
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True, bn=True):
super(up, self).__init__()
self.bilinear = bilinear
if not bilinear:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
self.conv = double_res_conv(in_ch, out_ch, bn)
def forward(self, x1, x2):
if not self.bilinear:
x1 = self.up(x1)
else:
x1 = nn.functional.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1, padding=0)
def forward(self, x):
x = self.conv(x)
return x
class PostUNet(nn.Module):
def __init__(self, n_channels=1, scale=1):
super(PostUNet, self).__init__()
self.inc = inconv(n_channels, 64 // scale)
self.down1 = down(64 // scale, 128 // scale)
self.down2 = down(128 // scale, 256 // scale)
self.down3 = down(256 // scale, 512 // scale)
self.down4 = down(512 // scale, 512 // scale)
self.up1 = up(1024 // scale, 256 // scale)
self.up2 = up(512 // scale, 128 // scale)
self.up3 = up(256 // scale, 64 // scale)
self.up4 = up(128 // scale, 32 // scale)
self.reduce = outconv(32 // scale, 1)
def forward(self, x0):
x1 = self.inc(x0)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.reduce(x)
x = x[:, 0, :, :]
return x
|