|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mean = [0.485, 0.456, 0.406]
|
|
self.std = [0.229, 0.224, 0.225]
|
|
|
|
self.enc_conv0 = nn.Sequential(
|
|
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(64),
|
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(64),
|
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(64)
|
|
)
|
|
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.enc_conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(128),
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(128),
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(128)
|
|
)
|
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.enc_conv2 = nn.Sequential(
|
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(256),
|
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(256),
|
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(256)
|
|
)
|
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.enc_conv3 = nn.Sequential(
|
|
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(512),
|
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(512),
|
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(512)
|
|
)
|
|
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
|
|
self.bottleneck_conv = nn.Sequential(
|
|
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(1024),
|
|
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(1024),
|
|
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(1024)
|
|
)
|
|
|
|
|
|
|
|
self.upsample0 = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
|
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
|
|
)
|
|
|
|
self.dec_conv0 = nn.Sequential(
|
|
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(512),
|
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(512),
|
|
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(512)
|
|
)
|
|
|
|
self.upsample1 = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
|
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
|
|
)
|
|
|
|
self.dec_conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(256),
|
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(256),
|
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(256)
|
|
)
|
|
|
|
self.upsample2 = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
|
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
|
|
)
|
|
|
|
self.dec_conv2 = nn.Sequential(
|
|
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(128),
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(128),
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(128)
|
|
)
|
|
|
|
self.upsample3 = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
|
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
|
|
)
|
|
|
|
self.dec_conv3 = nn.Sequential(
|
|
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(64),
|
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.BatchNorm2d(64),
|
|
nn.Conv2d(in_channels=64, out_channels=18, kernel_size=1, stride=1, padding=0)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
e0 = self.enc_conv0(x)
|
|
e1 = self.pool0(e0)
|
|
e1 = self.enc_conv1(e1)
|
|
e2 = self.pool1(e1)
|
|
e2 = self.enc_conv2(e2)
|
|
e3 = self.pool2(e2)
|
|
e3 = self.enc_conv3(e3)
|
|
|
|
|
|
b = self.pool3(e3)
|
|
b = self.bottleneck_conv(b)
|
|
|
|
|
|
d0 = self.upsample0(b)
|
|
d0 = torch.cat([d0, e3], dim=1)
|
|
d0 = self.dec_conv0(d0)
|
|
|
|
d1 = self.upsample1(d0)
|
|
d1 = torch.cat([d1, e2], dim=1)
|
|
d1 = self.dec_conv1(d1)
|
|
|
|
d2 = self.upsample2(d1)
|
|
d2 = torch.cat([d2, e1], dim=1)
|
|
d2 = self.dec_conv2(d2)
|
|
|
|
d3 = self.upsample3(d2)
|
|
d3 = torch.cat([d3, e0], dim=1)
|
|
d3 = self.dec_conv3(d3)
|
|
return d3 |