Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7febe9c
1
Parent(s):
1a1cf3c
Initialization.
Browse files- .gitignore +134 -0
- app.py +73 -0
- config.py +107 -0
- models/GCoNet.py +248 -0
- models/modules.py +516 -0
.gitignore
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Custom
|
2 |
+
.vscode
|
3 |
+
*.pth
|
4 |
+
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
pip-wheel-metadata/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
.python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
100 |
+
__pypackages__/
|
101 |
+
|
102 |
+
# Celery stuff
|
103 |
+
celerybeat-schedule
|
104 |
+
celerybeat.pid
|
105 |
+
|
106 |
+
# SageMath parsed files
|
107 |
+
*.sage.py
|
108 |
+
|
109 |
+
# Environments
|
110 |
+
.env
|
111 |
+
.venv
|
112 |
+
env/
|
113 |
+
venv/
|
114 |
+
ENV/
|
115 |
+
env.bak/
|
116 |
+
venv.bak/
|
117 |
+
|
118 |
+
# Spyder project settings
|
119 |
+
.spyderproject
|
120 |
+
.spyproject
|
121 |
+
|
122 |
+
# Rope project settings
|
123 |
+
.ropeproject
|
124 |
+
|
125 |
+
# mkdocs documentation
|
126 |
+
/site
|
127 |
+
|
128 |
+
# mypy
|
129 |
+
.mypy_cache/
|
130 |
+
.dmypy.json
|
131 |
+
dmypy.json
|
132 |
+
|
133 |
+
# Pyre type checker
|
134 |
+
.pyre/
|
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
from torchvision import transforms
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from models.GCoNet import GCoNet
|
12 |
+
|
13 |
+
|
14 |
+
device = ['cpu', 'cuda'][0]
|
15 |
+
|
16 |
+
|
17 |
+
class ImagePreprocessor():
|
18 |
+
def __init__(self) -> None:
|
19 |
+
self.transform_image = transforms.Compose([
|
20 |
+
transforms.Resize((256, 256)),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
23 |
+
])
|
24 |
+
|
25 |
+
def proc(self, image):
|
26 |
+
image = self.transform_image(image)
|
27 |
+
return image
|
28 |
+
|
29 |
+
|
30 |
+
model = GCoNet(bb_pretrained=False).to(device)
|
31 |
+
state_dict = './ultimate_duts_cocoseg (The best one).pth'
|
32 |
+
if os.path.exists(state_dict):
|
33 |
+
gconet_dict = torch.load(state_dict, map_location=device)
|
34 |
+
model.load_state_dict(gconet_dict)
|
35 |
+
model.eval()
|
36 |
+
|
37 |
+
|
38 |
+
def pred_maps(dr):
|
39 |
+
images = [cv2.imread(image_path) for image_path in glob(os.path.join(dr, '*'))]
|
40 |
+
image_shapes = [image.shape[:2] for image in images]
|
41 |
+
images = [Image.fromarray(image) for image in images]
|
42 |
+
|
43 |
+
images_proc = []
|
44 |
+
image_preprocessor = ImagePreprocessor()
|
45 |
+
for image in images:
|
46 |
+
images_proc.append(image_preprocessor.proc(image))
|
47 |
+
images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
scaled_preds_tensor = model(images_proc.to(device))[-1]
|
51 |
+
preds = []
|
52 |
+
for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
|
53 |
+
if device == 'cuda':
|
54 |
+
pred_tensor = pred_tensor.cpu()
|
55 |
+
preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
|
56 |
+
image_preds = []
|
57 |
+
for image, pred in zip(images, preds):
|
58 |
+
image_preds.append(
|
59 |
+
cv2.cvtColor(
|
60 |
+
np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)]),
|
61 |
+
cv2.COLOR_BGR2RGB
|
62 |
+
))
|
63 |
+
# for image_pred in image_preds:
|
64 |
+
# cv2.imwrite('a.png', cv2.cvtColor(image_pred, cv2.COLOR_RGB2BGR))
|
65 |
+
return image_preds[:]
|
66 |
+
|
67 |
+
demo = gr.Interface(
|
68 |
+
fn=pred_maps,
|
69 |
+
inputs='text',
|
70 |
+
outputs=['image', 'image', 'image', 'image', 'image'],
|
71 |
+
css=".output_image, .input_image {height: 300px !important}",
|
72 |
+
)
|
73 |
+
demo.launch(debug=True)
|
config.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
class Config():
|
5 |
+
def __init__(self) -> None:
|
6 |
+
# Backbone
|
7 |
+
self.bb = ['vgg16', 'vgg16bn', 'resnet50'][1]
|
8 |
+
# BN
|
9 |
+
self.use_bn = 'bn' in self.bb or 'resnet' in self.bb
|
10 |
+
# Augmentation
|
11 |
+
self.preproc_methods = ['flip', 'enhance', 'rotate', 'crop', 'pepper'][:3]
|
12 |
+
|
13 |
+
# Mask
|
14 |
+
losses = ['sal', 'cls', 'contrast', 'cls_mask']
|
15 |
+
self.loss = losses[:]
|
16 |
+
self.cls_mask_operation = ['x', '+', 'c'][0]
|
17 |
+
# Loss + Triplet Loss
|
18 |
+
self.lambdas_sal_last = {
|
19 |
+
# not 0 means opening this loss
|
20 |
+
# original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
|
21 |
+
'bce': 30 * 1, # high performance
|
22 |
+
'iou': 0.5 * 1, # 0 / 255
|
23 |
+
'ssim': 1 * 0, # help contours
|
24 |
+
'mse': 150 * 0, # can smooth the saliency map
|
25 |
+
'reg': 100 * 0,
|
26 |
+
'triplet': 3 * 1 * ('cls' in self.loss),
|
27 |
+
}
|
28 |
+
|
29 |
+
# DB
|
30 |
+
self.db_output_decoder = True
|
31 |
+
self.db_k = 300
|
32 |
+
self.db_k_alpha = 1
|
33 |
+
self.split_mask = True and 'cls_mask' in self.loss
|
34 |
+
self.db_mask = False and self.split_mask
|
35 |
+
|
36 |
+
# Triplet Loss
|
37 |
+
self.triplet = ['_x5', 'mask'][:1]
|
38 |
+
self.triplet_loss_margin = 0.1
|
39 |
+
# Adv
|
40 |
+
self.lambda_adv = 0. # turn to 0 to avoid adv training
|
41 |
+
|
42 |
+
# Refiner
|
43 |
+
self.refine = [0, 1, 4][0] # 0 -- no refinement, 1 -- only output mask for refinement, 4 -- but also raw input.
|
44 |
+
if self.refine:
|
45 |
+
self.batch_size = 16
|
46 |
+
else:
|
47 |
+
if self.bb != 'vgg16':
|
48 |
+
self.batch_size = 26
|
49 |
+
else:
|
50 |
+
self.batch_size = 48
|
51 |
+
self.db_output_refiner = False and self.refine
|
52 |
+
|
53 |
+
# Intermediate Layers
|
54 |
+
self.lambdas_sal_others = {
|
55 |
+
'bce': 0,
|
56 |
+
'iou': 0.,
|
57 |
+
'ssim': 0,
|
58 |
+
'mse': 0,
|
59 |
+
'reg': 0,
|
60 |
+
'triplet': 0,
|
61 |
+
}
|
62 |
+
self.output_number = 1
|
63 |
+
self.loss_sal_layers = 4 # used to be last 4 layers
|
64 |
+
self.loss_cls_mask_last_layers = 1 # used to be last 4 layers
|
65 |
+
if 'keep in range':
|
66 |
+
self.loss_sal_layers = min(self.output_number, self.loss_sal_layers)
|
67 |
+
self.loss_cls_mask_last_layers = min(self.output_number, self.loss_cls_mask_last_layers)
|
68 |
+
self.output_number = min(self.output_number, max(self.loss_sal_layers, self.loss_cls_mask_last_layers))
|
69 |
+
if self.output_number == 1:
|
70 |
+
for cri in self.lambdas_sal_others:
|
71 |
+
self.lambdas_sal_others[cri] = 0
|
72 |
+
self.conv_after_itp = False
|
73 |
+
self.complex_lateral_connection = False
|
74 |
+
|
75 |
+
# to control the quantitive level of each single loss by number of output branches.
|
76 |
+
self.loss_cls_mask_ratio_by_last_layers = 4 / self.loss_cls_mask_last_layers
|
77 |
+
for loss_sal in self.lambdas_sal_last.keys():
|
78 |
+
loss_sal_ratio_by_last_layers = 4 / (int(bool(self.lambdas_sal_others[loss_sal])) * (self.loss_sal_layers - 1) + 1)
|
79 |
+
self.lambdas_sal_last[loss_sal] *= loss_sal_ratio_by_last_layers
|
80 |
+
self.lambdas_sal_others[loss_sal] *= loss_sal_ratio_by_last_layers
|
81 |
+
self.lambda_cls_mask = 2.5 * self.loss_cls_mask_ratio_by_last_layers
|
82 |
+
self.lambda_cls = 3.
|
83 |
+
self.lambda_contrast = 250.
|
84 |
+
|
85 |
+
# Performance of GCoNet
|
86 |
+
self.val_measures = {
|
87 |
+
'Emax': {'CoCA': 0.760, 'CoSOD3k': 0.860, 'CoSal2015': 0.887},
|
88 |
+
'Smeasure': {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845},
|
89 |
+
'Fmax': {'CoCA': 0.544, 'CoSOD3k': 0.777, 'CoSal2015': 0.847},
|
90 |
+
}
|
91 |
+
|
92 |
+
# others
|
93 |
+
self.GAM = True
|
94 |
+
if not self.GAM and 'contrast' in self.loss:
|
95 |
+
self.loss.remove('contrast')
|
96 |
+
self.lr = 1e-4 * (self.batch_size / 16)
|
97 |
+
self.relation_module = ['GAM', 'ICE', 'NonLocal', 'MHA'][0]
|
98 |
+
self.self_supervision = False
|
99 |
+
self.label_smoothing = False
|
100 |
+
self.freeze = True
|
101 |
+
|
102 |
+
self.validation = False
|
103 |
+
self.decay_step_size = 3000
|
104 |
+
self.rand_seed = 7
|
105 |
+
run_sh_file = [f for f in os.listdir('.') if 'gco' in f and '.sh' in f] + [os.path.join('..', f) for f in os.listdir('..') if 'gco' in f and '.sh' in f]
|
106 |
+
# with open(run_sh_file[0], 'r') as f:
|
107 |
+
# self.val_last = int([l.strip() for l in f.readlines() if 'val_last=' in l][0].split('=')[-1])
|
models/GCoNet.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import torch
|
3 |
+
from torch.functional import norm
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision.models import vgg16, vgg16_bn
|
7 |
+
import fvcore.nn.weight_init as weight_init
|
8 |
+
from torchvision.models import resnet50
|
9 |
+
|
10 |
+
from models.modules import ResBlk, DSLayer, half_DSLayer, CoAttLayer, RefUnet, DBHead
|
11 |
+
|
12 |
+
from config import Config
|
13 |
+
|
14 |
+
|
15 |
+
class GCoNet(nn.Module):
|
16 |
+
def __init__(self, bb_pretrained=True):
|
17 |
+
super(GCoNet, self).__init__()
|
18 |
+
self.config = Config()
|
19 |
+
bb = self.config.bb
|
20 |
+
if bb == 'vgg16':
|
21 |
+
bb_net = list(vgg16(pretrained=bb_pretrained).children())[0]
|
22 |
+
bb_convs = OrderedDict({
|
23 |
+
'conv1': bb_net[:4],
|
24 |
+
'conv2': bb_net[4:9],
|
25 |
+
'conv3': bb_net[9:16],
|
26 |
+
'conv4': bb_net[16:23],
|
27 |
+
'conv5': bb_net[23:30]
|
28 |
+
})
|
29 |
+
channel_scale = 1
|
30 |
+
elif bb == 'resnet50':
|
31 |
+
bb_net = list(resnet50(pretrained=bb_pretrained).children())
|
32 |
+
bb_convs = OrderedDict({
|
33 |
+
'conv1': nn.Sequential(*bb_net[0:3]),
|
34 |
+
'conv2': bb_net[4],
|
35 |
+
'conv3': bb_net[5],
|
36 |
+
'conv4': bb_net[6],
|
37 |
+
'conv5': bb_net[7]
|
38 |
+
})
|
39 |
+
channel_scale = 4
|
40 |
+
elif bb == 'vgg16bn':
|
41 |
+
bb_net = list(vgg16_bn(pretrained=bb_pretrained).children())[0]
|
42 |
+
bb_convs = OrderedDict({
|
43 |
+
'conv1': bb_net[:6],
|
44 |
+
'conv2': bb_net[6:13],
|
45 |
+
'conv3': bb_net[13:23],
|
46 |
+
'conv4': bb_net[23:33],
|
47 |
+
'conv5': bb_net[33:43]
|
48 |
+
})
|
49 |
+
channel_scale = 1
|
50 |
+
self.bb = nn.Sequential(bb_convs)
|
51 |
+
lateral_channels_in = [512, 512, 256, 128, 64] if 'vgg16' in bb else [2048, 1024, 512, 256, 64]
|
52 |
+
|
53 |
+
# channel_scale_latlayer = channel_scale // 2 if bb == 'resnet50' else 1
|
54 |
+
# channel_last = 32
|
55 |
+
|
56 |
+
ch_decoder = lateral_channels_in[0]//2//channel_scale
|
57 |
+
self.top_layer = ResBlk(lateral_channels_in[0], ch_decoder)
|
58 |
+
self.enlayer5 = ResBlk(ch_decoder, ch_decoder)
|
59 |
+
if self.config.conv_after_itp:
|
60 |
+
self.dslayer5 = DSLayer(ch_decoder, ch_decoder)
|
61 |
+
self.latlayer5 = ResBlk(lateral_channels_in[1], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[1], ch_decoder, 1, 1, 0)
|
62 |
+
|
63 |
+
ch_decoder //= 2
|
64 |
+
self.enlayer4 = ResBlk(ch_decoder*2, ch_decoder)
|
65 |
+
if self.config.conv_after_itp:
|
66 |
+
self.dslayer4 = DSLayer(ch_decoder, ch_decoder)
|
67 |
+
self.latlayer4 = ResBlk(lateral_channels_in[2], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[2], ch_decoder, 1, 1, 0)
|
68 |
+
if self.config.output_number >= 4:
|
69 |
+
self.conv_out4 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
|
70 |
+
|
71 |
+
ch_decoder //= 2
|
72 |
+
self.enlayer3 = ResBlk(ch_decoder*2, ch_decoder)
|
73 |
+
if self.config.conv_after_itp:
|
74 |
+
self.dslayer3 = DSLayer(ch_decoder, ch_decoder)
|
75 |
+
self.latlayer3 = ResBlk(lateral_channels_in[3], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[3], ch_decoder, 1, 1, 0)
|
76 |
+
if self.config.output_number >= 3:
|
77 |
+
self.conv_out3 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
|
78 |
+
|
79 |
+
ch_decoder //= 2
|
80 |
+
self.enlayer2 = ResBlk(ch_decoder*2, ch_decoder)
|
81 |
+
if self.config.conv_after_itp:
|
82 |
+
self.dslayer2 = DSLayer(ch_decoder, ch_decoder)
|
83 |
+
self.latlayer2 = ResBlk(lateral_channels_in[4], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[4], ch_decoder, 1, 1, 0)
|
84 |
+
if self.config.output_number >= 2:
|
85 |
+
self.conv_out2 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
|
86 |
+
|
87 |
+
self.enlayer1 = ResBlk(ch_decoder, ch_decoder)
|
88 |
+
self.conv_out1 = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0))
|
89 |
+
|
90 |
+
if self.config.GAM:
|
91 |
+
self.co_x5 = CoAttLayer(channel_in=lateral_channels_in[0])
|
92 |
+
|
93 |
+
if 'contrast' in self.config.loss:
|
94 |
+
self.pred_layer = half_DSLayer(lateral_channels_in[0])
|
95 |
+
|
96 |
+
if {'cls', 'cls_mask'} & set(self.config.loss):
|
97 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
98 |
+
self.classifier = nn.Linear(lateral_channels_in[0], 291) # DUTS_class has 291 classes
|
99 |
+
for layer in [self.classifier]:
|
100 |
+
weight_init.c2_msra_fill(layer)
|
101 |
+
if self.config.split_mask:
|
102 |
+
self.sgm = nn.Sigmoid()
|
103 |
+
if self.config.refine:
|
104 |
+
self.refiner = nn.Sequential(RefUnet(self.config.refine, 64))
|
105 |
+
if self.config.split_mask:
|
106 |
+
self.conv_out_mask = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0))
|
107 |
+
if self.config.db_mask:
|
108 |
+
self.db_mask = DBHead(32)
|
109 |
+
if self.config.db_output_decoder:
|
110 |
+
self.db_output_decoder = DBHead(32)
|
111 |
+
if self.config.cls_mask_operation == 'c':
|
112 |
+
self.conv_cat_mask = nn.Conv2d(4, 3, 1, 1, 0)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
########## Encoder ##########
|
116 |
+
|
117 |
+
[N, _, H, W] = x.size()
|
118 |
+
x1 = self.bb.conv1(x)
|
119 |
+
x2 = self.bb.conv2(x1)
|
120 |
+
x3 = self.bb.conv3(x2)
|
121 |
+
x4 = self.bb.conv4(x3)
|
122 |
+
x5 = self.bb.conv5(x4)
|
123 |
+
|
124 |
+
if 'cls' in self.config.loss:
|
125 |
+
_x5 = self.avgpool(x5)
|
126 |
+
_x5 = _x5.view(_x5.size(0), -1)
|
127 |
+
pred_cls = self.classifier(_x5)
|
128 |
+
|
129 |
+
if self.config.GAM:
|
130 |
+
weighted_x5, neg_x5 = self.co_x5(x5)
|
131 |
+
if 'contrast' in self.config.loss:
|
132 |
+
if self.training:
|
133 |
+
########## contrastive branch #########
|
134 |
+
cat_x5 = torch.cat([weighted_x5, neg_x5], dim=0)
|
135 |
+
pred_contrast = self.pred_layer(cat_x5)
|
136 |
+
pred_contrast = F.interpolate(pred_contrast, size=(H, W), mode='bilinear', align_corners=True)
|
137 |
+
p5 = self.top_layer(weighted_x5)
|
138 |
+
else:
|
139 |
+
p5 = self.top_layer(x5)
|
140 |
+
|
141 |
+
########## Decoder ##########
|
142 |
+
scaled_preds = []
|
143 |
+
p5 = self.enlayer5(p5)
|
144 |
+
p5 = F.interpolate(p5, size=x4.shape[2:], mode='bilinear', align_corners=True)
|
145 |
+
if self.config.conv_after_itp:
|
146 |
+
p5 = self.dslayer5(p5)
|
147 |
+
p4 = p5 + self.latlayer5(x4)
|
148 |
+
|
149 |
+
p4 = self.enlayer4(p4)
|
150 |
+
p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
151 |
+
if self.config.conv_after_itp:
|
152 |
+
p4 = self.dslayer4(p4)
|
153 |
+
if self.config.output_number >= 4:
|
154 |
+
p4_out = self.conv_out4(p4)
|
155 |
+
scaled_preds.append(p4_out)
|
156 |
+
p3 = p4 + self.latlayer4(x3)
|
157 |
+
|
158 |
+
p3 = self.enlayer3(p3)
|
159 |
+
p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
160 |
+
if self.config.conv_after_itp:
|
161 |
+
p3 = self.dslayer3(p3)
|
162 |
+
if self.config.output_number >= 3:
|
163 |
+
p3_out = self.conv_out3(p3)
|
164 |
+
scaled_preds.append(p3_out)
|
165 |
+
p2 = p3 + self.latlayer3(x2)
|
166 |
+
|
167 |
+
p2 = self.enlayer2(p2)
|
168 |
+
p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
169 |
+
if self.config.conv_after_itp:
|
170 |
+
p2 = self.dslayer2(p2)
|
171 |
+
if self.config.output_number >= 2:
|
172 |
+
p2_out = self.conv_out2(p2)
|
173 |
+
scaled_preds.append(p2_out)
|
174 |
+
p1 = p2 + self.latlayer2(x1)
|
175 |
+
|
176 |
+
p1 = self.enlayer1(p1)
|
177 |
+
p1 = F.interpolate(p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
178 |
+
if self.config.db_output_decoder:
|
179 |
+
p1_out = self.db_output_decoder(p1)
|
180 |
+
else:
|
181 |
+
p1_out = self.conv_out1(p1)
|
182 |
+
scaled_preds.append(p1_out)
|
183 |
+
|
184 |
+
if self.config.refine == 1:
|
185 |
+
scaled_preds.append(self.refiner(p1_out))
|
186 |
+
elif self.config.refine == 4:
|
187 |
+
scaled_preds.append(self.refiner(torch.cat([x, p1_out], dim=1)))
|
188 |
+
|
189 |
+
if 'cls_mask' in self.config.loss:
|
190 |
+
pred_cls_masks = []
|
191 |
+
norm_features_mask = []
|
192 |
+
input_features = [x, x1, x2, x3][:self.config.loss_cls_mask_last_layers]
|
193 |
+
bb_lst = [self.bb.conv1, self.bb.conv2, self.bb.conv3, self.bb.conv4, self.bb.conv5]
|
194 |
+
for idx_out in range(self.config.loss_cls_mask_last_layers):
|
195 |
+
if idx_out:
|
196 |
+
mask_output = scaled_preds[-(idx_out+1+int(bool(self.config.refine)))]
|
197 |
+
else:
|
198 |
+
if self.config.split_mask:
|
199 |
+
if self.config.db_mask:
|
200 |
+
mask_output = self.db_mask(p1)
|
201 |
+
else:
|
202 |
+
mask_output = self.sgm(self.conv_out_mask(p1))
|
203 |
+
|
204 |
+
if self.config.cls_mask_operation == 'x':
|
205 |
+
masked_features = input_features[idx_out] * mask_output
|
206 |
+
elif self.config.cls_mask_operation == '+':
|
207 |
+
masked_features = input_features[idx_out] + mask_output
|
208 |
+
elif self.config.cls_mask_operation == 'c':
|
209 |
+
masked_features = self.conv_cat_mask(torch.cat((input_features[idx_out], mask_output), dim=1))
|
210 |
+
norm_feature_mask = self.avgpool(
|
211 |
+
nn.Sequential(*bb_lst[idx_out:])(
|
212 |
+
masked_features
|
213 |
+
)
|
214 |
+
).view(N, -1)
|
215 |
+
norm_features_mask.append(norm_feature_mask)
|
216 |
+
pred_cls_masks.append(
|
217 |
+
self.classifier(
|
218 |
+
norm_feature_mask
|
219 |
+
)
|
220 |
+
)
|
221 |
+
|
222 |
+
if self.training:
|
223 |
+
return_values = []
|
224 |
+
if {'sal', 'cls', 'contrast', 'cls_mask'} == set(self.config.loss):
|
225 |
+
return_values = [scaled_preds, pred_cls, pred_contrast, pred_cls_masks]
|
226 |
+
elif {'sal', 'cls', 'contrast'} == set(self.config.loss):
|
227 |
+
return_values = [scaled_preds, pred_cls, pred_contrast]
|
228 |
+
elif {'sal', 'cls', 'cls_mask'} == set(self.config.loss):
|
229 |
+
return_values = [scaled_preds, pred_cls, pred_cls_masks]
|
230 |
+
elif {'sal', 'cls'} == set(self.config.loss):
|
231 |
+
return_values = [scaled_preds, pred_cls]
|
232 |
+
elif {'sal', 'contrast'} == set(self.config.loss):
|
233 |
+
return_values = [scaled_preds, pred_contrast]
|
234 |
+
elif {'sal', 'cls_mask'} == set(self.config.loss):
|
235 |
+
return_values = [scaled_preds, pred_cls_masks]
|
236 |
+
else:
|
237 |
+
return_values = [scaled_preds]
|
238 |
+
|
239 |
+
if self.config.lambdas_sal_last['triplet']:
|
240 |
+
norm_features = []
|
241 |
+
if '_x5' in self.config.triplet:
|
242 |
+
norm_features.append(_x5)
|
243 |
+
if 'mask' in self.config.triplet:
|
244 |
+
norm_features.append(norm_features_mask[0])
|
245 |
+
return_values.append(norm_features)
|
246 |
+
return return_values
|
247 |
+
else:
|
248 |
+
return scaled_preds
|
models/modules.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
|
7 |
+
from config import Config
|
8 |
+
|
9 |
+
|
10 |
+
config = Config()
|
11 |
+
|
12 |
+
|
13 |
+
class ResBlk(nn.Module):
|
14 |
+
def __init__(self, channel_in=64, channel_out=64):
|
15 |
+
super(ResBlk, self).__init__()
|
16 |
+
self.conv_in = nn.Conv2d(channel_in, 64, 3, 1, 1)
|
17 |
+
self.relu_in = nn.ReLU(inplace=True)
|
18 |
+
self.conv_out = nn.Conv2d(64, channel_out, 3, 1, 1)
|
19 |
+
if config.use_bn:
|
20 |
+
self.bn_in = nn.BatchNorm2d(64)
|
21 |
+
self.bn_out = nn.BatchNorm2d(channel_out)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.conv_in(x)
|
25 |
+
if config.use_bn:
|
26 |
+
x = self.bn_in(x)
|
27 |
+
x = self.relu_in(x)
|
28 |
+
x = self.conv_out(x)
|
29 |
+
if config.use_bn:
|
30 |
+
x = self.bn_out(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class DSLayer(nn.Module):
|
35 |
+
def __init__(self, channel_in=64, channel_out=1, activation_out='relu'):
|
36 |
+
super(DSLayer, self).__init__()
|
37 |
+
self.activation_out = activation_out
|
38 |
+
self.conv1 = nn.Conv2d(channel_in, 64, kernel_size=3, stride=1, padding=1)
|
39 |
+
self.relu1 = nn.ReLU(inplace=True)
|
40 |
+
|
41 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
42 |
+
self.relu2 = nn.ReLU(inplace=True)
|
43 |
+
if activation_out:
|
44 |
+
self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0)
|
45 |
+
self.pred_relu = nn.ReLU(inplace=True)
|
46 |
+
else:
|
47 |
+
self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0)
|
48 |
+
|
49 |
+
if config.use_bn:
|
50 |
+
self.bn1 = nn.BatchNorm2d(64)
|
51 |
+
self.bn2 = nn.BatchNorm2d(64)
|
52 |
+
self.pred_bn = nn.BatchNorm2d(channel_out)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.conv1(x)
|
56 |
+
if config.use_bn:
|
57 |
+
x = self.bn1(x)
|
58 |
+
x = self.relu1(x)
|
59 |
+
x = self.conv2(x)
|
60 |
+
if config.use_bn:
|
61 |
+
x = self.bn2(x)
|
62 |
+
x = self.relu2(x)
|
63 |
+
|
64 |
+
x = self.pred_conv(x)
|
65 |
+
if config.use_bn:
|
66 |
+
x = self.pred_bn(x)
|
67 |
+
if self.activation_out:
|
68 |
+
x = self.pred_relu(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class half_DSLayer(nn.Module):
|
73 |
+
def __init__(self, channel_in=512):
|
74 |
+
super(half_DSLayer, self).__init__()
|
75 |
+
self.enlayer = nn.Sequential(
|
76 |
+
nn.Conv2d(channel_in, int(channel_in//4), kernel_size=3, stride=1, padding=1),
|
77 |
+
nn.ReLU(inplace=True)
|
78 |
+
)
|
79 |
+
self.predlayer = nn.Sequential(
|
80 |
+
nn.Conv2d(int(channel_in//4), 1, kernel_size=1, stride=1, padding=0),
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.enlayer(x)
|
85 |
+
x = self.predlayer(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class CoAttLayer(nn.Module):
|
90 |
+
def __init__(self, channel_in=512):
|
91 |
+
super(CoAttLayer, self).__init__()
|
92 |
+
|
93 |
+
self.all_attention = eval(Config().relation_module + '(channel_in)')
|
94 |
+
self.conv_output = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
95 |
+
self.conv_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
96 |
+
self.fc_transform = nn.Linear(channel_in, channel_in)
|
97 |
+
|
98 |
+
for layer in [self.conv_output, self.conv_transform, self.fc_transform]:
|
99 |
+
weight_init.c2_msra_fill(layer)
|
100 |
+
|
101 |
+
def forward(self, x5):
|
102 |
+
if self.training:
|
103 |
+
f_begin = 0
|
104 |
+
f_end = int(x5.shape[0] / 2)
|
105 |
+
s_begin = f_end
|
106 |
+
s_end = int(x5.shape[0])
|
107 |
+
|
108 |
+
x5_1 = x5[f_begin: f_end]
|
109 |
+
x5_2 = x5[s_begin: s_end]
|
110 |
+
|
111 |
+
x5_new_1 = self.all_attention(x5_1)
|
112 |
+
x5_new_2 = self.all_attention(x5_2)
|
113 |
+
|
114 |
+
x5_1_proto = torch.mean(x5_new_1, (0, 2, 3), True).view(1, -1)
|
115 |
+
x5_1_proto = x5_1_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
|
116 |
+
|
117 |
+
x5_2_proto = torch.mean(x5_new_2, (0, 2, 3), True).view(1, -1)
|
118 |
+
x5_2_proto = x5_2_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
|
119 |
+
|
120 |
+
x5_11 = x5_1 * x5_1_proto
|
121 |
+
x5_22 = x5_2 * x5_2_proto
|
122 |
+
weighted_x5 = torch.cat([x5_11, x5_22], dim=0)
|
123 |
+
|
124 |
+
x5_12 = x5_1 * x5_2_proto
|
125 |
+
x5_21 = x5_2 * x5_1_proto
|
126 |
+
neg_x5 = torch.cat([x5_12, x5_21], dim=0)
|
127 |
+
else:
|
128 |
+
|
129 |
+
x5_new = self.all_attention(x5)
|
130 |
+
x5_proto = torch.mean(x5_new, (0, 2, 3), True).view(1, -1)
|
131 |
+
x5_proto = x5_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
|
132 |
+
|
133 |
+
weighted_x5 = x5 * x5_proto #* cweight
|
134 |
+
neg_x5 = None
|
135 |
+
return weighted_x5, neg_x5
|
136 |
+
|
137 |
+
|
138 |
+
class ICE(nn.Module):
|
139 |
+
# The Integrity Channel Enhancement (ICE) module
|
140 |
+
# _X means in X-th column
|
141 |
+
def __init__(self, channel_in=512):
|
142 |
+
super(ICE, self).__init__()
|
143 |
+
self.conv_1 = nn.Conv2d(channel_in, channel_in, 3, 1, 1)
|
144 |
+
self.conv_2 = nn.Conv1d(channel_in, channel_in, 3, 1, 1)
|
145 |
+
self.conv_3 = nn.Conv2d(channel_in*3, channel_in, 3, 1, 1)
|
146 |
+
|
147 |
+
self.fc_2 = nn.Linear(channel_in, channel_in)
|
148 |
+
self.fc_3 = nn.Linear(channel_in, channel_in)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
x_1, x_2, x_3 = x, x, x
|
152 |
+
|
153 |
+
x_1 = x_1 * x_2 * x_3
|
154 |
+
x_2 = x_1 + x_2 + x_3
|
155 |
+
x_3 = torch.cat((x_1, x_2, x_3), dim=1)
|
156 |
+
|
157 |
+
V = self.conv_1(x_1)
|
158 |
+
|
159 |
+
bs, c, h, w = x_2.shape
|
160 |
+
K = self.conv_2(x_2.view(bs, c, h*w))
|
161 |
+
Q_prime = self.conv_3(x_3)
|
162 |
+
Q_prime = torch.norm(Q_prime, dim=(-2, -1)).view(bs, c, 1, 1)
|
163 |
+
Q_prime = Q_prime.view(bs, -1)
|
164 |
+
Q_prime = self.fc_3(Q_prime)
|
165 |
+
Q_prime = torch.softmax(Q_prime, dim=-1)
|
166 |
+
Q_prime = Q_prime.unsqueeze(1)
|
167 |
+
|
168 |
+
Q = torch.matmul(Q_prime, K)
|
169 |
+
|
170 |
+
x_2 = torch.nn.functional.cosine_similarity(K, Q, dim=-1)
|
171 |
+
x_2 = torch.sigmoid(x_2)
|
172 |
+
x_2 = self.fc_2(x_2)
|
173 |
+
x_2 = x_2.unsqueeze(-1).unsqueeze(-1)
|
174 |
+
x_1 = V * x_2 + V
|
175 |
+
|
176 |
+
return x_1
|
177 |
+
|
178 |
+
|
179 |
+
class GAM(nn.Module):
|
180 |
+
def __init__(self, channel_in=512):
|
181 |
+
|
182 |
+
super(GAM, self).__init__()
|
183 |
+
self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
184 |
+
self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
185 |
+
|
186 |
+
self.scale = 1.0 / (channel_in ** 0.5)
|
187 |
+
|
188 |
+
self.conv6 = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
189 |
+
|
190 |
+
for layer in [self.query_transform, self.key_transform, self.conv6]:
|
191 |
+
weight_init.c2_msra_fill(layer)
|
192 |
+
|
193 |
+
def forward(self, x5):
|
194 |
+
# x: B,C,H,W
|
195 |
+
# x_query: B,C,HW
|
196 |
+
B, C, H5, W5 = x5.size()
|
197 |
+
|
198 |
+
x_query = self.query_transform(x5).view(B, C, -1)
|
199 |
+
|
200 |
+
# x_query: B,HW,C
|
201 |
+
x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C
|
202 |
+
# x_key: B,C,HW
|
203 |
+
x_key = self.key_transform(x5).view(B, C, -1)
|
204 |
+
|
205 |
+
x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW
|
206 |
+
|
207 |
+
# W = Q^T K: B,HW,HW
|
208 |
+
x_w = torch.matmul(x_query, x_key) #* self.scale # BHW, BHW
|
209 |
+
x_w = x_w.view(B*H5*W5, B, H5*W5)
|
210 |
+
x_w = torch.max(x_w, -1).values # BHW, B
|
211 |
+
x_w = x_w.mean(-1)
|
212 |
+
#x_w = torch.mean(x_w, -1).values # BHW
|
213 |
+
x_w = x_w.view(B, -1) * self.scale # B, HW
|
214 |
+
x_w = F.softmax(x_w, dim=-1) # B, HW
|
215 |
+
x_w = x_w.view(B, H5, W5).unsqueeze(1) # B, 1, H, W
|
216 |
+
|
217 |
+
x5 = x5 * x_w
|
218 |
+
x5 = self.conv6(x5)
|
219 |
+
|
220 |
+
return x5
|
221 |
+
|
222 |
+
|
223 |
+
class MHA(nn.Module):
|
224 |
+
'''
|
225 |
+
Scaled dot-product attention
|
226 |
+
'''
|
227 |
+
|
228 |
+
def __init__(self, d_model=512, d_k=512, d_v=512, h=8, dropout=.1, channel_in=512):
|
229 |
+
'''
|
230 |
+
:param d_model: Output dimensionality of the model
|
231 |
+
:param d_k: Dimensionality of queries and keys
|
232 |
+
:param d_v: Dimensionality of values
|
233 |
+
:param h: Number of heads
|
234 |
+
'''
|
235 |
+
super(MHA, self).__init__()
|
236 |
+
self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
237 |
+
self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
238 |
+
self.value_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
|
239 |
+
self.fc_q = nn.Linear(d_model, h * d_k)
|
240 |
+
self.fc_k = nn.Linear(d_model, h * d_k)
|
241 |
+
self.fc_v = nn.Linear(d_model, h * d_v)
|
242 |
+
self.fc_o = nn.Linear(h * d_v, d_model)
|
243 |
+
self.dropout = nn.Dropout(dropout)
|
244 |
+
|
245 |
+
self.d_model = d_model
|
246 |
+
self.d_k = d_k
|
247 |
+
self.d_v = d_v
|
248 |
+
self.h = h
|
249 |
+
|
250 |
+
self.init_weights()
|
251 |
+
|
252 |
+
def init_weights(self):
|
253 |
+
for m in self.modules():
|
254 |
+
if isinstance(m, nn.Conv2d):
|
255 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
256 |
+
if m.bias is not None:
|
257 |
+
nn.init.constant_(m.bias, 0)
|
258 |
+
elif isinstance(m, nn.BatchNorm2d):
|
259 |
+
nn.init.constant_(m.weight, 1)
|
260 |
+
nn.init.constant_(m.bias, 0)
|
261 |
+
elif isinstance(m, nn.Linear):
|
262 |
+
nn.init.normal_(m.weight, std=0.001)
|
263 |
+
if m.bias is not None:
|
264 |
+
nn.init.constant_(m.bias, 0)
|
265 |
+
|
266 |
+
def forward(self, x, attention_mask=None, attention_weights=None):
|
267 |
+
'''
|
268 |
+
Computes
|
269 |
+
:param queries: Queries (b_s, nq, d_model)
|
270 |
+
:param keys: Keys (b_s, nk, d_model)
|
271 |
+
:param values: Values (b_s, nk, d_model)
|
272 |
+
:param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
|
273 |
+
:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
|
274 |
+
:return:
|
275 |
+
'''
|
276 |
+
B, C, H, W = x.size()
|
277 |
+
queries = self.query_transform(x).view(B, -1, C)
|
278 |
+
keys = self.query_transform(x).view(B, -1, C)
|
279 |
+
values = self.query_transform(x).view(B, -1, C)
|
280 |
+
|
281 |
+
b_s, nq = queries.shape[:2]
|
282 |
+
nk = keys.shape[1]
|
283 |
+
|
284 |
+
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
|
285 |
+
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
|
286 |
+
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
|
287 |
+
|
288 |
+
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
|
289 |
+
if attention_weights is not None:
|
290 |
+
att = att * attention_weights
|
291 |
+
if attention_mask is not None:
|
292 |
+
att = att.masked_fill(attention_mask, -np.inf)
|
293 |
+
att = torch.softmax(att, -1)
|
294 |
+
att = self.dropout(att)
|
295 |
+
|
296 |
+
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
|
297 |
+
out = self.fc_o(out).view(B, C, H, W) # (b_s, nq, d_model)
|
298 |
+
return out
|
299 |
+
|
300 |
+
|
301 |
+
class NonLocal(nn.Module):
|
302 |
+
def __init__(self, channel_in=512, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True):
|
303 |
+
super(NonLocal, self).__init__()
|
304 |
+
|
305 |
+
assert dimension in [1, 2, 3]
|
306 |
+
self.dimension = dimension
|
307 |
+
self.sub_sample = sub_sample
|
308 |
+
|
309 |
+
self.channel_in = channel_in
|
310 |
+
self.inter_channels = inter_channels
|
311 |
+
|
312 |
+
if self.inter_channels is None:
|
313 |
+
self.inter_channels = channel_in // 2
|
314 |
+
if self.inter_channels == 0:
|
315 |
+
self.inter_channels = 1
|
316 |
+
|
317 |
+
self.g = nn.Conv2d(self.channel_in, self.inter_channels, 1, 1, 0)
|
318 |
+
|
319 |
+
if bn_layer:
|
320 |
+
self.W = nn.Sequential(
|
321 |
+
nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0),
|
322 |
+
nn.BatchNorm2d(self.channel_in)
|
323 |
+
)
|
324 |
+
nn.init.constant_(self.W[1].weight, 0)
|
325 |
+
nn.init.constant_(self.W[1].bias, 0)
|
326 |
+
else:
|
327 |
+
self.W = nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0)
|
328 |
+
nn.init.constant_(self.W.weight, 0)
|
329 |
+
nn.init.constant_(self.W.bias, 0)
|
330 |
+
|
331 |
+
self.theta = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0)
|
332 |
+
self.phi = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0)
|
333 |
+
|
334 |
+
if sub_sample:
|
335 |
+
self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
|
336 |
+
self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))
|
337 |
+
|
338 |
+
def forward(self, x, return_nl_map=False):
|
339 |
+
"""
|
340 |
+
:param x: (b, c, t, h, w)
|
341 |
+
:param return_nl_map: if True return z, nl_map, else only return z.
|
342 |
+
:return:
|
343 |
+
"""
|
344 |
+
|
345 |
+
batch_size = x.size(0)
|
346 |
+
|
347 |
+
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
|
348 |
+
g_x = g_x.permute(0, 2, 1)
|
349 |
+
|
350 |
+
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
|
351 |
+
theta_x = theta_x.permute(0, 2, 1)
|
352 |
+
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
|
353 |
+
f = torch.matmul(theta_x, phi_x)
|
354 |
+
f_div_C = F.softmax(f, dim=-1)
|
355 |
+
|
356 |
+
y = torch.matmul(f_div_C, g_x)
|
357 |
+
y = y.permute(0, 2, 1).contiguous()
|
358 |
+
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
|
359 |
+
W_y = self.W(y)
|
360 |
+
z = W_y + x
|
361 |
+
|
362 |
+
if return_nl_map:
|
363 |
+
return z, f_div_C
|
364 |
+
return z
|
365 |
+
|
366 |
+
|
367 |
+
class DBHead(nn.Module):
|
368 |
+
def __init__(self, channel_in=32, channel_out=1, k=config.db_k):
|
369 |
+
super().__init__()
|
370 |
+
self.k = k
|
371 |
+
self.binarize = nn.Sequential(
|
372 |
+
nn.Conv2d(channel_in, channel_in, 3, 1, 1),
|
373 |
+
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
|
374 |
+
nn.Conv2d(channel_in, channel_in, 3, 1, 1),
|
375 |
+
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
|
376 |
+
nn.Conv2d(channel_in, channel_out, 1, 1, 0),
|
377 |
+
nn.Sigmoid()
|
378 |
+
)
|
379 |
+
|
380 |
+
self.thresh = nn.Sequential(
|
381 |
+
nn.Conv2d(channel_in, channel_in, 3, padding=1),
|
382 |
+
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
|
383 |
+
nn.Conv2d(channel_in, channel_in, 3, 1, 1),
|
384 |
+
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
|
385 |
+
nn.Conv2d(channel_in, channel_out, 1, 1, 0),
|
386 |
+
nn.Sigmoid()
|
387 |
+
)
|
388 |
+
|
389 |
+
def forward(self, x):
|
390 |
+
shrink_maps = self.binarize(x)
|
391 |
+
threshold_maps = self.thresh(x)
|
392 |
+
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
393 |
+
return binary_maps
|
394 |
+
|
395 |
+
def step_function(self, x, y):
|
396 |
+
if config.db_k_alpha != 1:
|
397 |
+
z = x - y
|
398 |
+
mask_neg_inv = 1 - 2 * (z < 0)
|
399 |
+
a = torch.exp(-self.k * (torch.pow(z * mask_neg_inv + 1e-16, 1/config.k_alpha) * mask_neg_inv))
|
400 |
+
else:
|
401 |
+
a = torch.exp(-self.k * (x - y))
|
402 |
+
if torch.isinf(a).any():
|
403 |
+
a = torch.exp(-50 * (x - y))
|
404 |
+
return torch.reciprocal(1 + a)
|
405 |
+
|
406 |
+
|
407 |
+
class RefUnet(nn.Module):
|
408 |
+
# Refinement
|
409 |
+
def __init__(self, in_ch, inc_ch):
|
410 |
+
super(RefUnet, self).__init__()
|
411 |
+
self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
|
412 |
+
self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
|
413 |
+
if config.use_bn:
|
414 |
+
self.bn1 = nn.BatchNorm2d(64)
|
415 |
+
self.relu1 = nn.ReLU(inplace=True)
|
416 |
+
|
417 |
+
self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
418 |
+
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
|
419 |
+
if config.use_bn:
|
420 |
+
self.bn2 = nn.BatchNorm2d(64)
|
421 |
+
self.relu2 = nn.ReLU(inplace=True)
|
422 |
+
|
423 |
+
self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
424 |
+
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
|
425 |
+
if config.use_bn:
|
426 |
+
self.bn3 = nn.BatchNorm2d(64)
|
427 |
+
self.relu3 = nn.ReLU(inplace=True)
|
428 |
+
|
429 |
+
self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
430 |
+
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
|
431 |
+
if config.use_bn:
|
432 |
+
self.bn4 = nn.BatchNorm2d(64)
|
433 |
+
self.relu4 = nn.ReLU(inplace=True)
|
434 |
+
|
435 |
+
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
436 |
+
#####
|
437 |
+
self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
|
438 |
+
if config.use_bn:
|
439 |
+
self.bn5 = nn.BatchNorm2d(64)
|
440 |
+
self.relu5 = nn.ReLU(inplace=True)
|
441 |
+
#####
|
442 |
+
self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
|
443 |
+
if config.use_bn:
|
444 |
+
self.bn_d4 = nn.BatchNorm2d(64)
|
445 |
+
self.relu_d4 = nn.ReLU(inplace=True)
|
446 |
+
|
447 |
+
self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
|
448 |
+
if config.use_bn:
|
449 |
+
self.bn_d3 = nn.BatchNorm2d(64)
|
450 |
+
self.relu_d3 = nn.ReLU(inplace=True)
|
451 |
+
|
452 |
+
self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
|
453 |
+
if config.use_bn:
|
454 |
+
self.bn_d2 = nn.BatchNorm2d(64)
|
455 |
+
self.relu_d2 = nn.ReLU(inplace=True)
|
456 |
+
|
457 |
+
self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
|
458 |
+
if config.use_bn:
|
459 |
+
self.bn_d1 = nn.BatchNorm2d(64)
|
460 |
+
self.relu_d1 = nn.ReLU(inplace=True)
|
461 |
+
|
462 |
+
self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
|
463 |
+
|
464 |
+
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
465 |
+
if config.db_output_refiner:
|
466 |
+
self.db_output_refiner = DBHead(64)
|
467 |
+
|
468 |
+
|
469 |
+
def forward(self, x):
|
470 |
+
hx = x
|
471 |
+
hx = self.conv1(self.conv0(hx))
|
472 |
+
if config.use_bn:
|
473 |
+
hx = self.bn1(hx)
|
474 |
+
hx1 = self.relu1(hx)
|
475 |
+
hx = self.conv2(self.pool1(hx1))
|
476 |
+
if config.use_bn:
|
477 |
+
hx = self.bn2(hx)
|
478 |
+
hx2 = self.relu2(hx)
|
479 |
+
hx = self.conv3(self.pool2(hx2))
|
480 |
+
if config.use_bn:
|
481 |
+
hx = self.bn3(hx)
|
482 |
+
hx3 = self.relu3(hx)
|
483 |
+
hx = self.conv4(self.pool3(hx3))
|
484 |
+
if config.use_bn:
|
485 |
+
hx = self.bn4(hx)
|
486 |
+
hx4 = self.relu4(hx)
|
487 |
+
hx = self.conv5(self.pool4(hx4))
|
488 |
+
if config.use_bn:
|
489 |
+
hx = self.bn5(hx)
|
490 |
+
hx5 = self.relu5(hx)
|
491 |
+
hx = self.upscore2(hx5)
|
492 |
+
d4 = self.conv_d4(torch.cat((hx, hx4), 1))
|
493 |
+
if config.use_bn:
|
494 |
+
d4 = self.bn_d4(d4)
|
495 |
+
d4 = self.relu_d4(d4)
|
496 |
+
hx = self.upscore2(d4)
|
497 |
+
d3 = self.conv_d3(torch.cat((hx, hx3), 1))
|
498 |
+
if config.use_bn:
|
499 |
+
d3 = self.bn_d3(d3)
|
500 |
+
d3 = self.relu_d3(d3)
|
501 |
+
hx = self.upscore2(d3)
|
502 |
+
d2 = self.conv_d2(torch.cat((hx, hx2), 1))
|
503 |
+
if config.use_bn:
|
504 |
+
d2 = self.bn_d2(d2)
|
505 |
+
d2 = self.relu_d2(d2)
|
506 |
+
hx = self.upscore2(d2)
|
507 |
+
d1 = self.conv_d1(torch.cat((hx, hx1), 1))
|
508 |
+
if config.use_bn:
|
509 |
+
d1 = self.bn_d1(d1)
|
510 |
+
d1 = self.relu_d1(d1)
|
511 |
+
if config.db_output_refiner:
|
512 |
+
x = self.db_output_refiner(d1)
|
513 |
+
else:
|
514 |
+
residual = self.conv_d0(d1)
|
515 |
+
x = x + residual
|
516 |
+
return x
|