Spaces:
Build error
Build error
Commit
·
2f0673d
1
Parent(s):
b0df336
Upload attention.py
Browse files- attention.py +252 -0
attention.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
"""
|
7 |
+
Channel Attention and Spaitial Attention from
|
8 |
+
Woo, S., Park, J., Lee, J.Y., & Kweon, I. CBAM: Convolutional Block Attention Module. ECCV2018.
|
9 |
+
"""
|
10 |
+
|
11 |
+
|
12 |
+
class ChannelAttention(nn.Module):
|
13 |
+
def __init__(self, in_planes, ratio=8):
|
14 |
+
super(ChannelAttention, self).__init__()
|
15 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
16 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
17 |
+
|
18 |
+
self.sharedMLP = nn.Sequential(
|
19 |
+
nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
|
22 |
+
self.sigmoid = nn.Sigmoid()
|
23 |
+
|
24 |
+
for m in self.modules():
|
25 |
+
if isinstance(m, nn.Conv2d):
|
26 |
+
nn.init.xavier_normal_(m.weight.data, gain=0.02)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
avgout = self.sharedMLP(self.avg_pool(x))
|
30 |
+
maxout = self.sharedMLP(self.max_pool(x))
|
31 |
+
return self.sigmoid(avgout + maxout)
|
32 |
+
|
33 |
+
|
34 |
+
class SpatialAttention(nn.Module):
|
35 |
+
def __init__(self, kernel_size=7):
|
36 |
+
super(SpatialAttention, self).__init__()
|
37 |
+
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
|
38 |
+
padding = 3 if kernel_size == 7 else 1
|
39 |
+
|
40 |
+
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
41 |
+
self.sigmoid = nn.Sigmoid()
|
42 |
+
|
43 |
+
for m in self.modules():
|
44 |
+
if isinstance(m, nn.Conv2d):
|
45 |
+
nn.init.xavier_normal_(m.weight.data, gain=0.02)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
avgout = torch.mean(x, dim=1, keepdim=True)
|
49 |
+
maxout, _ = torch.max(x, dim=1, keepdim=True)
|
50 |
+
x = torch.cat([avgout, maxout], dim=1)
|
51 |
+
x = self.conv(x)
|
52 |
+
return self.sigmoid(x)
|
53 |
+
|
54 |
+
|
55 |
+
"""
|
56 |
+
The following modules are modified based on https://github.com/heykeetae/Self-Attention-GAN
|
57 |
+
"""
|
58 |
+
|
59 |
+
|
60 |
+
class Self_Attn(nn.Module):
|
61 |
+
""" Self attention Layer"""
|
62 |
+
|
63 |
+
def __init__(self, in_dim, out_dim=None, add=False, ratio=8):
|
64 |
+
super(Self_Attn, self).__init__()
|
65 |
+
self.chanel_in = in_dim
|
66 |
+
self.add = add
|
67 |
+
if out_dim is None:
|
68 |
+
out_dim = in_dim
|
69 |
+
self.out_dim = out_dim
|
70 |
+
# self.activation = activation
|
71 |
+
|
72 |
+
self.query_conv = nn.Conv2d(
|
73 |
+
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
|
74 |
+
self.key_conv = nn.Conv2d(
|
75 |
+
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
|
76 |
+
self.value_conv = nn.Conv2d(
|
77 |
+
in_channels=in_dim, out_channels=out_dim, kernel_size=1)
|
78 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
79 |
+
|
80 |
+
self.softmax = nn.Softmax(dim=-1)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
"""
|
84 |
+
inputs :
|
85 |
+
x : input feature maps( B X C X W X H)
|
86 |
+
returns :
|
87 |
+
out : self attention value + input feature
|
88 |
+
attention: B X N X N (N is Width*Height)
|
89 |
+
"""
|
90 |
+
m_batchsize, C, width, height = x.size()
|
91 |
+
proj_query = self.query_conv(x).view(
|
92 |
+
m_batchsize, -1, width*height).permute(0, 2, 1) # B X C X(N)
|
93 |
+
proj_key = self.key_conv(x).view(
|
94 |
+
m_batchsize, -1, width*height) # B X C x (*W*H)
|
95 |
+
energy = torch.bmm(proj_query, proj_key) # transpose check
|
96 |
+
attention = self.softmax(energy) # BX (N) X (N)
|
97 |
+
proj_value = self.value_conv(x).view(
|
98 |
+
m_batchsize, -1, width*height) # B X C X N
|
99 |
+
|
100 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
101 |
+
out = out.view(m_batchsize, self.out_dim, width, height)
|
102 |
+
|
103 |
+
if self.add:
|
104 |
+
out = self.gamma*out + x
|
105 |
+
else:
|
106 |
+
out = self.gamma*out
|
107 |
+
return out # , attention
|
108 |
+
|
109 |
+
|
110 |
+
class CrossModalAttention(nn.Module):
|
111 |
+
""" CMA attention Layer"""
|
112 |
+
|
113 |
+
def __init__(self, in_dim, activation=None, ratio=8, cross_value=True):
|
114 |
+
super(CrossModalAttention, self).__init__()
|
115 |
+
self.chanel_in = in_dim
|
116 |
+
self.activation = activation
|
117 |
+
self.cross_value = cross_value
|
118 |
+
|
119 |
+
self.query_conv = nn.Conv2d(
|
120 |
+
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
|
121 |
+
self.key_conv = nn.Conv2d(
|
122 |
+
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
|
123 |
+
self.value_conv = nn.Conv2d(
|
124 |
+
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
125 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
126 |
+
|
127 |
+
self.softmax = nn.Softmax(dim=-1)
|
128 |
+
|
129 |
+
for m in self.modules():
|
130 |
+
if isinstance(m, nn.Conv2d):
|
131 |
+
nn.init.xavier_normal_(m.weight.data, gain=0.02)
|
132 |
+
|
133 |
+
def forward(self, x, y):
|
134 |
+
"""
|
135 |
+
inputs :
|
136 |
+
x : input feature maps( B X C X W X H)
|
137 |
+
returns :
|
138 |
+
out : self attention value + input feature
|
139 |
+
attention: B X N X N (N is Width*Height)
|
140 |
+
"""
|
141 |
+
B, C, H, W = x.size()
|
142 |
+
|
143 |
+
proj_query = self.query_conv(x).view(
|
144 |
+
B, -1, H*W).permute(0, 2, 1) # B , HW, C
|
145 |
+
proj_key = self.key_conv(y).view(
|
146 |
+
B, -1, H*W) # B X C x (*W*H)
|
147 |
+
energy = torch.bmm(proj_query, proj_key) # B, HW, HW
|
148 |
+
attention = self.softmax(energy) # BX (N) X (N)
|
149 |
+
if self.cross_value:
|
150 |
+
proj_value = self.value_conv(y).view(
|
151 |
+
B, -1, H*W) # B , C , HW
|
152 |
+
else:
|
153 |
+
proj_value = self.value_conv(x).view(
|
154 |
+
B, -1, H*W) # B , C , HW
|
155 |
+
|
156 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
157 |
+
out = out.view(B, C, H, W)
|
158 |
+
|
159 |
+
out = self.gamma*out + x
|
160 |
+
|
161 |
+
if self.activation is not None:
|
162 |
+
out = self.activation(out)
|
163 |
+
|
164 |
+
return out # , attention
|
165 |
+
|
166 |
+
|
167 |
+
class DualCrossModalAttention(nn.Module):
|
168 |
+
""" Dual CMA attention Layer"""
|
169 |
+
|
170 |
+
def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False):
|
171 |
+
super(DualCrossModalAttention, self).__init__()
|
172 |
+
self.chanel_in = in_dim
|
173 |
+
self.activation = activation
|
174 |
+
self.ret_att = ret_att
|
175 |
+
|
176 |
+
# query conv
|
177 |
+
self.key_conv1 = nn.Conv2d(
|
178 |
+
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
|
179 |
+
self.key_conv2 = nn.Conv2d(
|
180 |
+
in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
|
181 |
+
self.key_conv_share = nn.Conv2d(
|
182 |
+
in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1)
|
183 |
+
|
184 |
+
self.linear1 = nn.Linear(size*size, size*size)
|
185 |
+
self.linear2 = nn.Linear(size*size, size*size)
|
186 |
+
|
187 |
+
# separated value conv
|
188 |
+
self.value_conv1 = nn.Conv2d(
|
189 |
+
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
190 |
+
self.gamma1 = nn.Parameter(torch.zeros(1))
|
191 |
+
|
192 |
+
self.value_conv2 = nn.Conv2d(
|
193 |
+
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
194 |
+
self.gamma2 = nn.Parameter(torch.zeros(1))
|
195 |
+
|
196 |
+
self.softmax = nn.Softmax(dim=-1)
|
197 |
+
|
198 |
+
for m in self.modules():
|
199 |
+
if isinstance(m, nn.Conv2d):
|
200 |
+
nn.init.xavier_normal_(m.weight.data, gain=0.02)
|
201 |
+
if isinstance(m, nn.Linear):
|
202 |
+
nn.init.xavier_normal_(m.weight.data, gain=0.02)
|
203 |
+
|
204 |
+
def forward(self, x, y):
|
205 |
+
"""
|
206 |
+
inputs :
|
207 |
+
x : input feature maps( B X C X W X H)
|
208 |
+
returns :
|
209 |
+
out : self attention value + input feature
|
210 |
+
attention: B X N X N (N is Width*Height)
|
211 |
+
"""
|
212 |
+
B, C, H, W = x.size()
|
213 |
+
|
214 |
+
def _get_att(a, b):
|
215 |
+
proj_key1 = self.key_conv_share(self.key_conv1(a)).view(
|
216 |
+
B, -1, H*W).permute(0, 2, 1) # B, HW, C
|
217 |
+
proj_key2 = self.key_conv_share(self.key_conv2(b)).view(
|
218 |
+
B, -1, H*W) # B X C x (*W*H)
|
219 |
+
energy = torch.bmm(proj_key1, proj_key2) # B, HW, HW
|
220 |
+
|
221 |
+
attention1 = self.softmax(self.linear1(energy))
|
222 |
+
attention2 = self.softmax(self.linear2(
|
223 |
+
energy.permute(0, 2, 1))) # BX (N) X (N)
|
224 |
+
|
225 |
+
return attention1, attention2
|
226 |
+
|
227 |
+
att_y_on_x, att_x_on_y = _get_att(x, y)
|
228 |
+
proj_value_y_on_x = self.value_conv2(y).view(
|
229 |
+
B, -1, H*W) # B, C, HW
|
230 |
+
out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1))
|
231 |
+
out_y_on_x = out_y_on_x.view(B, C, H, W)
|
232 |
+
out_x = self.gamma1*out_y_on_x + x
|
233 |
+
|
234 |
+
proj_value_x_on_y = self.value_conv1(x).view(
|
235 |
+
B, -1, H*W) # B , C , HW
|
236 |
+
out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1))
|
237 |
+
out_x_on_y = out_x_on_y.view(B, C, H, W)
|
238 |
+
out_y = self.gamma2*out_x_on_y + y
|
239 |
+
|
240 |
+
if self.ret_att:
|
241 |
+
return out_x, out_y, att_y_on_x, att_x_on_y
|
242 |
+
|
243 |
+
return out_x, out_y # , attention
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
x = torch.rand(10, 768, 16, 16)
|
248 |
+
y = torch.rand(10, 768, 16, 16)
|
249 |
+
dcma = DualCrossModalAttention(768, ret_att=True)
|
250 |
+
out_x, out_y, att_y_on_x, att_x_on_y = dcma(x, y)
|
251 |
+
print(out_y.size())
|
252 |
+
print(att_x_on_y.size())
|