Spaces:
Running
Running
File size: 3,802 Bytes
bd94e77 8ce0f99 bd94e77 8ce0f99 1e78a70 bd94e77 1e78a70 bd94e77 8ce0f99 bd94e77 5e7d9ca 1e78a70 bd94e77 8ce0f99 5e7d9ca 1e78a70 bd94e77 1e78a70 bd94e77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from pathlib import Path
import platform
import gradio as gr
from huggingface_hub import snapshot_download
import numpy as np
import torch
from project_settings import environment, project_path
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/vm_sound_classification",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
denoise_engines = dict()
def when_click_denoise_button(noisy_audio_t, engine: str):
sample_rate, signal = noisy_audio_t
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
infer_engine = denoise_engines.get(engine)
if infer_engine is None:
raise gr.Error(f"invalid denoise engine: {engine}.")
try:
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
except Exception as e:
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
enhanced_audio_t = (sample_rate, enhanced_audio)
return enhanced_audio_t, None
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# engines
global denoise_engines
denoise_engines = {
"mpnet": InferenceMPNet(
pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet_aishell_20250221.zip").as_posix(),
),
}
# choices
denoise_engine_choices = list(denoise_engines.keys())
# ui
with gr.Blocks() as blocks:
gr.Markdown(value="nx denoise.")
with gr.Tabs():
with gr.TabItem("denoise"):
with gr.Row():
with gr.Column(variant="panel", scale=5):
dn_noisy_audio = gr.Audio(label="noisy_audio")
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
dn_button = gr.Button(variant="primary")
with gr.Column(variant="panel", scale=5):
dn_enhanced_audio = gr.Audio(label="enhanced_audio")
dn_clean_audio = gr.Audio(label="clean_audio")
dn_button.click(
when_click_denoise_button,
inputs=[dn_noisy_audio, dn_engine],
outputs=[dn_enhanced_audio, dn_clean_audio]
)
# http://127.0.0.1:7864/
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()
|