Commit
·
5387d47
1
Parent(s):
37acbdc
update common.py add Classify()
Browse files- models/common.py +19 -6
models/common.py
CHANGED
@@ -76,12 +76,6 @@ class SPP(nn.Module):
|
|
76 |
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
77 |
|
78 |
|
79 |
-
class Flatten(nn.Module):
|
80 |
-
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
|
81 |
-
def forward(self, x):
|
82 |
-
return x.view(x.size(0), -1)
|
83 |
-
|
84 |
-
|
85 |
class Focus(nn.Module):
|
86 |
# Focus wh information into c-space
|
87 |
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
@@ -100,3 +94,22 @@ class Concat(nn.Module):
|
|
100 |
|
101 |
def forward(self, x):
|
102 |
return torch.cat(x, self.d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
class Focus(nn.Module):
|
80 |
# Focus wh information into c-space
|
81 |
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
|
|
94 |
|
95 |
def forward(self, x):
|
96 |
return torch.cat(x, self.d)
|
97 |
+
|
98 |
+
|
99 |
+
class Flatten(nn.Module):
|
100 |
+
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
|
101 |
+
@staticmethod
|
102 |
+
def forward(x):
|
103 |
+
return x.view(x.size(0), -1)
|
104 |
+
|
105 |
+
|
106 |
+
class Classify(nn.Module):
|
107 |
+
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
108 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
109 |
+
super(Classify, self).__init__()
|
110 |
+
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
|
111 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1)
|
112 |
+
self.flat = Flatten()
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2)
|