Tanzeer commited on
Commit
8395863
·
1 Parent(s): 61cf4f0

Upload 15 files

Browse files
TranSalNet_Res.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from skimage import io, transform
7
+ from PIL import Image
8
+ import torch.nn as nn
9
+ from torchvision import transforms, utils, models
10
+ import torch.nn.functional as F
11
+ import utils.resnet as resnet
12
+
13
+ from utils.TransformerEncoder import Encoder
14
+
15
+
16
+
17
+ cfg1 = {
18
+ "hidden_size" : 768,
19
+ "mlp_dim" : 768*4,
20
+ "num_heads" : 12,
21
+ "num_layers" : 2,
22
+ "attention_dropout_rate" : 0,
23
+ "dropout_rate" : 0.0,
24
+ }
25
+
26
+ cfg2 = {
27
+ "hidden_size" : 768,
28
+ "mlp_dim" : 768*4,
29
+ "num_heads" : 12,
30
+ "num_layers" : 2,
31
+ "attention_dropout_rate" : 0,
32
+ "dropout_rate" : 0.0,
33
+ }
34
+
35
+ cfg3 = {
36
+ "hidden_size" : 512,
37
+ "mlp_dim" : 512*4,
38
+ "num_heads" : 8,
39
+ "num_layers" : 2,
40
+ "attention_dropout_rate" : 0,
41
+ "dropout_rate" : 0.0,
42
+ }
43
+
44
+
45
+ class TranSalNet(nn.Module):
46
+
47
+ def __init__(self):
48
+ super(TranSalNet, self).__init__()
49
+ self.encoder = _Encoder()
50
+ self.decoder = _Decoder()
51
+
52
+ def forward(self, x):
53
+ x = self.encoder(x)
54
+ x = self.decoder(x)
55
+ return x
56
+
57
+
58
+ class _Encoder(nn.Module):
59
+ def __init__(self):
60
+ super(_Encoder, self).__init__()
61
+ base_model = resnet.resnet50(pretrained=True)
62
+ base_layers = list(base_model.children())[:8]
63
+ self.encoder = nn.ModuleList(base_layers).eval()
64
+
65
+ def forward(self, x):
66
+ outputs = []
67
+ for ii,layer in enumerate(self.encoder):
68
+ x = layer(x)
69
+ if ii in {5,6,7}:
70
+ outputs.append(x)
71
+ return outputs
72
+
73
+
74
+ class _Decoder(nn.Module):
75
+
76
+ def __init__(self):
77
+ super(_Decoder, self).__init__()
78
+ self.conv1 = nn.Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
79
+ self.conv2 = nn.Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
80
+ self.conv3 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
81
+ self.conv4 = nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
82
+ self.conv5 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
83
+ self.conv6 = nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
84
+ self.conv7 = nn.Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
85
+
86
+ self.batchnorm1 = nn.BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
87
+ self.batchnorm2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
88
+ self.batchnorm3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
89
+ self.batchnorm4 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
90
+ self.batchnorm5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
91
+ self.batchnorm6 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
92
+
93
+ self.TransEncoder1 = TransEncoder(in_channels=2048, spatial_size=9*12, cfg=cfg1)
94
+ self.TransEncoder2 = TransEncoder(in_channels=1024, spatial_size=18*24, cfg=cfg2)
95
+ self.TransEncoder3 = TransEncoder(in_channels=512, spatial_size=36*48, cfg=cfg3)
96
+
97
+ self.add = torch.add
98
+ self.relu = nn.ReLU(True)
99
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
100
+ self.sigmoid = nn.Sigmoid()
101
+
102
+ def forward(self, x):
103
+ x3, x4, x5 = x
104
+
105
+ x5 = self.TransEncoder1(x5)
106
+ x5 = self.conv1(x5)
107
+ x5 = self.batchnorm1(x5)
108
+ x5 = self.relu(x5)
109
+ x5 = self.upsample(x5)
110
+
111
+ x4_a = self.TransEncoder2(x4)
112
+ x4 = x5 * x4_a
113
+ x4 = self.relu(x4)
114
+ x4 = self.conv2(x4)
115
+ x4 = self.batchnorm2(x4)
116
+ x4 = self.relu(x4)
117
+ x4 = self.upsample(x4)
118
+
119
+ x3_a = self.TransEncoder3(x3)
120
+ x3 = x4 * x3_a
121
+ x3 = self.relu(x3)
122
+ x3 = self.conv3(x3)
123
+ x3 = self.batchnorm3(x3)
124
+ x3 = self.relu(x3)
125
+ x3 = self.upsample(x3)
126
+
127
+ x2 = self.conv4(x3)
128
+ x2 = self.batchnorm4(x2)
129
+ x2 = self.relu(x2)
130
+ x2 = self.upsample(x2)
131
+ x2 = self.conv5(x2)
132
+ x2 = self.batchnorm5(x2)
133
+ x2 = self.relu(x2)
134
+
135
+ x1 = self.upsample(x2)
136
+ x1 = self.conv6(x1)
137
+ x1 = self.batchnorm6(x1)
138
+ x1 = self.relu(x1)
139
+ x1 = self.conv7(x1)
140
+ x = self.sigmoid(x1)
141
+
142
+ return x
143
+
144
+
145
+ class TransEncoder(nn.Module):
146
+
147
+ def __init__(self, in_channels, spatial_size, cfg):
148
+ super(TransEncoder, self).__init__()
149
+
150
+ self.patch_embeddings = nn.Conv2d(in_channels=in_channels,
151
+ out_channels=cfg['hidden_size'],
152
+ kernel_size=1,
153
+ stride=1)
154
+ self.position_embeddings = nn.Parameter(torch.zeros(1, spatial_size, cfg['hidden_size']))
155
+
156
+ self.transformer_encoder = Encoder(cfg)
157
+
158
+ def forward(self, x):
159
+ a, b = x.shape[2], x.shape[3]
160
+ x = self.patch_embeddings(x)
161
+ x = x.flatten(2)
162
+ x = x.transpose(-1, -2)
163
+
164
+ embeddings = x + self.position_embeddings
165
+ x = self.transformer_encoder(embeddings)
166
+ B, n_patch, hidden = x.shape
167
+ x = x.permute(0, 2, 1)
168
+ x = x.contiguous().view(B, hidden, a, b)
169
+
170
+ return x
171
+
__pycache__/TranSalNet_Res.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import cv2
4
+ from PIL import Image
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ from TranSalNet_Res import TranSalNet # Make sure TranSalNet is accessible from your Streamlit app
8
+
9
+ # Load the model and set the device
10
+ model = TranSalNet()
11
+ model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
12
+ model.eval() # Set the model to evaluation mode
13
+ device = torch.device('cpu')
14
+ model.to(device)
15
+
16
+ # Define Streamlit app
17
+ st.title('Saliency Detection App')
18
+ st.write('Upload an image for saliency detection:')
19
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
20
+
21
+ if uploaded_image:
22
+ image = Image.open(uploaded_image)
23
+ st.image(image, caption='Uploaded Image', use_column_width=True)
24
+
25
+ # Check if the user clicks a button
26
+ if st.button('Detect Saliency'):
27
+ # Preprocess the image
28
+ img = image.resize((384, 288))
29
+ img = np.array(img) / 255.
30
+ img = np.transpose(img, (2, 0, 1))
31
+ img = torch.from_numpy(img).unsqueeze(0).float()
32
+ img = img.to(device)
33
+
34
+ # Get saliency prediction
35
+ with torch.no_grad():
36
+ pred_saliency = model(img)
37
+
38
+ # Convert the result back to a PIL image
39
+ toPIL = transforms.ToPILImage()
40
+ pic = toPIL(pred_saliency.squeeze())
41
+
42
+ # Colorize the grayscale prediction
43
+ colorized_img = cv2.applyColorMap(np.uint8(pic), cv2.COLORMAP_JET)
44
+
45
+ # Ensure the colorized image has the same dimensions as the original image
46
+ original_img = np.array(image)
47
+ colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
48
+
49
+ # You can add more post-processing here if needed
50
+
51
+ # Display the final result
52
+ st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
53
+
54
+ st.write('Finished!')
pretrained_models/.keep ADDED
@@ -0,0 +1 @@
 
 
1
+
pretrained_models/TranSalNet_Res.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3853e24e1e0bf892bdc321dcf269516a58450f9af2d3ca0b620272d6c81fe5c7
3
+ size 290451767
pretrained_models/resnet50-0676ba61.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0676ba61b6795bbe1773cffd859882e5e297624d384b6993f7c9e683e722fb8a
3
+ size 102530333
requirements.txt ADDED
Binary file (3.06 kB). View file
 
