admin commited on
Commit
adac6eb
·
1 Parent(s): f6221dd
Files changed (3) hide show
  1. app.py +1 -2
  2. model.py +125 -2
  3. t_model.py +0 -152
app.py CHANGED
@@ -6,8 +6,7 @@ import numpy as np
6
  import pandas as pd
7
  import gradio as gr
8
  import librosa.display
9
- from model import EvalNet
10
- from t_model import t_EvalNet
11
  from utils import get_modelist, find_files, embed, MODEL_DIR
12
 
13
 
 
6
  import pandas as pd
7
  import gradio as gr
8
  import librosa.display
9
+ from model import EvalNet, t_EvalNet
 
10
  from utils import get_modelist, find_files, embed, MODEL_DIR
11
 
12
 
model.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- import torch.nn.functional as F
3
  import torch.nn as nn
4
- import numpy as np
5
  import torchvision.models as models
 
6
  from modelscope.msdatasets import MsDataset
7
 
8
 
@@ -181,3 +181,126 @@ class EvalNet:
181
  out.size(0), self.out_channel_before_classifier, self.H, self.H
182
  )
183
  return self.classifier(out).squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  import torchvision.models as models
5
+ import numpy as np
6
  from modelscope.msdatasets import MsDataset
7
 
8
 
 
181
  out.size(0), self.out_channel_before_classifier, self.H, self.H
182
  )
183
  return self.classifier(out).squeeze()
