glenn-jocher commited on
Commit
a97c3f9
·
1 Parent(s): 2efa01d

update common.py Classify()

Browse files
Files changed (1) hide show
  1. models/common.py +2 -1
models/common.py CHANGED
@@ -112,4 +112,5 @@ class Classify(nn.Module):
112
  self.flat = Flatten()
113
 
114
  def forward(self, x):
115
- return self.flat(self.conv(self.aap(x))) # flatten to x(b,c2)
 
 
112
  self.flat = Flatten()
113
 
114
  def forward(self, x):
115
+ z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
116
+ return self.flat(self.conv(z)) # flatten to x(b,c2)