File size: 3,149 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F


class double_res_conv(nn.Module):
    def __init__(self, in_ch, out_ch, bn=False):
        super(double_res_conv, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm2d(out_ch),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm2d(out_ch),
        )

        self.relu = nn.LeakyReLU(0.1)

    def forward(self, x):

        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.relu(x2)

        return x3


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super(inconv, self).__init__()

        self.conv = double_res_conv(in_ch, out_ch, bn)

    def forward(self, x):

        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(nn.AvgPool2d(2), double_res_conv(in_ch, out_ch, bn))

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True, bn=True):
        super(up, self).__init__()

        self.bilinear = bilinear
        if not bilinear:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = double_res_conv(in_ch, out_ch, bn)

    def forward(self, x1, x2):
        if not self.bilinear:
            x1 = self.up(x1)
        else:
            x1 = nn.functional.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1, padding=0)

    def forward(self, x):
        x = self.conv(x)

        return x


class PostUNet(nn.Module):
    def __init__(self, n_channels=1, scale=1):
        super(PostUNet, self).__init__()

        self.inc = inconv(n_channels, 64 // scale)
        self.down1 = down(64 // scale, 128 // scale)
        self.down2 = down(128 // scale, 256 // scale)
        self.down3 = down(256 // scale, 512 // scale)
        self.down4 = down(512 // scale, 512 // scale)

        self.up1 = up(1024 // scale, 256 // scale)
        self.up2 = up(512 // scale, 128 // scale)
        self.up3 = up(256 // scale, 64 // scale)
        self.up4 = up(128 // scale, 32 // scale)

        self.reduce = outconv(32 // scale, 1)

    def forward(self, x0):
        x1 = self.inc(x0)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.reduce(x)
        x = x[:, 0, :, :]
        return x