184
+
185
+
186
+ class t_EvalNet:
187
+ def __init__(
188
+ self,
189
+ backbone: str,
190
+ cls_num: int,
191
+ ori_T: int,
192
+ imgnet_ver="v1",
193
+ weight_path="",
194
+ ):
195
+ if not hasattr(models, backbone):
196
+ raise ValueError(f"Unsupported model {backbone}.")
197
+
198
+ self.imgnet_ver = imgnet_ver
199
+ self.type, self.weight_url, self.input_size = self._model_info(backbone)
200
+ self.model: torch.nn.Module = eval("models.%s()" % backbone)
201
+ self.ori_T = ori_T
202
+ if self.type == "vit":
203
+ self.hidden_dim = self.model.hidden_dim
204
+ self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
205
+
206
+ elif self.type == "swin_transformer":
207
+ self.hidden_dim = 768
208
+
209
+ self.cls_num = cls_num
210
+ self._set_classifier()
211
+ checkpoint = (
212
+ torch.load(weight_path)
213
+ if torch.cuda.is_available()
214
+ else torch.load(weight_path, map_location="cpu")
215
+ )
216
+ self.model.load_state_dict(checkpoint["model"], False)
217
+ self.classifier.load_state_dict(checkpoint["classifier"], False)
218
+ if torch.cuda.is_available():
219
+ self.model = self.model.cuda()
220
+ self.classifier = self.classifier.cuda()
221
+
222
+ self.model.eval()
223
+
224
+ def _get_backbone(self, backbone_ver, backbone_list):
225
+ for backbone_info in backbone_list:
226
+ if backbone_ver == backbone_info["ver"]:
227
+ return backbone_info
228
+
229
+ raise ValueError("[Backbone not found] Please check if --model is correct!")
230
+
231
+ def _model_info(self, backbone: str):
232
+ backbone_list = MsDataset.load(
233
+ "monetjoe/cv_backbones",
234
+ split=self.imgnet_ver,
235
+ cache_dir="./__pycache__",
236
+ trust_remote_code=True,
237
+ )
238
+ backbone_info = self._get_backbone(backbone, backbone_list)
239
+ return (
240
+ str(backbone_info["type"]),
241
+ str(backbone_info["url"]),
242
+ int(backbone_info["input_size"]),
243
+ )
244
+
245
+ def _create_classifier(self):
246
+ original_T_size = self.ori_T
247
+ self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1
248
+ upsample_module = nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1
249
+ nn.ConvTranspose2d(
250
+ self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
251
+ ),
252
+ nn.ReLU(inplace=True),
253
+ nn.BatchNorm2d(256),
254
+ nn.ConvTranspose2d(
255
+ 256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
256
+ ),
257
+ nn.ReLU(inplace=True),
258
+ nn.BatchNorm2d(128),
259
+ nn.ConvTranspose2d(
260
+ 128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
261
+ ),
262
+ nn.ReLU(inplace=True),
263
+ nn.BatchNorm2d(64),
264
+ nn.ConvTranspose2d(
265
+ 64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
266
+ ),
267
+ nn.ReLU(inplace=True),
268
+ nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T]
269
+ Interpolate(
270
+ size=(1, original_T_size), mode="bilinear", align_corners=False
271
+ ), # classifier
272
+ nn.Conv2d(32, 32, kernel_size=(1, 1)),
273
+ nn.ReLU(inplace=True),
274
+ nn.BatchNorm2d(32),
275
+ nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
276
+ )
277
+
278
+ return upsample_module
279
+
280
+ def _set_classifier(self): #### set custom classifier ####
281
+ if self.type == "vit" or self.type == "swin_transformer":
282
+ self.classifier = self._create_classifier()
283
+
284
+ def get_input_size(self):
285
+ return self.input_size
286
+
287
+ def forward(self, x: torch.Tensor):
288
+ if torch.cuda.is_available():
289
+ x = x.cuda()
290
+
291
+ if self.type == "vit":
292
+ x = self.model._process_input(x)
293
+ batch_class_token = self.class_token.expand(x.size(0), -1, -1).cuda()
294
+ x = torch.cat([batch_class_token, x], dim=1)
295
+ x = self.model.encoder(x)
296
+ x = x[:, 1:].permute(0, 2, 1)
297
+ x = x.unsqueeze(2)
298
+ return self.classifier(x).squeeze()
299
+
300
+ elif self.type == "swin_transformer":
301
+ x = self.model.features(x) # [B, H, W, C]
302
+ x = x.permute(0, 3, 1, 2)
303
+ x = self.avgpool(x) # [B, C, 1, W]
304
+ return self.classifier(x).squeeze()
305
+
306
+ return None
t_model.py DELETED
@@ -1,152 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision.models as models
5
- from modelscope.msdatasets import MsDataset
6
-
7
-
8
- class Interpolate(nn.Module):
9
- def __init__(
10
- self,
11
- size=None,
12
- scale_factor=None,
13
- mode="bilinear",
14
- align_corners=False,
15
- ):
16
- super(Interpolate, self).__init__()
17
- self.size = size
18
- self.scale_factor = scale_factor
19
- self.mode = mode
20
- self.align_corners = align_corners
21
-
22
- def forward(self, x):
23
- return F.interpolate(
24
- x,
25
- size=self.size,
26
- scale_factor=self.scale_factor,
27
- mode=self.mode,
28
- align_corners=self.align_corners,
29
- )
30
-
31
-
32
- class t_EvalNet:
33
- def __init__(
34
- self,
35
- backbone: str,
36
- cls_num: int,
37
- ori_T: int,
38
- imgnet_ver="v1",
39
- weight_path="",
40
- ):
41
- if not hasattr(models, backbone):
42
- raise ValueError(f"Unsupported model {backbone}.")
43
-
44
- self.imgnet_ver = imgnet_ver
45
- self.type, self.weight_url, self.input_size = self._model_info(backbone)
46
- self.model: torch.nn.Module = eval("models.%s()" % backbone)
47
- self.ori_T = ori_T
48
- if self.type == "vit":
49
- self.hidden_dim = self.model.hidden_dim
50
- self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
51
-
52
- elif self.type == "swin_transformer":
53
- self.hidden_dim = 768
54
-
55
- self.cls_num = cls_num
56
- self._set_classifier()
57
- checkpoint = (
58
- torch.load(weight_path)
59
- if torch.cuda.is_available()
60
- else torch.load(weight_path, map_location="cpu")
61
- )
62
- self.model.load_state_dict(checkpoint["model"], False)
63
- self.classifier.load_state_dict(checkpoint["classifier"], False)
64
- if torch.cuda.is_available():
65
- self.model = self.model.cuda()
66
- self.classifier = self.classifier.cuda()
67
-
68
- self.model.eval()
69
-
70
- def _get_backbone(self, backbone_ver, backbone_list):
71
- for backbone_info in backbone_list:
72
- if backbone_ver == backbone_info["ver"]:
73
- return backbone_info
74
-
75
- raise ValueError("[Backbone not found] Please check if --model is correct!")
76
-
77
- def _model_info(self, backbone: str):
78
- backbone_list = MsDataset.load(
79
- "monetjoe/cv_backbones",
80
- split=self.imgnet_ver,
81
- cache_dir="./__pycache__",
82
- trust_remote_code=True,
83
- )
84
- backbone_info = self._get_backbone(backbone, backbone_list)
85
- return (
86
- str(backbone_info["type"]),
87
- str(backbone_info["url"]),
88
- int(backbone_info["input_size"]),
89
- )
90
-
91
- def _create_classifier(self):
92
- original_T_size = self.ori_T
93
- self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1
94
- upsample_module = nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1
95
- nn.ConvTranspose2d(
96
- self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
97
- ),
98
- nn.ReLU(inplace=True),
99
- nn.BatchNorm2d(256),
100
- nn.ConvTranspose2d(
101
- 256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
102
- ),
103
- nn.ReLU(inplace=True),
104
- nn.BatchNorm2d(128),
105
- nn.ConvTranspose2d(
106
- 128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
107
- ),
108
- nn.ReLU(inplace=True),
109
- nn.BatchNorm2d(64),
110
- nn.ConvTranspose2d(
111
- 64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
112
- ),
113
- nn.ReLU(inplace=True),
114
- nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T]
115
- Interpolate(
116
- size=(1, original_T_size), mode="bilinear", align_corners=False
117
- ), # classifier
118
- nn.Conv2d(32, 32, kernel_size=(1, 1)),
119
- nn.ReLU(inplace=True),
120
- nn.BatchNorm2d(32),
121
- nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
122
- )
123
-
124
- return upsample_module
125
-
126
- def _set_classifier(self): #### set custom classifier ####
127
- if self.type == "vit" or self.type == "swin_transformer":
128
- self.classifier = self._create_classifier()
129
-
130
- def get_input_size(self):
131
- return self.input_size
132
-
133
- def forward(self, x: torch.Tensor):
134
- if torch.cuda.is_available():
135
- x = x.cuda()
136
-
137
- if self.type == "vit":
138
- x = self.model._process_input(x)
139
- batch_class_token = self.class_token.expand(x.size(0), -1, -1).cuda()
140
- x = torch.cat([batch_class_token, x], dim=1)
141
- x = self.model.encoder(x)
142
- x = x[:, 1:].permute(0, 2, 1)
143
- x = x.unsqueeze(2)
144
- return self.classifier(x).squeeze()
145
-
146
- elif self.type == "swin_transformer":
147
- x = self.model.features(x) # [B, H, W, C]
148
- x = x.permute(0, 3, 1, 2)
149
- x = self.avgpool(x) # [B, C, 1, W]
150
- return self.classifier(x).squeeze()
151
-
152
- return None