admin commited on
Commit
a0eed22
·
1 Parent(s): f398992
Files changed (4) hide show
  1. app.py +48 -42
  2. model.py +9 -4
  3. requirements.txt +5 -3
  4. utils.py +34 -10
app.py CHANGED
@@ -9,7 +9,7 @@ import librosa.display
9
  import matplotlib.pyplot as plt
10
  from collections import Counter
11
  from model import EvalNet
12
- from utils import get_modelist, find_files, embed_img
13
 
14
 
15
  TRANSLATE = {
@@ -344,33 +344,38 @@ def most_frequent_value(lst: list):
344
 
345
 
346
  def infer(wav_path: str, log_name: str, folder_path=TEMP_DIR):
347
- if os.path.exists(folder_path):
348
- shutil.rmtree(folder_path)
 
 
 
349
 
350
- if not wav_path:
351
- return None, "Please input an audio!"
352
 
353
- spec = log_name.split("_")[-3]
354
- os.makedirs(folder_path, exist_ok=True)
355
- try:
356
  model = EvalNet(log_name, len(TRANSLATE)).model
357
  eval("wav2%s" % spec)(wav_path)
 
 
 
 
 
 
358
 
359
- except Exception as e:
360
- return None, f"{e}"
 
 
 
 
 
361
 
362
- jpgs = find_files(folder_path, ".jpg")
363
- preds = []
364
- for jpg in jpgs:
365
- input = embed_img(jpg)
366
- output: torch.Tensor = model(input)
367
- preds.append(torch.max(output.data, 1)[1])
368
 
369
- pred_id = most_frequent_value(preds)
370
- return (
371
- os.path.basename(wav_path),
372
- f"{TRANSLATE[CLASSES[pred_id]][0]} ({TRANSLATE[CLASSES[pred_id]][1].capitalize()})",
373
- )
374
 
375
 
376
  if __name__ == "__main__":
@@ -385,39 +390,40 @@ if __name__ == "__main__":
385
  gr.Interface(
386
  fn=infer,
387
  inputs=[
388
- gr.Audio(label="Upload a recording", type="filepath"),
389
- gr.Dropdown(choices=models, label="Select a model", value=models[0]),
390
  ],
391
  outputs=[
392
- gr.Textbox(label="Audio filename", show_copy_button=True),
 
393
  gr.Textbox(
394
- label="Chinese instrument recognition",
395
  show_copy_button=True,
396
  ),
397
  ],
398
  examples=examples,
399
  cache_examples=False,
400
  flagging_mode="never",
401
- title="It is recommended to keep the recording length around 3s.",
402
  )
403
 
404
  gr.Markdown(
405
- """
406
- # Cite
407
- ```bibtex
408
- @article{Zhou-2025,
409
- author = {Monan Zhou and Shenyang Xu and Zhaorui Liu and Zhaowen Wang and Feng Yu and Wei Li and Baoqiang Han},
410
- title = {CCMusic: An Open and Diverse Database for Chinese Music Information Retrieval Research},
411
- journal = {Transactions of the International Society for Music Information Retrieval},
412
- volume = {8},
413
- number = {1},
414
- pages = {22--38},
415
- month = {Mar},
416
- year = {2025},
417
- url = {https://doi.org/10.5334/tismir.194},
418
- doi = {10.5334/tismir.194}
419
- }
420
- ```"""
421
  )
422
 
423
  demo.launch()
 
9
  import matplotlib.pyplot as plt
10
  from collections import Counter
11
  from model import EvalNet
12
+ from utils import get_modelist, find_files, embed_img, _L, EN_US
13
 
14
 
15
  TRANSLATE = {
 
344
 
345
 
346
  def infer(wav_path: str, log_name: str, folder_path=TEMP_DIR):
347
+ status = "Success"
348
+ filename = result = None
349
+ try:
350
+ if os.path.exists(folder_path):
351
+ shutil.rmtree(folder_path)
352
 
353
+ if not wav_path:
354
+ return None, "请输入音频!"
355
 
356
+ spec = log_name.split("_")[-3]
357
+ os.makedirs(folder_path, exist_ok=True)
 
358
  model = EvalNet(log_name, len(TRANSLATE)).model
359
  eval("wav2%s" % spec)(wav_path)
360
+ jpgs = find_files(folder_path, ".jpg")
361
+ preds = []
362
+ for jpg in jpgs:
363
+ input = embed_img(jpg)
364
+ output: torch.Tensor = model(input)
365
+ preds.append(torch.max(output.data, 1)[1])
366
 
367
+ pred_id = most_frequent_value(preds)
368
+ filename = os.path.basename(wav_path)
369
+ result = (
370
+ TRANSLATE[CLASSES[pred_id]][1].capitalize()
371
+ if EN_US
372
+ else f"{TRANSLATE[CLASSES[pred_id]][0]} ({TRANSLATE[CLASSES[pred_id]][1].capitalize()})"
373
+ )
374
 
375
+ except Exception as e:
376
+ status = f"{e}"
 
 
 
 
377
 
378
+ return status, filename, result
 
 
 
 
379
 
380
 
381
  if __name__ == "__main__":
 
390
  gr.Interface(
391
  fn=infer,
392
  inputs=[
393
+ gr.Audio(label=_L("上传录音"), type="filepath"),
394
+ gr.Dropdown(choices=models, label=_L("选择模型"), value=models[0]),
395
  ],
396
  outputs=[
397
+ gr.Textbox(label=_L("状态栏"), show_copy_button=True),
398
+ gr.Textbox(label=_L("音频文件名"), show_copy_button=True),
399
  gr.Textbox(
400
+ label=_L("中国乐器识别"),
401
  show_copy_button=True,
402
  ),
403
  ],
404
  examples=examples,
405
  cache_examples=False,
406
  flagging_mode="never",
407
+ title=_L("建议录音时长保持在 3s 左右"),
408
  )
409
 
410
  gr.Markdown(
411
+ f"# {_L('引用')}"
412
+ + """
413
+ ```bibtex
414
+ @article{Zhou-2025,
415
+ author = {Monan Zhou and Shenyang Xu and Zhaorui Liu and Zhaowen Wang and Feng Yu and Wei Li and Baoqiang Han},
416
+ title = {CCMusic: An Open and Diverse Database for Chinese Music Information Retrieval Research},
417
+ journal = {Transactions of the International Society for Music Information Retrieval},
418
+ volume = {8},
419
+ number = {1},
420
+ pages = {22--38},
421
+ month = {Mar},
422
+ year = {2025},
423
+ url = {https://doi.org/10.5334/tismir.194},
424
+ doi = {10.5334/tismir.194}
425
+ }
426
+ ```"""
427
  )
428
 
429
  demo.launch()
model.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
 
4
  from datasets import load_dataset
5
- from utils import MODEL_DIR
6
 
7
 
8
  class EvalNet:
@@ -17,7 +18,7 @@ class EvalNet:
17
  self.m_type, self.input_size = self._model_info(m_ver)
18
 
19
  if not hasattr(models, m_ver):
20
- raise Exception("Unsupported model.")
21
 
22
  self.model = eval("models.%s()" % m_ver)
23
  linear_output = self._set_outsize()
@@ -34,11 +35,15 @@ class EvalNet:
34
  if ver == bb["ver"]:
35
  return bb
36
 
37
- print("Backbone name not found, using default option - alexnet.")
38
  return backbone_list[0]
39
 
40
  def _model_info(self, m_ver: str):
41
- backbone_list = load_dataset("monetjoe/cv_backbones", split="train")
 
 
 
 
42
  backbone = self._get_backbone(m_ver, backbone_list)
43
  m_type = str(backbone["type"])
44
  input_size = int(backbone["input_size"])
 
1
  import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
4
+ from modelscope.msdatasets import MsDataset
5
  from datasets import load_dataset
6
+ from utils import MODEL_DIR, EN_US
7
 
8
 
9
  class EvalNet:
 
18
  self.m_type, self.input_size = self._model_info(m_ver)
19
 
20
  if not hasattr(models, m_ver):
21
+ raise ValueError("不支持的模型")
22
 
23
  self.model = eval("models.%s()" % m_ver)
24
  linear_output = self._set_outsize()
 
35
  if ver == bb["ver"]:
36
  return bb
37
 
38
+ print("未找到骨干网络名称,使用默认选项 - alexnet")
39
  return backbone_list[0]
40
 
41
  def _model_info(self, m_ver: str):
42
+ backbone_list = (
43
+ load_dataset("monetjoe/cv_backbones", split="train")
44
+ if EN_US
45
+ else MsDataset.load("monetjoe/cv_backbones", split="v1")
46
+ )
47
  backbone = self._get_backbone(m_ver, backbone_list)
48
  m_type = str(backbone["type"])
49
  input_size = int(backbone["input_size"])
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- torch
2
- pillow
 
 
3
  librosa
4
  matplotlib
5
- torchvision
 
1
+ torch==2.6.0+cu118
2
+ -f https://download.pytorch.org/whl/torch
3
+ torchvision==0.21.0+cu118
4
+ -f https://download.pytorch.org/whl/torchvision
5
  librosa
6
  matplotlib
7
+ modelscope[framework]==1.21.0
utils.py CHANGED
@@ -1,10 +1,37 @@
1
  import os
2
  import torch
3
  import torchvision.transforms as transforms
4
- from huggingface_hub import snapshot_download
 
5
  from PIL import Image
6
 
7
- MODEL_DIR = snapshot_download("ccmusic-database/CTIS", cache_dir="./__pycache__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def toCUDA(x):
@@ -27,19 +54,16 @@ def find_files(folder_path=f"{MODEL_DIR}/examples", ext=".wav"):
27
 
28
 
29
  def get_modelist(model_dir=MODEL_DIR, assign_model=""):
30
- try:
31
- entries = os.listdir(model_dir)
32
- except OSError as e:
33
- print(f"Cannot access {model_dir}: {e}")
34
- return
35
-
36
  output = []
37
- for entry in entries:
 
38
  full_path = os.path.join(model_dir, entry)
 
39
  if entry == ".git" or entry == "examples":
40
- print(f"Skip .git / examples dir: {full_path}")
41
  continue
42
 
 
43
  if os.path.isdir(full_path):
44
  model = os.path.basename(full_path)
45
  if assign_model and assign_model.lower() in model:
 
1
  import os
2
  import torch
3
  import torchvision.transforms as transforms
4
+ import huggingface_hub
5
+ import modelscope
6
  from PIL import Image
7
 
8
+ EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
9
+
10
+ ZH2EN = {
11
+ "上传录音": "Upload a recording",
12
+ "选择模型": "Select a model",
13
+ "状态栏": "Status",
14
+ "音频文件名": "Audio filename",
15
+ "中国乐器识别": "Chinese instrument recognition",
16
+ "建议录音时长保持在 3s 左右": "It is recommended to keep the recording length around 3s.",
17
+ "引用": "Cite",
18
+ }
19
+
20
+ MODEL_DIR = (
21
+ huggingface_hub.snapshot_download(
22
+ "ccmusic-database/CTIS",
23
+ cache_dir="./__pycache__",
24
+ )
25
+ if EN_US
26
+ else modelscope.snapshot_download(
27
+ "ccmusic-database/CTIS",
28
+ cache_dir="./__pycache__",
29
+ )
30
+ )
31
+
32
+
33
+ def _L(zh_txt: str):
34
+ return ZH2EN[zh_txt] if EN_US else zh_txt
35
 
36
 
37
  def toCUDA(x):
 
54
 
55
 
56
  def get_modelist(model_dir=MODEL_DIR, assign_model=""):
 
 
 
 
 
 
57
  output = []
58
+ for entry in os.listdir(model_dir):
59
+ # 获取完整路径
60
  full_path = os.path.join(model_dir, entry)
61
+ # 跳过'.git'文件夹
62
  if entry == ".git" or entry == "examples":
63
+ print(f"跳过 .git examples 文件夹: {full_path}")
64
  continue
65
 
66
+ # 检查条目是文件还是目录
67
  if os.path.isdir(full_path):
68
  model = os.path.basename(full_path)
69
  if assign_model and assign_model.lower() in model: