|
import os |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from torch.utils.data import Dataset, DataLoader |
|
from skimage import io, transform |
|
from PIL import Image |
|
import torch.nn as nn |
|
from torchvision import transforms, utils, models |
|
import torch.nn.functional as F |
|
import utils.resnet as resnet |
|
|
|
from utils.TransformerEncoder import Encoder |
|
|
|
|
|
|
|
cfg1 = { |
|
"hidden_size" : 768, |
|
"mlp_dim" : 768*4, |
|
"num_heads" : 12, |
|
"num_layers" : 2, |
|
"attention_dropout_rate" : 0, |
|
"dropout_rate" : 0.0, |
|
} |
|
|
|
cfg2 = { |
|
"hidden_size" : 768, |
|
"mlp_dim" : 768*4, |
|
"num_heads" : 12, |
|
"num_layers" : 2, |
|
"attention_dropout_rate" : 0, |
|
"dropout_rate" : 0.0, |
|
} |
|
|
|
cfg3 = { |
|
"hidden_size" : 512, |
|
"mlp_dim" : 512*4, |
|
"num_heads" : 8, |
|
"num_layers" : 2, |
|
"attention_dropout_rate" : 0, |
|
"dropout_rate" : 0.0, |
|
} |
|
|
|
|
|
class TranSalNet(nn.Module): |
|
|
|
def __init__(self): |
|
super(TranSalNet, self).__init__() |
|
self.encoder = _Encoder() |
|
self.decoder = _Decoder() |
|
|
|
def forward(self, x): |
|
x = self.encoder(x) |
|
x = self.decoder(x) |
|
return x |
|
|
|
|
|
class _Encoder(nn.Module): |
|
def __init__(self): |
|
super(_Encoder, self).__init__() |
|
base_model = resnet.resnet50(pretrained=True) |
|
base_layers = list(base_model.children())[:8] |
|
self.encoder = nn.ModuleList(base_layers).eval() |
|
|
|
def forward(self, x): |
|
outputs = [] |
|
for ii,layer in enumerate(self.encoder): |
|
x = layer(x) |
|
if ii in {5,6,7}: |
|
outputs.append(x) |
|
return outputs |
|
|
|
|
|
class _Decoder(nn.Module): |
|
|
|
def __init__(self): |
|
super(_Decoder, self).__init__() |
|
self.conv1 = nn.Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
self.conv2 = nn.Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
self.conv3 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
self.conv4 = nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
self.conv5 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
self.conv6 = nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
self.conv7 = nn.Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
|
|
self.batchnorm1 = nn.BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
self.batchnorm2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
self.batchnorm3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
self.batchnorm4 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
self.batchnorm5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
self.batchnorm6 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
|
|
self.TransEncoder1 = TransEncoder(in_channels=2048, spatial_size=9*12, cfg=cfg1) |
|
self.TransEncoder2 = TransEncoder(in_channels=1024, spatial_size=18*24, cfg=cfg2) |
|
self.TransEncoder3 = TransEncoder(in_channels=512, spatial_size=36*48, cfg=cfg3) |
|
|
|
self.add = torch.add |
|
self.relu = nn.ReLU(True) |
|
self.upsample = nn.Upsample(scale_factor=2, mode='nearest') |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x3, x4, x5 = x |
|
|
|
x5 = self.TransEncoder1(x5) |
|
x5 = self.conv1(x5) |
|
x5 = self.batchnorm1(x5) |
|
x5 = self.relu(x5) |
|
x5 = self.upsample(x5) |
|
|
|
x4_a = self.TransEncoder2(x4) |
|
x4 = x5 * x4_a |
|
x4 = self.relu(x4) |
|
x4 = self.conv2(x4) |
|
x4 = self.batchnorm2(x4) |
|
x4 = self.relu(x4) |
|
x4 = self.upsample(x4) |
|
|
|
x3_a = self.TransEncoder3(x3) |
|
x3 = x4 * x3_a |
|
x3 = self.relu(x3) |
|
x3 = self.conv3(x3) |
|
x3 = self.batchnorm3(x3) |
|
x3 = self.relu(x3) |
|
x3 = self.upsample(x3) |
|
|
|
x2 = self.conv4(x3) |
|
x2 = self.batchnorm4(x2) |
|
x2 = self.relu(x2) |
|
x2 = self.upsample(x2) |
|
x2 = self.conv5(x2) |
|
x2 = self.batchnorm5(x2) |
|
x2 = self.relu(x2) |
|
|
|
x1 = self.upsample(x2) |
|
x1 = self.conv6(x1) |
|
x1 = self.batchnorm6(x1) |
|
x1 = self.relu(x1) |
|
x1 = self.conv7(x1) |
|
x = self.sigmoid(x1) |
|
|
|
return x |
|
|
|
|
|
class TransEncoder(nn.Module): |
|
|
|
def __init__(self, in_channels, spatial_size, cfg): |
|
super(TransEncoder, self).__init__() |
|
|
|
self.patch_embeddings = nn.Conv2d(in_channels=in_channels, |
|
out_channels=cfg['hidden_size'], |
|
kernel_size=1, |
|
stride=1) |
|
self.position_embeddings = nn.Parameter(torch.zeros(1, spatial_size, cfg['hidden_size'])) |
|
|
|
self.transformer_encoder = Encoder(cfg) |
|
|
|
def forward(self, x): |
|
a, b = x.shape[2], x.shape[3] |
|
x = self.patch_embeddings(x) |
|
x = x.flatten(2) |
|
x = x.transpose(-1, -2) |
|
|
|
embeddings = x + self.position_embeddings |
|
x = self.transformer_encoder(embeddings) |
|
B, n_patch, hidden = x.shape |
|
x = x.permute(0, 2, 1) |
|
x = x.contiguous().view(B, hidden, a, b) |
|
|
|
return x |
|
|
|
|