utils/TransformerEncoder.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import copy
8
+ import logging
9
+ import math
10
+
11
+ from os.path import join as pjoin
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+
17
+ from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
18
+ from torch.nn.modules.utils import _pair
19
+ from scipy import ndimage
20
+
21
+
22
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu}
23
+
24
+
25
+ class Attention(nn.Module):
26
+ def __init__(self, config):
27
+ super(Attention, self).__init__()
28
+ self.num_attention_heads = config["num_heads"] # 12
29
+ self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads) # 42
30
+ self.all_head_size = self.num_attention_heads * self.attention_head_size # 12*42=504
31
+
32
+ self.query = Linear(config['hidden_size'], self.all_head_size) # (512, 504)
33
+ self.key = Linear(config['hidden_size'], self.all_head_size)
34
+ self.value = Linear(config['hidden_size'], self.all_head_size)
35
+
36
+ # self.out = Linear(config['hidden_size'], config['hidden_size'])
37
+ self.out = Linear(self.all_head_size, config['hidden_size'])
38
+ self.attn_dropout = Dropout(config["attention_dropout_rate"])
39
+ self.proj_dropout = Dropout(config["attention_dropout_rate"])
40
+
41
+ self.softmax = Softmax(dim=-1)
42
+
43
+ def transpose_for_scores(self, x):
44
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
45
+ x = x.view(*new_x_shape)
46
+ return x.permute(0, 2, 1, 3)
47
+
48
+ def forward(self, hidden_states):
49
+
50
+ mixed_query_layer = self.query(hidden_states)
51
+ mixed_key_layer = self.key(hidden_states)
52
+ mixed_value_layer = self.value(hidden_states)
53
+
54
+ query_layer = self.transpose_for_scores(mixed_query_layer)
55
+ key_layer = self.transpose_for_scores(mixed_key_layer)
56
+ value_layer = self.transpose_for_scores(mixed_value_layer)
57
+
58
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
59
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
60
+ attention_probs = self.softmax(attention_scores)
61
+ attention_probs = self.attn_dropout(attention_probs)
62
+
63
+ context_layer = torch.matmul(attention_probs, value_layer)
64
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
65
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
66
+ context_layer = context_layer.view(*new_context_layer_shape)
67
+ attention_output = self.out(context_layer)
68
+ attention_output = self.proj_dropout(attention_output)
69
+ return attention_output
70
+
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(self, config):
74
+ super(Mlp, self).__init__()
75
+ self.fc1 = Linear(config['hidden_size'], config["mlp_dim"])
76
+ self.fc2 = Linear(config["mlp_dim"], config['hidden_size'])
77
+ self.act_fn = ACT2FN["gelu"]
78
+ self.dropout = Dropout(config["dropout_rate"])
79
+ self._init_weights()
80
+
81
+ def _init_weights(self):
82
+ nn.init.xavier_uniform_(self.fc1.weight)
83
+ nn.init.xavier_uniform_(self.fc2.weight)
84
+ nn.init.normal_(self.fc1.bias, std=1e-6)
85
+ nn.init.normal_(self.fc2.bias, std=1e-6)
86
+
87
+ def forward(self, x):
88
+ x = self.fc1(x)
89
+ x = self.act_fn(x)
90
+ x = self.dropout(x)
91
+ x = self.fc2(x)
92
+ x = self.dropout(x)
93
+ return x
94
+
95
+
96
+ class Block(nn.Module):
97
+ def __init__(self, config):
98
+ super(Block, self).__init__()
99
+ self.flag = config['num_heads']
100
+ self.hidden_size = config['hidden_size']
101
+ self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6)
102
+ self.ffn = Mlp(config)
103
+ self.attn = Attention(config)
104
+ self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6)
105
+
106
+ def forward(self, x):
107
+ h = x
108
+
109
+ x = self.attention_norm(x)
110
+ x = self.attn(x)
111
+ x = x + h
112
+
113
+ h = x
114
+ x = self.ffn_norm(x)
115
+ x = self.ffn(x)
116
+ x = x + h
117
+ return x
118
+
119
+
120
+ class Encoder(nn.Module):
121
+ def __init__(self, config):
122
+ super(Encoder, self).__init__()
123
+
124
+ self.layer = nn.ModuleList()
125
+ self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6)
126
+ for _ in range(config["num_layers"]):
127
+ layer = Block(config)
128
+ self.layer.append(copy.deepcopy(layer))
129
+
130
+ def forward(self, hidden_states):
131
+ for layer_block in self.layer:
132
+ hidden_states = layer_block(hidden_states)
133
+ encoded = self.encoder_norm(hidden_states)
134
+
135
+ return encoded
136
+
137
+
utils/__pycache__/TransformerEncoder.cpython-310.pyc ADDED
Binary file (4.54 kB). View file
 
