Concise `TransformerBlock()` (#3821)
Browse files- models/common.py +2 -12
models/common.py
CHANGED
@@ -77,18 +77,8 @@ class TransformerBlock(nn.Module):
|
|
77 |
if self.conv is not None:
|
78 |
x = self.conv(x)
|
79 |
b, _, w, h = x.shape
|
80 |
-
p = x.flatten(2)
|
81 |
-
p
|
82 |
-
p = p.transpose(0, 3)
|
83 |
-
p = p.squeeze(3)
|
84 |
-
e = self.linear(p)
|
85 |
-
x = p + e
|
86 |
-
|
87 |
-
x = self.tr(x)
|
88 |
-
x = x.unsqueeze(3)
|
89 |
-
x = x.transpose(0, 3)
|
90 |
-
x = x.reshape(b, self.c2, w, h)
|
91 |
-
return x
|
92 |
|
93 |
|
94 |
class Bottleneck(nn.Module):
|
|
|
77 |
if self.conv is not None:
|
78 |
x = self.conv(x)
|
79 |
b, _, w, h = x.shape
|
80 |
+
p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
|
81 |
+
return self.tr(p + self.linear(p)).unsqueeze(3).transpose(0, 3).reshape(b, self.c2, w, h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
class Bottleneck(nn.Module):
|