update
Browse files- examples/vm_sound_classification/run.sh +3 -0
- main.py +15 -0
- requirements.txt +1 -1
examples/vm_sound_classification/run.sh
CHANGED
@@ -18,6 +18,9 @@ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name f
|
|
18 |
sh run.sh --stage 0 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification4-ch16 \
|
19 |
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" --label_plan 4
|
20 |
|
|
|
|
|
|
|
21 |
"
|
22 |
|
23 |
END
|
|
|
18 |
sh run.sh --stage 0 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification4-ch16 \
|
19 |
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" --label_plan 4
|
20 |
|
21 |
+
sh run.sh --stage 0 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification2-ch32 \
|
22 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" --label_plan 2
|
23 |
+
|
24 |
"
|
25 |
|
26 |
END
|
main.py
CHANGED
@@ -10,6 +10,8 @@ import tempfile
|
|
10 |
import zipfile
|
11 |
|
12 |
import gradio as gr
|
|
|
|
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
|
@@ -34,6 +36,12 @@ def get_args():
|
|
34 |
default=environment.get("server_port", 7860),
|
35 |
type=int
|
36 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
args = parser.parse_args()
|
38 |
return args
|
39 |
|
@@ -104,6 +112,13 @@ def main():
|
|
104 |
|
105 |
examples_dir = Path(args.examples_dir)
|
106 |
trained_model_dir = Path(args.trained_model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# models
|
109 |
model_choices = list()
|
|
|
10 |
import zipfile
|
11 |
|
12 |
import gradio as gr
|
13 |
+
from dill.pointers import parents
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
import numpy as np
|
16 |
import torch
|
17 |
|
|
|
36 |
default=environment.get("server_port", 7860),
|
37 |
type=int
|
38 |
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
"--models_repo_id",
|
42 |
+
default="qgyd2021/vm_sound_classification",
|
43 |
+
type=str
|
44 |
+
)
|
45 |
args = parser.parse_args()
|
46 |
return args
|
47 |
|
|
|
112 |
|
113 |
examples_dir = Path(args.examples_dir)
|
114 |
trained_model_dir = Path(args.trained_model_dir)
|
115 |
+
trained_model_dir.mkdir(parents=True, exist_ok=True)
|
116 |
+
|
117 |
+
# download models
|
118 |
+
_ = snapshot_download(
|
119 |
+
repo_id=args.models_repo_id,
|
120 |
+
local_dir=trained_model_dir.as_posix()
|
121 |
+
)
|
122 |
|
123 |
# models
|
124 |
model_choices = list()
|
requirements.txt
CHANGED
@@ -9,5 +9,5 @@ tqdm==4.66.4
|
|
9 |
overrides==1.9.0
|
10 |
pyyaml==6.0.1
|
11 |
evaluate==0.4.2
|
12 |
-
gradio==4.
|
13 |
python-dotenv==1.0.1
|
|
|
9 |
overrides==1.9.0
|
10 |
pyyaml==6.0.1
|
11 |
evaluate==0.4.2
|
12 |
+
gradio==4.44.1
|
13 |
python-dotenv==1.0.1
|