Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from detection_models.sync_batchnorm import DataParallelWithCallback | |
from detection_models.antialiasing import Downsample | |
class UNet(nn.Module): | |
def __init__( | |
self, | |
in_channels=3, | |
out_channels=3, | |
depth=5, | |
conv_num=2, | |
wf=6, | |
padding=True, | |
batch_norm=True, | |
up_mode="upsample", | |
with_tanh=False, | |
sync_bn=True, | |
antialiasing=True, | |
): | |
""" | |
Implementation of | |
U-Net: Convolutional Networks for Biomedical Image Segmentation | |
(Ronneberger et al., 2015) | |
https://arxiv.org/abs/1505.04597 | |
Using the default arguments will yield the exact version used | |
in the original paper | |
Args: | |
in_channels (int): number of input channels | |
out_channels (int): number of output channels | |
depth (int): depth of the network | |
wf (int): number of filters in the first layer is 2**wf | |
padding (bool): if True, apply padding such that the input shape | |
is the same as the output. | |
This may introduce artifacts | |
batch_norm (bool): Use BatchNorm after layers with an | |
activation function | |
up_mode (str): one of 'upconv' or 'upsample'. | |
'upconv' will use transposed convolutions for | |
learned upsampling. | |
'upsample' will use bilinear upsampling. | |
""" | |
super().__init__() | |
assert up_mode in ("upconv", "upsample") | |
self.padding = padding | |
self.depth = depth - 1 | |
prev_channels = in_channels | |
self.first = nn.Sequential( | |
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] | |
) | |
prev_channels = 2 ** wf | |
self.down_path = nn.ModuleList() | |
self.down_sample = nn.ModuleList() | |
for i in range(depth): | |
if antialiasing and depth > 0: | |
self.down_sample.append( | |
nn.Sequential( | |
*[ | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), | |
nn.BatchNorm2d(prev_channels), | |
nn.LeakyReLU(0.2, True), | |
Downsample(channels=prev_channels, stride=2), | |
] | |
) | |
) | |
else: | |
self.down_sample.append( | |
nn.Sequential( | |
*[ | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), | |
nn.BatchNorm2d(prev_channels), | |
nn.LeakyReLU(0.2, True), | |
] | |
) | |
) | |
self.down_path.append( | |
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) | |
) | |
prev_channels = 2 ** (wf + i + 1) | |
self.up_path = nn.ModuleList() | |
for i in reversed(range(depth)): | |
self.up_path.append( | |
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) | |
) | |
prev_channels = 2 ** (wf + i) | |
if with_tanh: | |
self.last = nn.Sequential( | |
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] | |
) | |
else: | |
self.last = nn.Sequential( | |
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] | |
) | |
if sync_bn: | |
self = DataParallelWithCallback(self) | |
def forward(self, x): | |
x = self.first(x) | |
blocks = [] | |
for i, down_block in enumerate(self.down_path): | |
blocks.append(x) | |
x = self.down_sample[i](x) | |
x = down_block(x) | |
for i, up in enumerate(self.up_path): | |
x = up(x, blocks[-i - 1]) | |
return self.last(x) | |
class UNetConvBlock(nn.Module): | |
def __init__(self, conv_num, in_size, out_size, padding, batch_norm): | |
super(UNetConvBlock, self).__init__() | |
block = [] | |
for _ in range(conv_num): | |
block.append(nn.ReflectionPad2d(padding=int(padding))) | |
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) | |
if batch_norm: | |
block.append(nn.BatchNorm2d(out_size)) | |
block.append(nn.LeakyReLU(0.2, True)) | |
in_size = out_size | |
self.block = nn.Sequential(*block) | |
def forward(self, x): | |
out = self.block(x) | |
return out | |
class UNetUpBlock(nn.Module): | |
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): | |
super(UNetUpBlock, self).__init__() | |
if up_mode == "upconv": | |
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) | |
elif up_mode == "upsample": | |
self.up = nn.Sequential( | |
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), | |
) | |
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) | |
def center_crop(self, layer, target_size): | |
_, _, layer_height, layer_width = layer.size() | |
diff_y = (layer_height - target_size[0]) // 2 | |
diff_x = (layer_width - target_size[1]) // 2 | |
return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])] | |
def forward(self, x, bridge): | |
up = self.up(x) | |
crop1 = self.center_crop(bridge, up.shape[2:]) | |
out = torch.cat([up, crop1], 1) | |
out = self.conv_block(out) | |
return out | |
class UnetGenerator(nn.Module): | |
"""Create a Unet-based generator""" | |
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False): | |
"""Construct a Unet generator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
output_nc (int) -- the number of channels in output images | |
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, | |
image of size 128x128 will become of size 1x1 # at the bottleneck | |
ngf (int) -- the number of filters in the last conv layer | |
norm_layer -- normalization layer | |
We construct the U-Net from the innermost layer to the outermost layer. | |
It is a recursive process. | |
""" | |
super().__init__() | |
if norm_type == "BN": | |
norm_layer = nn.BatchNorm2d | |
elif norm_type == "IN": | |
norm_layer = nn.InstanceNorm2d | |
else: | |
raise NameError("Unknown norm layer") | |
# construct unet structure | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True | |
) # add the innermost layer | |
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 8, | |
ngf * 8, | |
input_nc=None, | |
submodule=unet_block, | |
norm_layer=norm_layer, | |
use_dropout=use_dropout, | |
) | |
# gradually reduce the number of filters from ngf * 8 to ngf | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
) | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
) | |
unet_block = UnetSkipConnectionBlock( | |
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
) | |
self.model = UnetSkipConnectionBlock( | |
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer | |
) # add the outermost layer | |
def forward(self, input): | |
return self.model(input) | |
class UnetSkipConnectionBlock(nn.Module): | |
"""Defines the Unet submodule with skip connection. | |
-------------------identity---------------------- | |
|-- downsampling -- |submodule| -- upsampling --| | |
""" | |
def __init__( | |
self, | |
outer_nc, | |
inner_nc, | |
input_nc=None, | |
submodule=None, | |
outermost=False, | |
innermost=False, | |
norm_layer=nn.BatchNorm2d, | |
use_dropout=False, | |
): | |
"""Construct a Unet submodule with skip connections. | |
Parameters: | |
outer_nc (int) -- the number of filters in the outer conv layer | |
inner_nc (int) -- the number of filters in the inner conv layer | |
input_nc (int) -- the number of channels in input images/features | |
submodule (UnetSkipConnectionBlock) -- previously defined submodules | |
outermost (bool) -- if this module is the outermost module | |
innermost (bool) -- if this module is the innermost module | |
norm_layer -- normalization layer | |
user_dropout (bool) -- if use dropout layers. | |
""" | |
super().__init__() | |
self.outermost = outermost | |
use_bias = norm_layer == nn.InstanceNorm2d | |
if input_nc is None: | |
input_nc = outer_nc | |
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) | |
downrelu = nn.LeakyReLU(0.2, True) | |
downnorm = norm_layer(inner_nc) | |
uprelu = nn.LeakyReLU(0.2, True) | |
upnorm = norm_layer(outer_nc) | |
if outermost: | |
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) | |
down = [downconv] | |
up = [uprelu, upconv, nn.Tanh()] | |
model = down + [submodule] + up | |
elif innermost: | |
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) | |
down = [downrelu, downconv] | |
up = [uprelu, upconv, upnorm] | |
model = down + up | |
else: | |
upconv = nn.ConvTranspose2d( | |
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias | |
) | |
down = [downrelu, downconv, downnorm] | |
up = [uprelu, upconv, upnorm] | |
if use_dropout: | |
model = down + [submodule] + up + [nn.Dropout(0.5)] | |
else: | |
model = down + [submodule] + up | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
if self.outermost: | |
return self.model(x) | |
else: # add skip connections | |
return torch.cat([x, self.model(x)], 1) | |
# ============================================ | |
# Network testing | |
# ============================================ | |
if __name__ == "__main__": | |
from torchsummary import summary | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = UNet_two_decoders( | |
in_channels=3, | |
out_channels1=3, | |
out_channels2=1, | |
depth=4, | |
conv_num=1, | |
wf=6, | |
padding=True, | |
batch_norm=True, | |
up_mode="upsample", | |
with_tanh=False, | |
) | |
model.to(device) | |
model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False) | |
model_pix2pix.to(device) | |
print("customized unet:") | |
summary(model, (3, 256, 256)) | |
print("cyclegan unet:") | |
summary(model_pix2pix, (3, 256, 256)) | |
x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda() | |
g = make_dot(model(x)) | |
g.render("models/Digraph.gv", view=False) | |