admin commited on
Commit
f6221dd
·
1 Parent(s): d08ce40

add trust_remote_code=True,

Browse files
Files changed (2) hide show
  1. model.py +4 -4
  2. t_model.py +3 -4
model.py CHANGED
@@ -62,6 +62,7 @@ class EvalNet:
62
  if torch.cuda.is_available():
63
  self.model = self.model.cuda()
64
  self.classifier = self.classifier.cuda()
 
65
  self.model.eval()
66
 
67
  def _get_backbone(self, backbone_ver, backbone_list):
@@ -76,6 +77,7 @@ class EvalNet:
76
  "monetjoe/cv_backbones",
77
  split=self.imgnet_ver,
78
  cache_dir="./__pycache__",
 
79
  )
80
  backbone_info = self._get_backbone(backbone, backbone_list)
81
  return (
@@ -171,13 +173,11 @@ class EvalNet:
171
 
172
  if self.type == "convnext":
173
  out = self.model(x)
174
- out = self.classifier(out).squeeze()
175
- return out
176
 
177
  else:
178
  out = self.model(x)
179
  out = out.view(
180
  out.size(0), self.out_channel_before_classifier, self.H, self.H
181
  )
182
- out = self.classifier(out).squeeze()
183
- return out
 
62
  if torch.cuda.is_available():
63
  self.model = self.model.cuda()
64
  self.classifier = self.classifier.cuda()
65
+
66
  self.model.eval()
67
 
68
  def _get_backbone(self, backbone_ver, backbone_list):
 
77
  "monetjoe/cv_backbones",
78
  split=self.imgnet_ver,
79
  cache_dir="./__pycache__",
80
+ trust_remote_code=True,
81
  )
82
  backbone_info = self._get_backbone(backbone, backbone_list)
83
  return (
 
173
 
174
  if self.type == "convnext":
175
  out = self.model(x)
176
+ return self.classifier(out).squeeze()
 
177
 
178
  else:
179
  out = self.model(x)
180
  out = out.view(
181
  out.size(0), self.out_channel_before_classifier, self.H, self.H
182
  )
183
+ return self.classifier(out).squeeze()
 
t_model.py CHANGED
@@ -79,6 +79,7 @@ class t_EvalNet:
79
  "monetjoe/cv_backbones",
80
  split=self.imgnet_ver,
81
  cache_dir="./__pycache__",
 
82
  )
83
  backbone_info = self._get_backbone(backbone, backbone_list)
84
  return (
@@ -140,14 +141,12 @@ class t_EvalNet:
140
  x = self.model.encoder(x)
141
  x = x[:, 1:].permute(0, 2, 1)
142
  x = x.unsqueeze(2)
143
- x = self.classifier(x).squeeze() # x shape: [bsz, hidden_dim, 1, seq_len]
144
- return x
145
 
146
  elif self.type == "swin_transformer":
147
  x = self.model.features(x) # [B, H, W, C]
148
  x = x.permute(0, 3, 1, 2)
149
  x = self.avgpool(x) # [B, C, 1, W]
150
- x = self.classifier(x).squeeze()
151
- return x
152
 
153
  return None
 
79
  "monetjoe/cv_backbones",
80
  split=self.imgnet_ver,
81
  cache_dir="./__pycache__",
82
+ trust_remote_code=True,
83
  )
84
  backbone_info = self._get_backbone(backbone, backbone_list)
85
  return (
 
141
  x = self.model.encoder(x)
142
  x = x[:, 1:].permute(0, 2, 1)
143
  x = x.unsqueeze(2)
144
+ return self.classifier(x).squeeze()
 
145
 
146
  elif self.type == "swin_transformer":
147
  x = self.model.features(x) # [B, H, W, C]
148
  x = x.permute(0, 3, 1, 2)
149
  x = self.avgpool(x) # [B, C, 1, W]
150
+ return self.classifier(x).squeeze()
 
151
 
152
  return None