utils/__pycache__/data_process.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
utils/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
utils/data_process.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+
8
+
9
+ def preprocess_img(img_dir, channels=3):
10
+
11
+ if channels == 1:
12
+ img = cv2.imread(img_dir, 0)
13
+ elif channels == 3:
14
+ img = cv2.imread(img_dir)
15
+
16
+ shape_r = 288
17
+ shape_c = 384
18
+ img_padded = np.ones((shape_r, shape_c, channels), dtype=np.uint8)
19
+ if channels == 1:
20
+ img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8)
21
+ original_shape = img.shape
22
+ rows_rate = original_shape[0] / shape_r
23
+ cols_rate = original_shape[1] / shape_c
24
+ if rows_rate > cols_rate:
25
+ new_cols = (original_shape[1] * shape_r) // original_shape[0]
26
+ img = cv2.resize(img, (new_cols, shape_r))
27
+ if new_cols > shape_c:
28
+ new_cols = shape_c
29
+ img_padded[:,
30
+ ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img
31
+ else:
32
+ new_rows = (original_shape[0] * shape_c) // original_shape[1]
33
+ img = cv2.resize(img, (shape_c, new_rows))
34
+
35
+ if new_rows > shape_r:
36
+ new_rows = shape_r
37
+ img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows),
38
+ :] = img
39
+
40
+ return img_padded
41
+
42
+
43
+ def postprocess_img(pred, org_dir):
44
+ pred = np.array(pred)
45
+ org = cv2.imread(org_dir, 0)
46
+ shape_r = org.shape[0]
47
+ shape_c = org.shape[1]
48
+ predictions_shape = pred.shape
49
+
50
+ rows_rate = shape_r / predictions_shape[0]
51
+ cols_rate = shape_c / predictions_shape[1]
52
+
53
+ if rows_rate > cols_rate:
54
+ new_cols = (predictions_shape[1] * shape_r) // predictions_shape[0]
55
+ pred = cv2.resize(pred, (new_cols, shape_r))
56
+ img = pred[:, ((pred.shape[1] - shape_c) // 2):((pred.shape[1] - shape_c) // 2 + shape_c)]
57
+ else:
58
+ new_rows = (predictions_shape[0] * shape_c) // predictions_shape[1]
59
+ pred = cv2.resize(pred, (shape_c, new_rows))
60
+ img = pred[((pred.shape[0] - shape_r) // 2):((pred.shape[0] - shape_r) // 2 + shape_r), :]
61
+
62
+ return img
63
+
64
+
65
+ class MyDataset(Dataset):
66
+ """Load dataset."""
67
+
68
+ def __init__(self, ids, stimuli_dir, saliency_dir, fixation_dir, transform=None):
69
+ """
70
+ Args:
71
+ csv_file (string): Path to the csv file with annotations.
72
+ root_dir (string): Directory with all the images.
73
+ transform (callable, optional): Optional transform to be applied
74
+ on a sample.
75
+ """
76
+ self.ids = ids
77
+ self.stimuli_dir = stimuli_dir
78
+ self.saliency_dir = saliency_dir
79
+ self.fixation_dir = fixation_dir
80
+ self.transform = transform
81
+
82
+ def __len__(self):
83
+ return len(self.ids)
84
+
85
+ def __getitem__(self, idx):
86
+ if torch.is_tensor(idx):
87
+ idx = idx.tolist()
88
+
89
+ im_path = self.stimuli_dir + self.ids.iloc[idx, 0]
90
+ image = Image.open(im_path).convert('RGB')
91
+ img = np.array(image) / 255.
92
+ img = np.transpose(img, (2, 0, 1))
93
+ img = torch.from_numpy(img)
94
+ # if self.transform:
95
+ # img = self.transform(image)
96
+
97
+ smap_path = self.saliency_dir + self.ids.iloc[idx, 1]
98
+ saliency = Image.open(smap_path)
99
+
100
+ smap = np.expand_dims(np.array(saliency) / 255., axis=0)
101
+ smap = torch.from_numpy(smap)
102
+
103
+ fmap_path = self.fixation_dir + self.ids.iloc[idx, 2]
104
+ fixation = Image.open(fmap_path)
105
+
106
+ fmap = np.expand_dims(np.array(fixation) / 255., axis=0)
107
+ fmap = torch.from_numpy(fmap)
108
+
109
+ sample = {'image': img, 'saliency': smap, 'fixation': fmap}
110
+
111
+ return sample
112
+
113
+
114
+
115
+
116
+
utils/densenet.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as cp
6
+ from collections import OrderedDict
7
+ # from .utils import load_state_dict_from_url
8
+ from torch import Tensor
9
+ from torch.jit.annotations import List
10
+
11
+
12
+ __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
13
+
14
+ model_urls = {
15
+ 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
16
+ 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
17
+ 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
18
+ 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
19
+ }
20
+
21
+
22
+ class _DenseLayer(nn.Module):
23
+ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
24
+ super(_DenseLayer, self).__init__()
25
+ self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
26
+ self.add_module('relu1', nn.ReLU(inplace=True)),
27
+ self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
28
+ growth_rate, kernel_size=1, stride=1,
29
+ bias=False)),
30
+ self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
31
+ self.add_module('relu2', nn.ReLU(inplace=True)),
32
+ self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
33
+ kernel_size=3, stride=1, padding=1,
34
+ bias=False)),
35
+ self.drop_rate = float(drop_rate)
36
+ self.memory_efficient = memory_efficient
37
+
38
+ def bn_function(self, inputs):
39
+ # type: (List[Tensor]) -> Tensor
40
+ concated_features = torch.cat(inputs, 1)
41
+ bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
42
+ return bottleneck_output
43
+
44
+ # todo: rewrite when torchscript supports any
45
+ def any_requires_grad(self, input):
46
+ # type: (List[Tensor]) -> bool
47
+ for tensor in input:
48
+ if tensor.requires_grad:
49
+ return True
50
+ return False
51
+
52
+ @torch.jit.unused # noqa: T484
53
+ def call_checkpoint_bottleneck(self, input):
54
+ # type: (List[Tensor]) -> Tensor
55
+ def closure(*inputs):
56
+ return self.bn_function(inputs)
57
+
58
+ return cp.checkpoint(closure, *input)
59
+
60
+ @torch.jit._overload_method # noqa: F811
61
+ def forward(self, input):
62
+ # type: (List[Tensor]) -> (Tensor)
63
+ pass
64
+
65
+ @torch.jit._overload_method # noqa: F811
66
+ def forward(self, input):
67
+ # type: (Tensor) -> (Tensor)
68
+ pass
69
+
70
+ # torchscript does not yet support *args, so we overload method
71
+ # allowing it to take either a List[Tensor] or single Tensor
72
+ def forward(self, input): # noqa: F811
73
+ if isinstance(input, Tensor):
74
+ prev_features = [input]
75
+ else:
76
+ prev_features = input
77
+
78
+ if self.memory_efficient and self.any_requires_grad(prev_features):
79
+ if torch.jit.is_scripting():
80
+ raise Exception("Memory Efficient not supported in JIT")
81
+
82
+ bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
83
+ else:
84
+ bottleneck_output = self.bn_function(prev_features)
85
+
86
+ new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
87
+ if self.drop_rate > 0:
88
+ new_features = F.dropout(new_features, p=self.drop_rate,
89
+ training=self.training)
90
+ return new_features
91
+
92
+
93
+ class _DenseBlock(nn.ModuleDict):
94
+ _version = 2
95
+
96
+ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
97
+ super(_DenseBlock, self).__init__()
98
+ for i in range(num_layers):
99
+ layer = _DenseLayer(
100
+ num_input_features + i * growth_rate,
101
+ growth_rate=growth_rate,
102
+ bn_size=bn_size,
103
+ drop_rate=drop_rate,
104
+ memory_efficient=memory_efficient,
105
+ )
106
+ self.add_module('denselayer%d' % (i + 1), layer)
107
+
108
+ def forward(self, init_features):
109
+ features = [init_features]
110
+ for name, layer in self.items():
111
+ new_features = layer(features)
112
+ features.append(new_features)
113
+ return torch.cat(features, 1)
114
+
115
+
116
+ class _Transition(nn.Sequential):
117
+ def __init__(self, num_input_features, num_output_features):
118
+ super(_Transition, self).__init__()
119
+ self.add_module('norm', nn.BatchNorm2d(num_input_features))
120
+ self.add_module('relu', nn.ReLU(inplace=True))
121
+ self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
122
+ kernel_size=1, stride=1, bias=False))
123
+ self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
124
+
125
+
126
+ class DenseNet(nn.Module):
127
+ r"""Densenet-BC model class, based on
128
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
129
+
130
+ Args:
131
+ growth_rate (int) - how many filters to add each layer (`k` in paper)
132
+ block_config (list of 4 ints) - how many layers in each pooling block
133
+ num_init_features (int) - the number of filters to learn in the first convolution layer
134
+ bn_size (int) - multiplicative factor for number of bottle neck layers
135
+ (i.e. bn_size * k features in the bottleneck layer)
136
+ drop_rate (float) - dropout rate after each dense layer
137
+ num_classes (int) - number of classification classes
138
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
139
+ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
140
+ """
141
+
142
+ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
143
+ num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
144
+
145
+ super(DenseNet, self).__init__()
146
+
147
+ # First convolution
148
+ self.features = nn.Sequential(OrderedDict([
149
+ ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
150
+ padding=3, bias=False)),
151
+ ('norm0', nn.BatchNorm2d(num_init_features)),
152
+ ('relu0', nn.ReLU(inplace=True)),
153
+ ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
154
+ ]))
155
+
156
+ # Each denseblock
157
+ num_features = num_init_features
158
+ for i, num_layers in enumerate(block_config):
159
+ block = _DenseBlock(
160
+ num_layers=num_layers,
161
+ num_input_features=num_features,
162
+ bn_size=bn_size,
163
+ growth_rate=growth_rate,
164
+ drop_rate=drop_rate,
165
+ memory_efficient=memory_efficient
166
+ )
167
+ self.features.add_module('denseblock%d' % (i + 1), block)
168
+ num_features = num_features + num_layers * growth_rate
169
+ if i != len(block_config) - 1:
170
+ trans = _Transition(num_input_features=num_features,
171
+ num_output_features=num_features // 2)
172
+ self.features.add_module('transition%d' % (i + 1), trans)
173
+ num_features = num_features // 2
174
+
175
+ # Final batch norm
176
+ self.features.add_module('norm5', nn.BatchNorm2d(num_features))
177
+
178
+ # Linear layer
179
+ self.classifier = nn.Linear(num_features, num_classes)
180
+
181
+ # Official init from torch repo.
182
+ for m in self.modules():
183
+ if isinstance(m, nn.Conv2d):
184
+ nn.init.kaiming_normal_(m.weight)
185
+ elif isinstance(m, nn.BatchNorm2d):
186
+ nn.init.constant_(m.weight, 1)
187
+ nn.init.constant_(m.bias, 0)
188
+ elif isinstance(m, nn.Linear):
189
+ nn.init.constant_(m.bias, 0)
190
+
191
+ def forward(self, x):
192
+ features = self.features(x)
193
+ out = F.relu(features, inplace=True)
194
+ out = F.adaptive_avg_pool2d(out, (1, 1))
195
+ out = torch.flatten(out, 1)
196
+ out = self.classifier(out)
197
+ return out
198
+
199
+
200
+ def _load_state_dict(model, model_url, progress, flag):
201
+ # '.'s are no longer allowed in module names, but previous _DenseLayer
202
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
203
+ # They are also in the checkpoints in model_urls. This pattern is used
204
+ # to find such keys.
205
+ pattern = re.compile(
206
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
207
+ if flag == "densenet161":
208
+ state_dict = torch.load(r'pretrained_models/densenet161-8d451a50.pth')
209
+ else:
210
+ state_dict = load_state_dict_from_url(model_url, progress=progress)
211
+ for key in list(state_dict.keys()):
212
+ res = pattern.match(key)
213
+ if res:
214
+ new_key = res.group(1) + res.group(2)
215
+ state_dict[new_key] = state_dict[key]
216
+ del state_dict[key]
217
+ model.load_state_dict(state_dict)
218
+
219
+
220
+ def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
221
+ **kwargs):
222
+ model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
223
+ if pretrained:
224
+ if arch == 'densenet161':
225
+ _load_state_dict(model, model_urls[arch], progress, 'densenet161')
226
+ else:
227
+ _load_state_dict(model, model_urls[arch], progress, 0)
228
+ return model
229
+
230
+
231
+ def densenet121(pretrained=False, progress=True, **kwargs):
232
+ r"""Densenet-121 model from
233
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
234
+
235
+ Args:
236
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
237
+ progress (bool): If True, displays a progress bar of the download to stderr
238
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
239
+ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
240
+ """
241
+ return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
242
+ **kwargs)
243
+
244
+
245
+
246
+ def densenet161(pretrained=False, progress=True, **kwargs):
247
+ r"""Densenet-161 model from
248
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
249
+
250
+ Args:
251
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
252
+ progress (bool): If True, displays a progress bar of the download to stderr
253
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
254
+ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
255
+ """
256
+ return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
257
+ **kwargs)
258
+
259
+
260
+
261
+ def densenet169(pretrained=False, progress=True, **kwargs):
262
+ r"""Densenet-169 model from
263
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
264
+
265
+ Args:
266
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
267
+ progress (bool): If True, displays a progress bar of the download to stderr
268
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
269
+ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
270
+ """
271
+ return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
272
+ **kwargs)
273
+
274
+
275
+
276
+ def densenet201(pretrained=False, progress=True, **kwargs):
277
+ r"""Densenet-201 model from
278
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
279
+
280
+ Args:
281
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
282
+ progress (bool): If True, displays a progress bar of the download to stderr
283
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
284
+ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
285
+ """
286
+ return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
287
+ **kwargs)
utils/loss_function.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+ import torch.nn as nn
3
+ #import numpy as np
4
+
5
+
6
+ class SaliencyLoss(nn.Module):
7
+ def __init__(self):
8
+ super(SaliencyLoss, self).__init__()
9
+
10
+ def forward(self, preds, labels, loss_type='cc'):
11
+ losses = []
12
+ if loss_type == 'cc':
13
+ for i in range(labels.shape[0]): # labels.shape[0] is batch size
14
+ loss = loss_CC(preds[i],labels[i])
15
+ losses.append(loss)
16
+
17
+ elif loss_type == 'kldiv':
18
+ for i in range(labels.shape[0]):
19
+ loss = loss_KLdiv(preds[i],labels[i])
20
+ losses.append(loss)
21
+
22
+ elif loss_type == 'sim':
23
+ for i in range(labels.shape[0]):
24
+ loss = loss_similarity(preds[i],labels[i])
25
+ losses.append(loss)
26
+
27
+ elif loss_type == 'nss':
28
+ for i in range(labels.shape[0]):
29
+ loss = loss_NSS(preds[i],labels[i])
30
+ losses.append(loss)
31
+
32
+ return t.stack(losses).mean(dim=0, keepdim=True)
33
+
34
+
35
+ def loss_KLdiv(pred_map, gt_map):
36
+ eps = 2.2204e-16
37
+ pred_map = pred_map/t.sum(pred_map)
38
+ gt_map = gt_map/t.sum(gt_map)
39
+ div = t.sum(t.mul(gt_map, t.log(eps + t.div(gt_map,pred_map+eps))))
40
+ return div
41
+
42
+
43
+ def loss_CC(pred_map,gt_map):
44
+ gt_map_ = (gt_map - t.mean(gt_map))
45
+ pred_map_ = (pred_map - t.mean(pred_map))
46
+ cc = t.sum(t.mul(gt_map_,pred_map_))/t.sqrt(t.sum(t.mul(gt_map_,gt_map_))*t.sum(t.mul(pred_map_,pred_map_)))
47
+ return cc
48
+
49
+
50
+ def loss_similarity(pred_map,gt_map):
51
+ gt_map = (gt_map - t.min(gt_map))/(t.max(gt_map)-t.min(gt_map))
52
+ gt_map = gt_map/t.sum(gt_map)
53
+
54
+ pred_map = (pred_map - t.min(pred_map))/(t.max(pred_map)-t.min(pred_map))
55
+ pred_map = pred_map/t.sum(pred_map)
56
+
57
+ diff = t.min(gt_map,pred_map)
58
+ score = t.sum(diff)
59
+
60
+ return score
61
+
62
+
63
+ def loss_NSS(pred_map,fix_map):
64
+ '''ground truth here is fixation map'''
65
+
66
+ pred_map_ = (pred_map - t.mean(pred_map))/t.std(pred_map)
67
+ mask = fix_map.gt(0)
68
+ score = t.mean(t.masked_select(pred_map_, mask))
69
+ return score
utils/resnet.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, Any, Callable, Union, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+ # from .._internally_replaced_utils import load_state_dict_from_url
8
+ # from ..utils import _log_api_usage_once
9
+
10
+
11
+ __all__ = [
12
+ "ResNet",
13
+ "resnet18",
14
+ "resnet34",
15
+ "resnet50",
16
+ "resnet101",
17
+ "resnet152",
18
+ "resnext50_32x4d",
19
+ "resnext101_32x8d",
20
+ "wide_resnet50_2",
21
+ "wide_resnet101_2",
22
+ ]
23
+
24
+
25
+ model_urls = {
26
+ "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
27
+ "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
28
+ "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
29
+ "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
30
+ "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
31
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
32
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
33
+ "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
34
+ "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
35
+ }
36
+
37
+
38
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
39
+ """3x3 convolution with padding"""
40
+ return nn.Conv2d(
41
+ in_planes,
42
+ out_planes,
43
+ kernel_size=3,
44
+ stride=stride,
45
+ padding=dilation,
46
+ groups=groups,
47
+ bias=False,
48
+ dilation=dilation,
49
+ )
50
+
51
+
52
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
53
+ """1x1 convolution"""
54
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
55
+
56
+
57
+ class BasicBlock(nn.Module):
58
+ expansion: int = 1
59
+
60
+ def __init__(
61
+ self,
62
+ inplanes: int,
63
+ planes: int,
64
+ stride: int = 1,
65
+ downsample: Optional[nn.Module] = None,
66
+ groups: int = 1,
67
+ base_width: int = 64,
68
+ dilation: int = 1,
69
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
70
+ ) -> None:
71
+ super().__init__()
72
+ if norm_layer is None:
73
+ norm_layer = nn.BatchNorm2d
74
+ if groups != 1 or base_width != 64:
75
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
76
+ if dilation > 1:
77
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
78
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
79
+ self.conv1 = conv3x3(inplanes, planes, stride)
80
+ self.bn1 = norm_layer(planes)
81
+ self.relu = nn.ReLU(inplace=True)
82
+ self.conv2 = conv3x3(planes, planes)
83
+ self.bn2 = norm_layer(planes)
84
+ self.downsample = downsample
85
+ self.stride = stride
86
+
87
+ def forward(self, x: Tensor) -> Tensor:
88
+ identity = x
89
+
90
+ out = self.conv1(x)
91
+ out = self.bn1(out)
92
+ out = self.relu(out)
93
+
94
+ out = self.conv2(out)
95
+ out = self.bn2(out)
96
+
97
+ if self.downsample is not None:
98
+ identity = self.downsample(x)
99
+
100
+ out += identity
101
+ out = self.relu(out)
102
+
103
+ return out
104
+
105
+
106
+ class Bottleneck(nn.Module):
107
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
108
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
109
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
110
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
111
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
112
+
113
+ expansion: int = 4
114
+
115
+ def __init__(
116
+ self,
117
+ inplanes: int,
118
+ planes: int,
119
+ stride: int = 1,
120
+ downsample: Optional[nn.Module] = None,
121
+ groups: int = 1,
122
+ base_width: int = 64,
123
+ dilation: int = 1,
124
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
125
+ ) -> None:
126
+ super().__init__()
127
+ if norm_layer is None:
128
+ norm_layer = nn.BatchNorm2d
129
+ width = int(planes * (base_width / 64.0)) * groups
130
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
131
+ self.conv1 = conv1x1(inplanes, width)
132
+ self.bn1 = norm_layer(width)
133
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
134
+ self.bn2 = norm_layer(width)
135
+ self.conv3 = conv1x1(width, planes * self.expansion)
136
+ self.bn3 = norm_layer(planes * self.expansion)
137
+ self.relu = nn.ReLU(inplace=True)
138
+ self.downsample = downsample
139
+ self.stride = stride
140
+
141
+ def forward(self, x: Tensor) -> Tensor:
142
+ identity = x
143
+
144
+ out = self.conv1(x)
145
+ out = self.bn1(out)
146
+ out = self.relu(out)
147
+
148
+ out = self.conv2(out)
149
+ out = self.bn2(out)
150
+ out = self.relu(out)
151
+
152
+ out = self.conv3(out)
153
+ out = self.bn3(out)
154
+
155
+ if self.downsample is not None:
156
+ identity = self.downsample(x)
157
+
158
+ out += identity
159
+ out = self.relu(out)
160
+
161
+ return out
162
+
163
+
164
+ class ResNet(nn.Module):
165
+ def __init__(
166
+ self,
167
+ block: Type[Union[BasicBlock, Bottleneck]],
168
+ layers: List[int],
169
+ num_classes: int = 1000,
170
+ zero_init_residual: bool = False,
171
+ groups: int = 1,
172
+ width_per_group: int = 64,
173
+ replace_stride_with_dilation: Optional[List[bool]] = None,
174
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
175
+ ) -> None:
176
+ super().__init__()
177
+ # _log_api_usage_once(self)
178
+ if norm_layer is None:
179
+ norm_layer = nn.BatchNorm2d
180
+ self._norm_layer = norm_layer
181
+
182
+ self.inplanes = 64
183
+ self.dilation = 1
184
+ if replace_stride_with_dilation is None:
185
+ # each element in the tuple indicates if we should replace
186
+ # the 2x2 stride with a dilated convolution instead
187
+ replace_stride_with_dilation = [False, False, False]
188
+ if len(replace_stride_with_dilation) != 3:
189
+ raise ValueError(
190
+ "replace_stride_with_dilation should be None "
191
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
192
+ )
193
+ self.groups = groups
194
+ self.base_width = width_per_group
195
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
196
+ self.bn1 = norm_layer(self.inplanes)
197
+ self.relu = nn.ReLU(inplace=True)
198
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
199
+ self.layer1 = self._make_layer(block, 64, layers[0])
200
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
201
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
202
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
203
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
204
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
205
+
206
+ for m in self.modules():
207
+ if isinstance(m, nn.Conv2d):
208
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
209
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
210
+ nn.init.constant_(m.weight, 1)
211
+ nn.init.constant_(m.bias, 0)
212
+
213
+ # Zero-initialize the last BN in each residual branch,
214
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
215
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
216
+ if zero_init_residual:
217
+ for m in self.modules():
218
+ if isinstance(m, Bottleneck):
219
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
220
+ elif isinstance(m, BasicBlock):
221
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
222
+
223
+ def _make_layer(
224
+ self,
225
+ block: Type[Union[BasicBlock, Bottleneck]],
226
+ planes: int,
227
+ blocks: int,
228
+ stride: int = 1,
229
+ dilate: bool = False,
230
+ ) -> nn.Sequential:
231
+ norm_layer = self._norm_layer
232
+ downsample = None
233
+ previous_dilation = self.dilation
234
+ if dilate:
235
+ self.dilation *= stride
236
+ stride = 1
237
+ if stride != 1 or self.inplanes != planes * block.expansion:
238
+ downsample = nn.Sequential(
239
+ conv1x1(self.inplanes, planes * block.expansion, stride),
240
+ norm_layer(planes * block.expansion),
241
+ )
242
+
243
+ layers = []
244
+ layers.append(
245
+ block(
246
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
247
+ )
248
+ )
249
+ self.inplanes = planes * block.expansion
250
+ for _ in range(1, blocks):
251
+ layers.append(
252
+ block(
253
+ self.inplanes,
254
+ planes,
255
+ groups=self.groups,
256
+ base_width=self.base_width,
257
+ dilation=self.dilation,
258
+ norm_layer=norm_layer,
259
+ )
260
+ )
261
+
262
+ return nn.Sequential(*layers)
263
+
264
+ def _forward_impl(self, x: Tensor) -> Tensor:
265
+ # See note [TorchScript super()]
266
+ x = self.conv1(x)
267
+ x = self.bn1(x)
268
+ x = self.relu(x)
269
+ x = self.maxpool(x)
270
+
271
+ x = self.layer1(x)
272
+ x = self.layer2(x)
273
+ x = self.layer3(x)
274
+ x = self.layer4(x)
275
+
276
+ x = self.avgpool(x)
277
+ x = torch.flatten(x, 1)
278
+ x = self.fc(x)
279
+
280
+ return x
281
+
282
+ def forward(self, x: Tensor) -> Tensor:
283
+ return self._forward_impl(x)
284
+
285
+
286
+ def _resnet(
287
+ arch: str,
288
+ block: Type[Union[BasicBlock, Bottleneck]],
289
+ layers: List[int],
290
+ pretrained: bool,
291
+ progress: bool,
292
+ **kwargs: Any,
293
+ ) -> ResNet:
294
+ model = ResNet(block, layers, **kwargs)
295
+ if pretrained:
296
+ if arch == 'resnet50':
297
+ state_dict = torch.load(r'pretrained_models/resnet50-0676ba61.pth')
298
+ else:
299
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
300
+ model.load_state_dict(state_dict)
301
+ return model
302
+
303
+
304
+ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
305
+ r"""ResNet-18 model from
306
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
307
+
308
+ Args:
309
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
310
+ progress (bool): If True, displays a progress bar of the download to stderr
311
+ """
312
+ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
313
+
314
+
315
+ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
316
+ r"""ResNet-34 model from
317
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
318
+
319
+ Args:
320
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
321
+ progress (bool): If True, displays a progress bar of the download to stderr
322
+ """
323
+ return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
324
+
325
+
326
+ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
327
+ r"""ResNet-50 model from
328
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
329
+
330
+ Args:
331
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
332
+ progress (bool): If True, displays a progress bar of the download to stderr
333
+ """
334
+ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
335
+
336
+
337
+ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
338
+ r"""ResNet-101 model from
339
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
340
+
341
+ Args:
342
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
343
+ progress (bool): If True, displays a progress bar of the download to stderr
344
+ """
345
+ return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
346
+
347
+
348
+ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
349
+ r"""ResNet-152 model from
350
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
351
+
352
+ Args:
353
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
354
+ progress (bool): If True, displays a progress bar of the download to stderr
355
+ """
356
+ return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
357
+
358
+
359
+ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
360
+ r"""ResNeXt-50 32x4d model from
361
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
362
+
363
+ Args:
364
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
365
+ progress (bool): If True, displays a progress bar of the download to stderr
366
+ """
367
+ kwargs["groups"] = 32
368
+ kwargs["width_per_group"] = 4
369
+ return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
370
+
371
+
372
+ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
373
+ r"""ResNeXt-101 32x8d model from
374
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
375
+
376
+ Args:
377
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
378
+ progress (bool): If True, displays a progress bar of the download to stderr
379
+ """
380
+ kwargs["groups"] = 32
381
+ kwargs["width_per_group"] = 8
382
+ return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
383
+
384
+
385
+
386
+ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
387
+ r"""Wide ResNet-50-2 model from
388
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
389
+
390
+ The model is the same as ResNet except for the bottleneck number of channels
391
+ which is twice larger in every block. The number of channels in outer 1x1
392
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
393
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
394
+
395
+ Args:
396
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
397
+ progress (bool): If True, displays a progress bar of the download to stderr
398
+ """
399
+ kwargs["width_per_group"] = 64 * 2
400
+ return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
401
+
402
+
403
+
404
+
405
+ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
406
+ r"""Wide ResNet-101-2 model from
407
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
408
+
409
+ The model is the same as ResNet except for the bottleneck number of channels
410
+ which is twice larger in every block. The number of channels in outer 1x1
411
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
412
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
413
+
414
+ Args:
415
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
416
+ progress (bool): If True, displays a progress bar of the download to stderr
417
+ """
418
+ kwargs["width_per_group"] = 64 * 2
419
+ return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)