admin commited on
Commit
d613312
·
1 Parent(s): adac6eb

fix cuda err

Browse files
Files changed (1) hide show
  1. model.py +4 -7
model.py CHANGED
@@ -88,7 +88,7 @@ class EvalNet:
88
 
89
  def _create_classifier(self):
90
  original_T_size = self.ori_T
91
- upsample_module = nn.Sequential(
92
  nn.AdaptiveAvgPool2d((1, None)), # F -> 1
93
  nn.ConvTranspose2d(
94
  self.out_channel_before_classifier,
@@ -123,8 +123,6 @@ class EvalNet:
123
  nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
124
  )
125
 
126
- return upsample_module
127
-
128
  def _set_channel_outsize(self): #### get the output size before classifier ####
129
  conv2d_out_ch = []
130
  for name, module in self.model.named_modules():
@@ -245,7 +243,7 @@ class t_EvalNet:
245
  def _create_classifier(self):
246
  original_T_size = self.ori_T
247
  self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1
248
- upsample_module = nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1
249
  nn.ConvTranspose2d(
250
  self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
251
  ),
@@ -275,8 +273,6 @@ class t_EvalNet:
275
  nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
276
  )
277
 
278
- return upsample_module
279
-
280
  def _set_classifier(self): #### set custom classifier ####
281
  if self.type == "vit" or self.type == "swin_transformer":
282
  self.classifier = self._create_classifier()
@@ -287,10 +283,11 @@ class t_EvalNet:
287
  def forward(self, x: torch.Tensor):
288
  if torch.cuda.is_available():
289
  x = x.cuda()
 
290
 
291
  if self.type == "vit":
292
  x = self.model._process_input(x)
293
- batch_class_token = self.class_token.expand(x.size(0), -1, -1).cuda()
294
  x = torch.cat([batch_class_token, x], dim=1)
295
  x = self.model.encoder(x)
296
  x = x[:, 1:].permute(0, 2, 1)
 
88
 
89
  def _create_classifier(self):
90
  original_T_size = self.ori_T
91
+ return nn.Sequential(
92
  nn.AdaptiveAvgPool2d((1, None)), # F -> 1
93
  nn.ConvTranspose2d(
94
  self.out_channel_before_classifier,
 
123
  nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
124
  )
125
 
 
 
126
  def _set_channel_outsize(self): #### get the output size before classifier ####
127
  conv2d_out_ch = []
128
  for name, module in self.model.named_modules():
 
243
  def _create_classifier(self):
244
  original_T_size = self.ori_T
245
  self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1
246
+ return nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1
247
  nn.ConvTranspose2d(
248
  self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
249
  ),
 
273
  nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
274
  )
275
 
 
 
276
  def _set_classifier(self): #### set custom classifier ####
277
  if self.type == "vit" or self.type == "swin_transformer":
278
  self.classifier = self._create_classifier()
 
283
  def forward(self, x: torch.Tensor):
284
  if torch.cuda.is_available():
285
  x = x.cuda()
286
+ self.class_token = self.class_token.cuda()
287
 
288
  if self.type == "vit":
289
  x = self.model._process_input(x)
290
+ batch_class_token = self.class_token.expand(x.size(0), -1, -1)
291
  x = torch.cat([batch_class_token, x], dim=1)
292
  x = self.model.encoder(x)
293
  x = x[:, 1:].permute(0, 2, 1)