import torch import torch.nn as nn from torchinfo import summary import torchvision resnet = torchvision.models.resnet.resnet50(pretrained=True) class ConvBlock(nn.Module): """ Helper module that consists of a Conv -> BN -> ReLU """ def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.with_nonlinearity = with_nonlinearity def forward(self, x): x = self.conv(x) x = self.bn(x) if self.with_nonlinearity: x = self.relu(x) return x class Bridge(nn.Module): """ This is the middle layer of the UNet which just consists of some """ def __init__(self, in_channels, out_channels): super().__init__() self.bridge = nn.Sequential( ConvBlock(in_channels, out_channels), ConvBlock(out_channels, out_channels) ) def forward(self, x): return self.bridge(x) class UpBlockForUNetWithResNet50(nn.Module): """ Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock """ def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, upsampling_method="conv_transpose"): super().__init__() if up_conv_in_channels == None: up_conv_in_channels = in_channels if up_conv_out_channels == None: up_conv_out_channels = out_channels if upsampling_method == "conv_transpose": self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) elif upsampling_method == "bilinear": self.upsample = nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) ) self.conv_block_1 = ConvBlock(in_channels, out_channels) self.conv_block_2 = ConvBlock(out_channels, out_channels) def forward(self, up_x, down_x): """ :param up_x: this is the output from the previous up block :param down_x: this is the output from the down block :return: upsampled feature map """ x = self.upsample(up_x) print(x.shape) print(down_x.shape) x = torch.cat([x, down_x], 1) x = self.conv_block_1(x) x = self.conv_block_2(x) return x class UNetWithResnet50Encoder(nn.Module): DEPTH = 6 def __init__(self, max_depth, n_classes=1): super().__init__() resnet = torchvision.models.resnet.resnet50(pretrained=True) down_blocks = [] up_blocks = [] self.input_block = nn.Sequential(*list(resnet.children()))[:3] self.input_pool = list(resnet.children())[3] for bottleneck in list(resnet.children()): if isinstance(bottleneck, nn.Sequential): down_blocks.append(bottleneck) self.down_blocks = nn.ModuleList(down_blocks) self.bridge = Bridge(2048, 2048) up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, up_conv_in_channels=256, up_conv_out_channels=128)) up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, up_conv_in_channels=128, up_conv_out_channels=64)) self.up_blocks = nn.ModuleList(up_blocks) self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) self.max_depth = max_depth def forward(self, x, with_output_feature_map=False): pre_pools = dict() pre_pools[f"layer_0"] = x x = self.input_block(x) pre_pools[f"layer_1"] = x x = self.input_pool(x) for i, block in enumerate(self.down_blocks, 2): x = block(x) if i == (UNetWithResnet50Encoder.DEPTH - 1): continue pre_pools[f"layer_{i}"] = x x = self.bridge(x) for i, block in enumerate(self.up_blocks, 1): key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" x = block(x, pre_pools[key]) output_feature_map = x x = self.out(x) del pre_pools # if with_output_feature_map: # return x, output_feature_map # else: # return x out_depth = torch.sigmoid(x) * self.max_depth return {'pred_d': out_depth} # model = UNetWithResnet50Encoder().cuda() # inp = torch.rand((2, 3, 512, 512)).cuda() # out = model(inp) if __name__ == "__main__": model = UNetWithResnet50Encoder(max_depth=10).cuda() # print(model) summary(model, input_size=(1,3,256,256))