asdasdasdasd commited on
Commit
e1bfa3e
·
1 Parent(s): eadd03c

Upload model_core.py

Browse files
Files changed (1) hide show
  1. model_core.py +153 -0
model_core.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from components.attention import ChannelAttention, SpatialAttention, DualCrossModalAttention
6
+ from components.srm_conv import SRMConv2d_simple, SRMConv2d_Separate
7
+ from networks.xception import TransferModel
8
+
9
+
10
+ class SRMPixelAttention(nn.Module):
11
+ def __init__(self, in_channels):
12
+ super(SRMPixelAttention, self).__init__()
13
+ # self.srm = SRMConv2d_simple()
14
+ self.conv = nn.Sequential(
15
+ nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False),
16
+ nn.BatchNorm2d(32),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(32, 64, 3, bias=False),
19
+ nn.BatchNorm2d(64),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+
23
+ self.pa = SpatialAttention()
24
+
25
+ for m in self.modules():
26
+ if isinstance(m, nn.Conv2d):
27
+ nn.init.kaiming_normal_(m.weight, a=1)
28
+ if not m.bias is None:
29
+ nn.init.constant_(m.bias, 0)
30
+
31
+ def forward(self, x_srm):
32
+ # x_srm = self.srm(x)
33
+ fea = self.conv(x_srm)
34
+ att_map = self.pa(fea)
35
+
36
+ return att_map
37
+
38
+
39
+ class FeatureFusionModule(nn.Module):
40
+ def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs):
41
+ super(FeatureFusionModule, self).__init__()
42
+ self.convblk = nn.Sequential(
43
+ nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False),
44
+ nn.BatchNorm2d(out_chan),
45
+ nn.ReLU()
46
+ )
47
+ self.ca = ChannelAttention(out_chan, ratio=16)
48
+ self.init_weight()
49
+
50
+ def forward(self, x, y):
51
+ fuse_fea = self.convblk(torch.cat((x, y), dim=1))
52
+ fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea)
53
+ return fuse_fea
54
+
55
+ def init_weight(self):
56
+ for ly in self.children():
57
+ if isinstance(ly, nn.Conv2d):
58
+ nn.init.kaiming_normal_(ly.weight, a=1)
59
+ if not ly.bias is None:
60
+ nn.init.constant_(ly.bias, 0)
61
+
62
+
63
+ class Two_Stream_Net(nn.Module):
64
+ def __init__(self):
65
+ super().__init__()
66
+ self.xception_rgb = TransferModel(
67
+ 'xception', dropout=0.5, inc=3, return_fea=True)
68
+ self.xception_srm = TransferModel(
69
+ 'xception', dropout=0.5, inc=3, return_fea=True)
70
+
71
+ self.srm_conv0 = SRMConv2d_simple(inc=3)
72
+ self.srm_conv1 = SRMConv2d_Separate(32, 32)
73
+ self.srm_conv2 = SRMConv2d_Separate(64, 64)
74
+ self.relu = nn.ReLU(inplace=True)
75
+
76
+ self.att_map = None
77
+ self.srm_sa = SRMPixelAttention(3)
78
+ self.srm_sa_post = nn.Sequential(
79
+ nn.BatchNorm2d(64),
80
+ nn.ReLU(inplace=True)
81
+ )
82
+
83
+ self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False)
84
+ self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False)
85
+
86
+ self.fusion = FeatureFusionModule()
87
+
88
+ self.att_dic = {}
89
+
90
+ def features(self, x):
91
+ srm = self.srm_conv0(x)
92
+
93
+ x = self.xception_rgb.model.fea_part1_0(x)
94
+ y = self.xception_srm.model.fea_part1_0(srm) \
95
+ + self.srm_conv1(x)
96
+ y = self.relu(y)
97
+
98
+ x = self.xception_rgb.model.fea_part1_1(x)
99
+ y = self.xception_srm.model.fea_part1_1(y) \
100
+ + self.srm_conv2(x)
101
+ y = self.relu(y)
102
+
103
+ # srm guided spatial attention
104
+ self.att_map = self.srm_sa(srm)
105
+ x = x * self.att_map + x
106
+ x = self.srm_sa_post(x)
107
+
108
+ x = self.xception_rgb.model.fea_part2(x)
109
+ y = self.xception_srm.model.fea_part2(y)
110
+
111
+ x, y = self.dual_cma0(x, y)
112
+
113
+
114
+ x = self.xception_rgb.model.fea_part3(x)
115
+ y = self.xception_srm.model.fea_part3(y)
116
+
117
+
118
+ x, y = self.dual_cma1(x, y)
119
+
120
+ x = self.xception_rgb.model.fea_part4(x)
121
+ y = self.xception_srm.model.fea_part4(y)
122
+
123
+ x = self.xception_rgb.model.fea_part5(x)
124
+ y = self.xception_srm.model.fea_part5(y)
125
+
126
+ fea = self.fusion(x, y)
127
+
128
+
129
+ return fea
130
+
131
+ def classifier(self, fea):
132
+ out, fea = self.xception_rgb.classifier(fea)
133
+ return out, fea
134
+
135
+ def forward(self, x):
136
+ '''
137
+ x: original rgb
138
+
139
+ Return:
140
+ out: (B, 2) the output for loss computing
141
+ fea: (B, 1024) the flattened features before the last FC
142
+ att_map: srm spatial attention map
143
+ '''
144
+ out, fea = self.classifier(self.features(x))
145
+
146
+ return out, fea, self.att_map
147
+
148
+ if __name__ == '__main__':
149
+ model = Two_Stream_Net()
150
+ dummy = torch.rand((1,3,256,256))
151
+ out = model(dummy)
152
+ print(model)
153
+