admin commited on
Commit
f1b22d5
·
1 Parent(s): dfb9456
Files changed (7) hide show
  1. .gitattributes +22 -10
  2. .gitignore +6 -0
  3. app.py +195 -0
  4. model.py +183 -0
  5. requirements.txt +5 -0
  6. t_model.py +153 -0
  7. utils.py +59 -0
.gitattributes CHANGED
@@ -1,35 +1,47 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.gguf* filter=lfs diff=lfs merge=lfs -text
36
+ *.ggml filter=lfs diff=lfs merge=lfs -text
37
+ *.llamafile* filter=lfs diff=lfs merge=lfs -text
38
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
39
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
40
+ *.npy filter=lfs diff=lfs merge=lfs -text
41
+ *.npz filter=lfs diff=lfs merge=lfs -text
42
+ *.pickle filter=lfs diff=lfs merge=lfs -text
43
+ *.pkl filter=lfs diff=lfs merge=lfs -text
44
+ *.tar filter=lfs diff=lfs merge=lfs -text
45
+ *.wasm filter=lfs diff=lfs merge=lfs -text
46
  *.zst filter=lfs diff=lfs merge=lfs -text
47
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pt
2
+ *__pycache__*
3
+ tmp/*
4
+ flagged/*
5
+ test.py
6
+ rename.sh
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import warnings
5
+ import numpy as np
6
+ import pandas as pd
7
+ import gradio as gr
8
+ import librosa.display
9
+ from model import EvalNet
10
+ from t_model import t_EvalNet
11
+ from utils import get_modelist, find_files, embed, MODEL_DIR
12
+
13
+
14
+ TRANSLATE = {
15
+ "chanyin": "Vibrato", # 颤音
16
+ "boxian": "Plucks", # 拨弦
17
+ "shanghua": "Upward Portamento", # 上滑音
18
+ "xiahua": "Downward Portamento", # 下滑音
19
+ "huazhi/guazou/lianmo/liantuo": "Glissando", # 花指\刮奏\连抹\连托
20
+ "yaozhi": "Tremolo", # 摇指
21
+ "dianyin": "Point Note", # 点音
22
+ }
23
+ CLASSES = list(TRANSLATE.keys())
24
+ TEMP_DIR = "./__pycache__/tmp"
25
+ SAMPLE_RATE = 44100
26
+ HOP_LENGTH = 512
27
+ TIME_LENGTH = 3
28
+
29
+
30
+ def logMel(y, sr=SAMPLE_RATE):
31
+ mel = librosa.feature.melspectrogram(
32
+ y=y,
33
+ sr=sr,
34
+ hop_length=HOP_LENGTH,
35
+ fmin=27.5,
36
+ )
37
+ return librosa.power_to_db(mel, ref=np.max)
38
+
39
+
40
+ def logCqt(y, sr=SAMPLE_RATE):
41
+ cqt = librosa.cqt(
42
+ y,
43
+ sr=sr,
44
+ hop_length=HOP_LENGTH,
45
+ fmin=27.5,
46
+ n_bins=88,
47
+ bins_per_octave=12,
48
+ )
49
+ return ((1.0 / 80.0) * librosa.core.amplitude_to_db(np.abs(cqt), ref=np.max)) + 1.0
50
+
51
+
52
+ def logChroma(y, sr=SAMPLE_RATE):
53
+ chroma = librosa.feature.chroma_stft(
54
+ y=y,
55
+ sr=sr,
56
+ hop_length=HOP_LENGTH,
57
+ )
58
+ return (
59
+ (1.0 / 80.0) * librosa.core.amplitude_to_db(np.abs(chroma), ref=np.max)
60
+ ) + 1.0
61
+
62
+
63
+ def RoW_norm(data):
64
+ common_sum = 0
65
+ square_sum = 0
66
+ tfle = 0
67
+ for i in range(len(data)):
68
+ tfle += (data[i].sum(-1).sum(0) != 0).astype("float").sum()
69
+ common_sum += data[i].sum(-1).sum(-1)
70
+ square_sum += (data[i] ** 2).sum(-1).sum(-1)
71
+
72
+ common_avg = common_sum / tfle
73
+ square_avg = square_sum / tfle
74
+ std = np.sqrt(square_avg - common_avg**2)
75
+ return common_avg, std
76
+
77
+
78
+ def norm(data):
79
+ size = data.shape
80
+ avg, std = RoW_norm(data)
81
+ avg = np.tile(avg.reshape((1, -1, 1, 1)), (size[0], 1, size[2], size[3]))
82
+ std = np.tile(std.reshape((1, -1, 1, 1)), (size[0], 1, size[2], size[3]))
83
+ return (data - avg) / std
84
+
85
+
86
+ def chunk_data(f):
87
+ x = []
88
+ xdata = np.transpose(f)
89
+ s = SAMPLE_RATE * TIME_LENGTH // HOP_LENGTH
90
+ length = int(np.ceil((int(len(xdata) / s) + 1) * s))
91
+ app = np.zeros((length - xdata.shape[0], xdata.shape[1]))
92
+ xdata = np.concatenate((xdata, app), 0)
93
+ for i in range(int(length / s)):
94
+ data = xdata[int(i * s) : int(i * s + s)]
95
+ x.append(np.transpose(data[:s, :]))
96
+
97
+ return np.array(x)
98
+
99
+
100
+ def load(audio_path: str, converto="mel"):
101
+ y, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
102
+ spec = eval("log%s(y, sr)" % converto.capitalize())
103
+ x_spec = chunk_data(spec)
104
+ Xtr_spec = np.expand_dims(x_spec, axis=3)
105
+ return list(norm(Xtr_spec))
106
+
107
+
108
+ def infer(audio_path: str, log_name: str):
109
+ if not audio_path:
110
+ return None, "Please input an audio!"
111
+
112
+ backbone = "_".join(log_name.split("_")[:-1])
113
+ spec = log_name.split("_")[-1]
114
+ try:
115
+ input = load(audio_path, converto=spec)
116
+ if "vit" in backbone or "swin" in backbone:
117
+ eval_net = t_EvalNet(
118
+ backbone,
119
+ len(TRANSLATE),
120
+ input[0].shape[1],
121
+ weight_path=f"{MODEL_DIR}/{log_name}.pt",
122
+ )
123
+
124
+ else:
125
+ eval_net = EvalNet(
126
+ backbone,
127
+ len(TRANSLATE),
128
+ input[0].shape[1],
129
+ weight_path=f"{MODEL_DIR}/{log_name}.pt",
130
+ )
131
+
132
+ except Exception as e:
133
+ return None, f"{e}"
134
+
135
+ input_size = eval_net.get_input_size()
136
+ embeded_input = embed(input, input_size)
137
+ output = list(eval_net.forward(embeded_input))
138
+ outputs = []
139
+ index = 0
140
+ for y in output:
141
+ preds = list(y.T)
142
+ for pred in preds:
143
+ outputs.append(
144
+ {
145
+ "Frame": index,
146
+ "Tech": TRANSLATE[CLASSES[torch.argmax(pred).item()]],
147
+ }
148
+ )
149
+ index += 1
150
+
151
+ return os.path.basename(audio_path), pd.DataFrame(outputs)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ warnings.filterwarnings("ignore")
156
+ models = get_modelist(assign_model="VGG19_mel")
157
+ examples = []
158
+ example_wavs = find_files()
159
+ for wav in example_wavs:
160
+ examples.append([wav, models[0]])
161
+
162
+ with gr.Blocks() as demo:
163
+ gr.Interface(
164
+ fn=infer,
165
+ inputs=[
166
+ gr.Audio(label="Upload audio", type="filepath"),
167
+ gr.Dropdown(choices=models, label="Select a model", value=models[0]),
168
+ ],
169
+ outputs=[
170
+ gr.Textbox(label="Audio filename", show_copy_button=True),
171
+ gr.Dataframe(label="Frame-level guzheng playing technique detection"),
172
+ ],
173
+ examples=examples,
174
+ cache_examples=False,
175
+ flagging_mode="never",
176
+ title="It is suggested that the recording time should not be too long",
177
+ )
178
+
179
+ gr.Markdown(
180
+ """
181
+ # Cite
182
+ ```bibtex
183
+ @dataset{zhaorui_liu_2021_5676893,
184
+ author = {Monan Zhou, Shenyang Xu, Zhaorui Liu, Zhaowen Wang, Feng Yu, Wei Li and Baoqiang Han},
185
+ title = {CCMusic: an Open and Diverse Database for Chinese Music Information Retrieval Research},
186
+ month = {mar},
187
+ year = {2024},
188
+ publisher = {HuggingFace},
189
+ version = {1.2},
190
+ url = {https://huggingface.co/ccmusic-database}
191
+ }
192
+ ```"""
193
+ )
194
+
195
+ demo.launch()
model.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torchvision.models as models
6
+ from modelscope.msdatasets import MsDataset
7
+
8
+
9
+ class Interpolate(nn.Module):
10
+ def __init__(
11
+ self,
12
+ size=None,
13
+ scale_factor=None,
14
+ mode="bilinear",
15
+ align_corners=False,
16
+ ):
17
+ super(Interpolate, self).__init__()
18
+ self.size = size
19
+ self.scale_factor = scale_factor
20
+ self.mode = mode
21
+ self.align_corners = align_corners
22
+
23
+ def forward(self, x):
24
+ return F.interpolate(
25
+ x,
26
+ size=self.size,
27
+ scale_factor=self.scale_factor,
28
+ mode=self.mode,
29
+ align_corners=self.align_corners,
30
+ )
31
+
32
+
33
+ class EvalNet:
34
+ def __init__(
35
+ self,
36
+ backbone: str,
37
+ cls_num: int,
38
+ ori_T: int,
39
+ imgnet_ver="v1",
40
+ weight_path="",
41
+ ):
42
+ if not hasattr(models, backbone):
43
+ raise ValueError(f"Unsupported model {backbone}.")
44
+
45
+ self.imgnet_ver = imgnet_ver
46
+ self.training = bool(weight_path == "")
47
+ self.type, self.weight_url, self.input_size = self._model_info(backbone)
48
+ self.model: torch.nn.Module = eval("models.%s()" % backbone)
49
+ self.ori_T = ori_T
50
+ self.out_channel_before_classifier = 0
51
+ self._set_channel_outsize() # set out channel size
52
+ self.cls_num = cls_num
53
+ self._set_classifier()
54
+ self._pseudo_foward()
55
+ checkpoint = (
56
+ torch.load(weight_path)
57
+ if torch.cuda.is_available()
58
+ else torch.load(weight_path, map_location="cpu")
59
+ ) # self.model.load_state_dict(checkpoint, False)
60
+ self.model.load_state_dict(checkpoint["model"], False)
61
+ self.classifier.load_state_dict(checkpoint["classifier"], False)
62
+ if torch.cuda.is_available():
63
+ self.model = self.model.cuda()
64
+ self.classifier = self.classifier.cuda()
65
+ self.model.eval()
66
+
67
+ def _get_backbone(self, backbone_ver, backbone_list):
68
+ for backbone_info in backbone_list:
69
+ if backbone_ver == backbone_info["ver"]:
70
+ return backbone_info
71
+
72
+ raise ValueError("[Backbone not found] Please check if --model is correct!")
73
+
74
+ def _model_info(self, backbone: str):
75
+ backbone_list = MsDataset.load(
76
+ "monetjoe/cv_backbones",
77
+ split=self.imgnet_ver,
78
+ cache_dir="./__pycache__",
79
+ )
80
+ backbone_info = self._get_backbone(backbone, backbone_list)
81
+ return (
82
+ str(backbone_info["type"]),
83
+ str(backbone_info["url"]),
84
+ int(backbone_info["input_size"]),
85
+ )
86
+
87
+ def _create_classifier(self):
88
+ original_T_size = self.ori_T
89
+ upsample_module = nn.Sequential(
90
+ nn.AdaptiveAvgPool2d((1, None)), # F -> 1
91
+ nn.ConvTranspose2d(
92
+ self.out_channel_before_classifier,
93
+ 256,
94
+ kernel_size=(1, 4),
95
+ stride=(1, 2),
96
+ padding=(0, 1),
97
+ ),
98
+ nn.ReLU(inplace=True),
99
+ nn.BatchNorm2d(256),
100
+ nn.ConvTranspose2d(
101
+ 256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
102
+ ),
103
+ nn.ReLU(inplace=True),
104
+ nn.BatchNorm2d(128),
105
+ nn.ConvTranspose2d(
106
+ 128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
107
+ ),
108
+ nn.ReLU(inplace=True),
109
+ nn.BatchNorm2d(64),
110
+ nn.ConvTranspose2d(
111
+ 64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
112
+ ),
113
+ nn.ReLU(inplace=True),
114
+ nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T]
115
+ Interpolate(
116
+ size=(1, original_T_size), mode="bilinear", align_corners=False
117
+ ), # classifier
118
+ nn.Conv2d(32, 32, kernel_size=(1, 1)),
119
+ nn.ReLU(inplace=True),
120
+ nn.BatchNorm2d(32),
121
+ nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
122
+ )
123
+
124
+ return upsample_module
125
+
126
+ def _set_channel_outsize(self): #### get the output size before classifier ####
127
+ conv2d_out_ch = []
128
+ for name, module in self.model.named_modules():
129
+ if isinstance(module, torch.nn.Conv2d):
130
+ conv2d_out_ch.append(module.out_channels)
131
+
132
+ if (
133
+ str(name).__contains__("classifier")
134
+ or str(name).__eq__("fc")
135
+ or str(name).__contains__("head")
136
+ ):
137
+ if isinstance(module, torch.nn.Conv2d):
138
+ conv2d_out_ch.append(module.in_channels)
139
+ break
140
+
141
+ self.out_channel_before_classifier = conv2d_out_ch[-1]
142
+
143
+ def _set_classifier(self): #### set custom classifier ####
144
+ if self.type == "resnet":
145
+ self.model.avgpool = nn.Identity()
146
+ self.model.fc = nn.Identity()
147
+ self.classifier = self._create_classifier()
148
+
149
+ elif (
150
+ self.type == "vgg" or self.type == "efficientnet" or self.type == "convnext"
151
+ ):
152
+ self.model.avgpool = nn.Identity()
153
+ self.model.classifier = nn.Identity()
154
+ self.classifier = self._create_classifier()
155
+
156
+ elif self.type == "squeezenet":
157
+ self.model.classifier = nn.Identity()
158
+ self.classifier = self._create_classifier()
159
+
160
+ def get_input_size(self):
161
+ return self.input_size
162
+
163
+ def _pseudo_foward(self):
164
+ temp = torch.randn(4, 3, self.input_size, self.input_size)
165
+ out = self.model(temp)
166
+ self.H = int(np.sqrt(out.size(1) / self.out_channel_before_classifier))
167
+
168
+ def forward(self, x):
169
+ if torch.cuda.is_available():
170
+ x = x.cuda()
171
+
172
+ if self.type == "convnext":
173
+ out = self.model(x)
174
+ out = self.classifier(out).squeeze()
175
+ return out
176
+
177
+ else:
178
+ out = self.model(x)
179
+ out = out.view(
180
+ out.size(0), self.out_channel_before_classifier, self.H, self.H
181
+ )
182
+ out = self.classifier(out).squeeze()
183
+ return out
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ pillow
3
+ librosa
4
+ matplotlib
5
+ torchvision
t_model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from modelscope.msdatasets import MsDataset
6
+
7
+
8
+ class Interpolate(nn.Module):
9
+ def __init__(
10
+ self,
11
+ size=None,
12
+ scale_factor=None,
13
+ mode="bilinear",
14
+ align_corners=False,
15
+ ):
16
+ super(Interpolate, self).__init__()
17
+ self.size = size
18
+ self.scale_factor = scale_factor
19
+ self.mode = mode
20
+ self.align_corners = align_corners
21
+
22
+ def forward(self, x):
23
+ return F.interpolate(
24
+ x,
25
+ size=self.size,
26
+ scale_factor=self.scale_factor,
27
+ mode=self.mode,
28
+ align_corners=self.align_corners,
29
+ )
30
+
31
+
32
+ class t_EvalNet:
33
+ def __init__(
34
+ self,
35
+ backbone: str,
36
+ cls_num: int,
37
+ ori_T: int,
38
+ imgnet_ver="v1",
39
+ weight_path="",
40
+ ):
41
+ if not hasattr(models, backbone):
42
+ raise ValueError(f"Unsupported model {backbone}.")
43
+
44
+ self.imgnet_ver = imgnet_ver
45
+ self.type, self.weight_url, self.input_size = self._model_info(backbone)
46
+ self.model: torch.nn.Module = eval("models.%s()" % backbone)
47
+ self.ori_T = ori_T
48
+ if self.type == "vit":
49
+ self.hidden_dim = self.model.hidden_dim
50
+ self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
51
+
52
+ elif self.type == "swin_transformer":
53
+ self.hidden_dim = 768
54
+
55
+ self.cls_num = cls_num
56
+ self._set_classifier()
57
+ checkpoint = (
58
+ torch.load(weight_path)
59
+ if torch.cuda.is_available()
60
+ else torch.load(weight_path, map_location="cpu")
61
+ )
62
+ self.model.load_state_dict(checkpoint["model"], False)
63
+ self.classifier.load_state_dict(checkpoint["classifier"], False)
64
+ if torch.cuda.is_available():
65
+ self.model = self.model.cuda()
66
+ self.classifier = self.classifier.cuda()
67
+
68
+ self.model.eval()
69
+
70
+ def _get_backbone(self, backbone_ver, backbone_list):
71
+ for backbone_info in backbone_list:
72
+ if backbone_ver == backbone_info["ver"]:
73
+ return backbone_info
74
+
75
+ raise ValueError("[Backbone not found] Please check if --model is correct!")
76
+
77
+ def _model_info(self, backbone: str):
78
+ backbone_list = MsDataset.load(
79
+ "monetjoe/cv_backbones",
80
+ split=self.imgnet_ver,
81
+ cache_dir="./__pycache__",
82
+ )
83
+ backbone_info = self._get_backbone(backbone, backbone_list)
84
+ return (
85
+ str(backbone_info["type"]),
86
+ str(backbone_info["url"]),
87
+ int(backbone_info["input_size"]),
88
+ )
89
+
90
+ def _create_classifier(self):
91
+ original_T_size = self.ori_T
92
+ self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1
93
+ upsample_module = nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1
94
+ nn.ConvTranspose2d(
95
+ self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
96
+ ),
97
+ nn.ReLU(inplace=True),
98
+ nn.BatchNorm2d(256),
99
+ nn.ConvTranspose2d(
100
+ 256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
101
+ ),
102
+ nn.ReLU(inplace=True),
103
+ nn.BatchNorm2d(128),
104
+ nn.ConvTranspose2d(
105
+ 128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
106
+ ),
107
+ nn.ReLU(inplace=True),
108
+ nn.BatchNorm2d(64),
109
+ nn.ConvTranspose2d(
110
+ 64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1)
111
+ ),
112
+ nn.ReLU(inplace=True),
113
+ nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T]
114
+ Interpolate(
115
+ size=(1, original_T_size), mode="bilinear", align_corners=False
116
+ ), # classifier
117
+ nn.Conv2d(32, 32, kernel_size=(1, 1)),
118
+ nn.ReLU(inplace=True),
119
+ nn.BatchNorm2d(32),
120
+ nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)),
121
+ )
122
+
123
+ return upsample_module
124
+
125
+ def _set_classifier(self): #### set custom classifier ####
126
+ if self.type == "vit" or self.type == "swin_transformer":
127
+ self.classifier = self._create_classifier()
128
+
129
+ def get_input_size(self):
130
+ return self.input_size
131
+
132
+ def forward(self, x: torch.Tensor):
133
+ if torch.cuda.is_available():
134
+ x = x.cuda()
135
+
136
+ if self.type == "vit":
137
+ x = self.model._process_input(x)
138
+ batch_class_token = self.class_token.expand(x.size(0), -1, -1).cuda()
139
+ x = torch.cat([batch_class_token, x], dim=1)
140
+ x = self.model.encoder(x)
141
+ x = x[:, 1:].permute(0, 2, 1)
142
+ x = x.unsqueeze(2)
143
+ x = self.classifier(x).squeeze() # x shape: [bsz, hidden_dim, 1, seq_len]
144
+ return x
145
+
146
+ elif self.type == "swin_transformer":
147
+ x = self.model.features(x) # [B, H, W, C]
148
+ x = x.permute(0, 3, 1, 2)
149
+ x = self.avgpool(x) # [B, C, 1, W]
150
+ x = self.classifier(x).squeeze()
151
+ return x
152
+
153
+ return None
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torchvision.transforms import Compose, Resize, Normalize
5
+ from modelscope import snapshot_download
6
+
7
+ MODEL_DIR = snapshot_download(
8
+ "ccmusic-database/Guzheng_Tech99",
9
+ cache_dir="./__pycache__",
10
+ )
11
+
12
+
13
+ def toCUDA(x):
14
+ if hasattr(x, "cuda"):
15
+ if torch.cuda.is_available():
16
+ return x.cuda()
17
+
18
+ return x
19
+
20
+
21
+ def find_files(folder_path=f"{MODEL_DIR}/examples", ext=".flac"):
22
+ audio_files = []
23
+ for root, _, files in os.walk(folder_path):
24
+ for file in files:
25
+ if file.endswith(ext):
26
+ file_path = os.path.join(root, file)
27
+ audio_files.append(file_path)
28
+
29
+ return audio_files
30
+
31
+
32
+ def get_modelist(model_dir=MODEL_DIR, assign_model=""):
33
+ pt_files = []
34
+ for _, _, files in os.walk(model_dir):
35
+ for file in files:
36
+ if file.endswith(".pt"):
37
+ model = os.path.basename(file)[:-3]
38
+ if assign_model and assign_model in model:
39
+ pt_files.append(model)
40
+ else:
41
+ pt_files.insert(0, model)
42
+
43
+ return pt_files
44
+
45
+
46
+ def embed(input: list, img_size: int):
47
+ compose = Compose(
48
+ [
49
+ Resize([img_size, img_size]),
50
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
51
+ ]
52
+ )
53
+ inputs = []
54
+ for x in input:
55
+ x = np.array(x).transpose(2, 0, 1)
56
+ x = torch.from_numpy(x).repeat(3, 1, 1)
57
+ inputs.append(compose(x).float())
58
+
59
+ return toCUDA(torch.tensor(np.array(inputs)))