Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from pathlib import Path | |
import platform | |
import shutil | |
import zipfile | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
import torch | |
from project_settings import environment, project_path | |
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--examples_dir", | |
# default=(project_path / "data").as_posix(), | |
default=(project_path / "data/examples").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--models_repo_id", | |
default="qgyd2021/nx_denoise", | |
type=str | |
) | |
parser.add_argument( | |
"--trained_model_dir", | |
default=(project_path / "trained_models").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--hf_token", | |
default=environment.get("hf_token"), | |
type=str, | |
) | |
parser.add_argument( | |
"--server_port", | |
default=environment.get("server_port", 7860), | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
denoise_engines = dict() | |
def when_click_denoise_button(noisy_audio_t, engine: str): | |
sample_rate, signal = noisy_audio_t | |
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32) | |
infer_engine = denoise_engines.get(engine) | |
if infer_engine is None: | |
raise gr.Error(f"invalid denoise engine: {engine}.") | |
try: | |
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio) | |
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16) | |
except Exception as e: | |
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") | |
enhanced_audio_t = (sample_rate, enhanced_audio) | |
return enhanced_audio_t | |
def main(): | |
args = get_args() | |
examples_dir = Path(args.examples_dir) | |
trained_model_dir = Path(args.trained_model_dir) | |
# download models | |
if not trained_model_dir.exists(): | |
trained_model_dir.mkdir(parents=True, exist_ok=True) | |
_ = snapshot_download( | |
repo_id=args.models_repo_id, | |
local_dir=trained_model_dir.as_posix(), | |
token=args.hf_token, | |
) | |
# engines | |
global denoise_engines | |
denoise_engines = { | |
"mpnet_aishell_20250221": InferenceMPNet( | |
pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet_aishell_20250221.zip").as_posix(), | |
), | |
} | |
# choices | |
denoise_engine_choices = list(denoise_engines.keys()) | |
# examples | |
example_zip_file = trained_model_dir / "examples.zip" | |
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: | |
out_root = examples_dir | |
if out_root.exists(): | |
shutil.rmtree(out_root.as_posix()) | |
out_root.mkdir(parents=True, exist_ok=True) | |
f_zip.extractall(path=out_root) | |
# examples | |
examples = list() | |
for filename in examples_dir.glob("**/*/*.wav"): | |
label = filename.parts[-2] | |
examples.append([ | |
filename.as_posix(), | |
denoise_engine_choices[0] | |
]) | |
# ui | |
with gr.Blocks() as blocks: | |
gr.Markdown(value="nx denoise.") | |
with gr.Tabs(): | |
with gr.TabItem("denoise"): | |
with gr.Row(): | |
with gr.Column(variant="panel", scale=5): | |
dn_noisy_audio = gr.Audio(label="noisy_audio") | |
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine") | |
dn_button = gr.Button(variant="primary") | |
with gr.Column(variant="panel", scale=5): | |
dn_enhanced_audio = gr.Audio(label="enhanced_audio") | |
dn_button.click( | |
when_click_denoise_button, | |
inputs=[dn_noisy_audio, dn_engine], | |
outputs=[dn_enhanced_audio] | |
) | |
# http://127.0.0.1:7864/ | |
blocks.queue().launch( | |
share=False if platform.system() == "Windows" else False, | |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
server_port=args.server_port | |
) | |
return | |
if __name__ == "__main__": | |
main() | |