HoneyTian commited on
Commit
8ce0f99
·
1 Parent(s): 6713e7b
Files changed (1) hide show
  1. main.py +30 -0
main.py CHANGED
@@ -1,9 +1,11 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import platform
5
 
6
  import gradio as gr
 
7
  import numpy as np
8
  import torch
9
 
@@ -13,6 +15,22 @@ from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
13
 
14
  def get_args():
15
  parser = argparse.ArgumentParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  parser.add_argument(
17
  "--hf_token",
18
  default=environment.get("hf_token"),
@@ -58,6 +76,18 @@ def when_click_denoise_button(noisy_audio_t, engine: str):
58
  def main():
59
  args = get_args()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # choices
62
  denoise_engine_choices = list(denoise_engines.keys())
63
 
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ from pathlib import Path
5
  import platform
6
 
7
  import gradio as gr
8
+ from huggingface_hub import snapshot_download
9
  import numpy as np
10
  import torch
11
 
 
15
 
16
  def get_args():
17
  parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--examples_dir",
20
+ # default=(project_path / "data").as_posix(),
21
+ default=(project_path / "data/examples").as_posix(),
22
+ type=str
23
+ )
24
+ parser.add_argument(
25
+ "--models_repo_id",
26
+ default="qgyd2021/vm_sound_classification",
27
+ type=str
28
+ )
29
+ parser.add_argument(
30
+ "--trained_model_dir",
31
+ default=(project_path / "trained_models").as_posix(),
32
+ type=str
33
+ )
34
  parser.add_argument(
35
  "--hf_token",
36
  default=environment.get("hf_token"),
 
76
  def main():
77
  args = get_args()
78
 
79
+ examples_dir = Path(args.examples_dir)
80
+ trained_model_dir = Path(args.trained_model_dir)
81
+
82
+ # download models
83
+ if not trained_model_dir.exists():
84
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
85
+ _ = snapshot_download(
86
+ repo_id=args.models_repo_id,
87
+ local_dir=trained_model_dir.as_posix(),
88
+ token=args.hf_token,
89
+ )
90
+
91
  # choices
92
  denoise_engine_choices = list(denoise_engines.keys())
93