Spaces:
Build error
Build error
Commit
·
e1bfa3e
1
Parent(s):
eadd03c
Upload model_core.py
Browse files- model_core.py +153 -0
model_core.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from components.attention import ChannelAttention, SpatialAttention, DualCrossModalAttention
|
6 |
+
from components.srm_conv import SRMConv2d_simple, SRMConv2d_Separate
|
7 |
+
from networks.xception import TransferModel
|
8 |
+
|
9 |
+
|
10 |
+
class SRMPixelAttention(nn.Module):
|
11 |
+
def __init__(self, in_channels):
|
12 |
+
super(SRMPixelAttention, self).__init__()
|
13 |
+
# self.srm = SRMConv2d_simple()
|
14 |
+
self.conv = nn.Sequential(
|
15 |
+
nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False),
|
16 |
+
nn.BatchNorm2d(32),
|
17 |
+
nn.ReLU(inplace=True),
|
18 |
+
nn.Conv2d(32, 64, 3, bias=False),
|
19 |
+
nn.BatchNorm2d(64),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
)
|
22 |
+
|
23 |
+
self.pa = SpatialAttention()
|
24 |
+
|
25 |
+
for m in self.modules():
|
26 |
+
if isinstance(m, nn.Conv2d):
|
27 |
+
nn.init.kaiming_normal_(m.weight, a=1)
|
28 |
+
if not m.bias is None:
|
29 |
+
nn.init.constant_(m.bias, 0)
|
30 |
+
|
31 |
+
def forward(self, x_srm):
|
32 |
+
# x_srm = self.srm(x)
|
33 |
+
fea = self.conv(x_srm)
|
34 |
+
att_map = self.pa(fea)
|
35 |
+
|
36 |
+
return att_map
|
37 |
+
|
38 |
+
|
39 |
+
class FeatureFusionModule(nn.Module):
|
40 |
+
def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs):
|
41 |
+
super(FeatureFusionModule, self).__init__()
|
42 |
+
self.convblk = nn.Sequential(
|
43 |
+
nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False),
|
44 |
+
nn.BatchNorm2d(out_chan),
|
45 |
+
nn.ReLU()
|
46 |
+
)
|
47 |
+
self.ca = ChannelAttention(out_chan, ratio=16)
|
48 |
+
self.init_weight()
|
49 |
+
|
50 |
+
def forward(self, x, y):
|
51 |
+
fuse_fea = self.convblk(torch.cat((x, y), dim=1))
|
52 |
+
fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea)
|
53 |
+
return fuse_fea
|
54 |
+
|
55 |
+
def init_weight(self):
|
56 |
+
for ly in self.children():
|
57 |
+
if isinstance(ly, nn.Conv2d):
|
58 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
59 |
+
if not ly.bias is None:
|
60 |
+
nn.init.constant_(ly.bias, 0)
|
61 |
+
|
62 |
+
|
63 |
+
class Two_Stream_Net(nn.Module):
|
64 |
+
def __init__(self):
|
65 |
+
super().__init__()
|
66 |
+
self.xception_rgb = TransferModel(
|
67 |
+
'xception', dropout=0.5, inc=3, return_fea=True)
|
68 |
+
self.xception_srm = TransferModel(
|
69 |
+
'xception', dropout=0.5, inc=3, return_fea=True)
|
70 |
+
|
71 |
+
self.srm_conv0 = SRMConv2d_simple(inc=3)
|
72 |
+
self.srm_conv1 = SRMConv2d_Separate(32, 32)
|
73 |
+
self.srm_conv2 = SRMConv2d_Separate(64, 64)
|
74 |
+
self.relu = nn.ReLU(inplace=True)
|
75 |
+
|
76 |
+
self.att_map = None
|
77 |
+
self.srm_sa = SRMPixelAttention(3)
|
78 |
+
self.srm_sa_post = nn.Sequential(
|
79 |
+
nn.BatchNorm2d(64),
|
80 |
+
nn.ReLU(inplace=True)
|
81 |
+
)
|
82 |
+
|
83 |
+
self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False)
|
84 |
+
self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False)
|
85 |
+
|
86 |
+
self.fusion = FeatureFusionModule()
|
87 |
+
|
88 |
+
self.att_dic = {}
|
89 |
+
|
90 |
+
def features(self, x):
|
91 |
+
srm = self.srm_conv0(x)
|
92 |
+
|
93 |
+
x = self.xception_rgb.model.fea_part1_0(x)
|
94 |
+
y = self.xception_srm.model.fea_part1_0(srm) \
|
95 |
+
+ self.srm_conv1(x)
|
96 |
+
y = self.relu(y)
|
97 |
+
|
98 |
+
x = self.xception_rgb.model.fea_part1_1(x)
|
99 |
+
y = self.xception_srm.model.fea_part1_1(y) \
|
100 |
+
+ self.srm_conv2(x)
|
101 |
+
y = self.relu(y)
|
102 |
+
|
103 |
+
# srm guided spatial attention
|
104 |
+
self.att_map = self.srm_sa(srm)
|
105 |
+
x = x * self.att_map + x
|
106 |
+
x = self.srm_sa_post(x)
|
107 |
+
|
108 |
+
x = self.xception_rgb.model.fea_part2(x)
|
109 |
+
y = self.xception_srm.model.fea_part2(y)
|
110 |
+
|
111 |
+
x, y = self.dual_cma0(x, y)
|
112 |
+
|
113 |
+
|
114 |
+
x = self.xception_rgb.model.fea_part3(x)
|
115 |
+
y = self.xception_srm.model.fea_part3(y)
|
116 |
+
|
117 |
+
|
118 |
+
x, y = self.dual_cma1(x, y)
|
119 |
+
|
120 |
+
x = self.xception_rgb.model.fea_part4(x)
|
121 |
+
y = self.xception_srm.model.fea_part4(y)
|
122 |
+
|
123 |
+
x = self.xception_rgb.model.fea_part5(x)
|
124 |
+
y = self.xception_srm.model.fea_part5(y)
|
125 |
+
|
126 |
+
fea = self.fusion(x, y)
|
127 |
+
|
128 |
+
|
129 |
+
return fea
|
130 |
+
|
131 |
+
def classifier(self, fea):
|
132 |
+
out, fea = self.xception_rgb.classifier(fea)
|
133 |
+
return out, fea
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
'''
|
137 |
+
x: original rgb
|
138 |
+
|
139 |
+
Return:
|
140 |
+
out: (B, 2) the output for loss computing
|
141 |
+
fea: (B, 1024) the flattened features before the last FC
|
142 |
+
att_map: srm spatial attention map
|
143 |
+
'''
|
144 |
+
out, fea = self.classifier(self.features(x))
|
145 |
+
|
146 |
+
return out, fea, self.att_map
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
model = Two_Stream_Net()
|
150 |
+
dummy = torch.rand((1,3,256,256))
|
151 |
+
out = model(dummy)
|
152 |
+
print(model)
|
153 |
+
|