HoneyTian commited on
Commit
7b2b795
1 Parent(s): 1d6c27e
examples/vm_sound_classification/run.sh CHANGED
@@ -18,6 +18,9 @@ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name f
18
  sh run.sh --stage 0 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification4-ch16 \
19
  --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" --label_plan 4
20
 
 
 
 
21
  "
22
 
23
  END
 
18
  sh run.sh --stage 0 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification4-ch16 \
19
  --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" --label_plan 4
20
 
21
+ sh run.sh --stage 0 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification2-ch32 \
22
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" --label_plan 2
23
+
24
  "
25
 
26
  END
main.py CHANGED
@@ -10,6 +10,8 @@ import tempfile
10
  import zipfile
11
 
12
  import gradio as gr
 
 
13
  import numpy as np
14
  import torch
15
 
@@ -34,6 +36,12 @@ def get_args():
34
  default=environment.get("server_port", 7860),
35
  type=int
36
  )
 
 
 
 
 
 
37
  args = parser.parse_args()
38
  return args
39
 
@@ -104,6 +112,13 @@ def main():
104
 
105
  examples_dir = Path(args.examples_dir)
106
  trained_model_dir = Path(args.trained_model_dir)
 
 
 
 
 
 
 
107
 
108
  # models
109
  model_choices = list()
 
10
  import zipfile
11
 
12
  import gradio as gr
13
+ from dill.pointers import parents
14
+ from huggingface_hub import snapshot_download
15
  import numpy as np
16
  import torch
17
 
 
36
  default=environment.get("server_port", 7860),
37
  type=int
38
  )
39
+
40
+ parser.add_argument(
41
+ "--models_repo_id",
42
+ default="qgyd2021/vm_sound_classification",
43
+ type=str
44
+ )
45
  args = parser.parse_args()
46
  return args
47
 
 
112
 
113
  examples_dir = Path(args.examples_dir)
114
  trained_model_dir = Path(args.trained_model_dir)
115
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
116
+
117
+ # download models
118
+ _ = snapshot_download(
119
+ repo_id=args.models_repo_id,
120
+ local_dir=trained_model_dir.as_posix()
121
+ )
122
 
123
  # models
124
  model_choices = list()
requirements.txt CHANGED
@@ -9,5 +9,5 @@ tqdm==4.66.4
9
  overrides==1.9.0
10
  pyyaml==6.0.1
11
  evaluate==0.4.2
12
- gradio==4.37.1
13
  python-dotenv==1.0.1
 
9
  overrides==1.9.0
10
  pyyaml==6.0.1
11
  evaluate==0.4.2
12
+ gradio==4.44.1
13
  python-dotenv==1.0.1