Commit
·
a97c3f9
1
Parent(s):
2efa01d
update common.py Classify()
Browse files- models/common.py +2 -1
models/common.py
CHANGED
@@ -112,4 +112,5 @@ class Classify(nn.Module):
|
|
112 |
self.flat = Flatten()
|
113 |
|
114 |
def forward(self, x):
|
115 |
-
|
|
|
|
112 |
self.flat = Flatten()
|
113 |
|
114 |
def forward(self, x):
|
115 |
+
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
|
116 |
+
return self.flat(self.conv(z)) # flatten to x(b,c2)
|