|
|
|
|
|
import argparse |
|
from functools import lru_cache |
|
import json |
|
from pathlib import Path |
|
import platform |
|
import shutil |
|
import tempfile |
|
import zipfile |
|
from typing import Tuple |
|
|
|
import gradio as gr |
|
from dill.pointers import parents |
|
from huggingface_hub import snapshot_download |
|
import numpy as np |
|
import torch |
|
|
|
from project_settings import environment, project_path |
|
from toolbox.torch.utils.data.vocabulary import Vocabulary |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--examples_dir", |
|
|
|
default=(project_path / "data/examples").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--trained_model_dir", |
|
default=(project_path / "trained_models").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--server_port", |
|
default=environment.get("server_port", 7860), |
|
type=int |
|
) |
|
|
|
parser.add_argument( |
|
"--models_repo_id", |
|
default="qgyd2021/vm_sound_classification", |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
@lru_cache(maxsize=100) |
|
def load_model(model_file: Path): |
|
with zipfile.ZipFile(model_file, "r") as f_zip: |
|
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" |
|
if out_root.exists(): |
|
shutil.rmtree(out_root.as_posix()) |
|
out_root.mkdir(parents=True, exist_ok=True) |
|
f_zip.extractall(path=out_root) |
|
|
|
tgt_path = out_root / model_file.stem |
|
jit_model_file = tgt_path / "trace_model.zip" |
|
vocab_path = tgt_path / "vocabulary" |
|
|
|
vocabulary = Vocabulary.from_files(vocab_path.as_posix()) |
|
|
|
with open(jit_model_file.as_posix(), "rb") as f: |
|
model = torch.jit.load(f) |
|
model.eval() |
|
|
|
shutil.rmtree(tgt_path) |
|
|
|
d = { |
|
"model": model, |
|
"vocabulary": vocabulary |
|
} |
|
return d |
|
|
|
|
|
def click_button(audio: np.ndarray, |
|
model_name: str, |
|
ground_true: str) -> Tuple[str, float]: |
|
|
|
sample_rate, signal = audio |
|
|
|
model_file = "trained_models/{}.zip".format(model_name) |
|
model_file = Path(model_file) |
|
d = load_model(model_file) |
|
|
|
model = d["model"] |
|
vocabulary = d["vocabulary"] |
|
|
|
inputs = signal / (1 << 15) |
|
inputs = torch.tensor(inputs, dtype=torch.float32) |
|
inputs = torch.unsqueeze(inputs, dim=0) |
|
|
|
with torch.no_grad(): |
|
logits = model.forward(inputs) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
label_idx = torch.argmax(probs, dim=-1) |
|
|
|
label_idx = label_idx.cpu() |
|
probs = probs.cpu() |
|
|
|
label_idx = label_idx.numpy()[0] |
|
prob = probs.numpy()[0][label_idx] |
|
|
|
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") |
|
|
|
return label_str, round(prob, 4) |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
examples_dir = Path(args.examples_dir) |
|
trained_model_dir = Path(args.trained_model_dir) |
|
trained_model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
_ = snapshot_download( |
|
repo_id=args.models_repo_id, |
|
local_dir=trained_model_dir.as_posix() |
|
) |
|
|
|
|
|
example_zip_file = trained_model_dir / "examples.zip" |
|
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: |
|
out_root = examples_dir |
|
if out_root.exists(): |
|
shutil.rmtree(out_root.as_posix()) |
|
out_root.mkdir(parents=True, exist_ok=True) |
|
f_zip.extractall(path=out_root) |
|
|
|
|
|
model_choices = list() |
|
for filename in trained_model_dir.glob("*.zip"): |
|
model_name = filename.stem |
|
if model_name == "examples": |
|
continue |
|
model_choices.append(model_name) |
|
|
|
|
|
examples = list() |
|
for filename in examples_dir.glob("**/*/*.wav"): |
|
label = filename.parts[-2] |
|
|
|
examples.append([ |
|
filename.as_posix(), |
|
model_choices[0], |
|
label |
|
]) |
|
|
|
|
|
brief_description = """ |
|
国际语音智能外呼系统, 电话声音分类, 8000, int16. |
|
""" |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
gr.Markdown(value=brief_description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
c_audio = gr.Audio(label="audio") |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name") |
|
with gr.Column(scale=3): |
|
c_ground_true = gr.Textbox(label="ground_true") |
|
|
|
c_button = gr.Button("run", variant="primary") |
|
with gr.Column(scale=3): |
|
c_label = gr.Textbox(label="label") |
|
c_probability = gr.Number(label="probability") |
|
|
|
gr.Examples( |
|
examples, |
|
inputs=[c_audio, c_model_name, c_ground_true], |
|
outputs=[c_label, c_probability], |
|
fn=click_button, |
|
examples_per_page=5, |
|
) |
|
|
|
c_button.click( |
|
click_button, |
|
inputs=[c_audio, c_model_name, c_ground_true], |
|
outputs=[c_label, c_probability], |
|
) |
|
|
|
blocks.queue().launch( |
|
share=False if platform.system() == "Windows" else False, |
|
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
|
server_port=args.server_port |
|
) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|