Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import json | |
from functools import lru_cache | |
import logging | |
from pathlib import Path | |
import platform | |
import shutil | |
from typing import Tuple | |
import zipfile | |
import time | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
import log | |
from project_settings import environment, project_path, log_directory | |
from toolbox.os.command import Command | |
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet | |
from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN | |
from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet | |
log.setup_size_rotating(log_directory=log_directory) | |
logger = logging.getLogger("main") | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--examples_dir", | |
# default=(project_path / "data").as_posix(), | |
default=(project_path / "data/examples").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--models_repo_id", | |
default="qgyd2021/nx_denoise", | |
type=str | |
) | |
parser.add_argument( | |
"--trained_model_dir", | |
default=(project_path / "trained_models").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--hf_token", | |
default=environment.get("hf_token"), | |
type=str, | |
) | |
parser.add_argument( | |
"--server_port", | |
default=environment.get("server_port", 7860), | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
def shell(cmd: str): | |
return Command.popen(cmd) | |
denoise_engines = { | |
"dfnet-nx-dns3": { | |
"infer_cls": InferenceDfNet, | |
"kwargs": { | |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet-nx-dns3.zip").as_posix() | |
} | |
}, | |
"frcrn-dns3": { | |
"infer_cls": InferenceFRCRN, | |
"kwargs": { | |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3.zip").as_posix() | |
} | |
}, | |
"mpnet-nx-speech": { | |
"infer_cls": InferenceMPNet, | |
"kwargs": { | |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech.zip").as_posix() | |
} | |
}, | |
} | |
def load_denoise_model(infer_cls, **kwargs): | |
infer_engine = infer_cls(**kwargs) | |
return infer_engine | |
def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_t = None, engine: str = None): | |
if noisy_audio_file_t is None and noisy_audio_microphone_t is None: | |
raise gr.Error(f"audio file and microphone is null.") | |
if noisy_audio_file_t is not None and noisy_audio_microphone_t is not None: | |
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.") | |
noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t | |
sample_rate, signal = noisy_audio_t | |
audio_duration = signal.shape[-1] // 8000 | |
# Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。 | |
logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}") | |
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32) | |
infer_engine_param = denoise_engines.get(engine) | |
if infer_engine_param is None: | |
raise gr.Error(f"invalid denoise engine: {engine}.") | |
try: | |
infer_cls = infer_engine_param["infer_cls"] | |
kwargs = infer_engine_param["kwargs"] | |
infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs) | |
begin = time.time() | |
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio) | |
time_cost = time.time() - begin | |
fpr = time_cost / audio_duration | |
info = { | |
"time_cost": round(time_cost, 4), | |
"audio_duration": round(audio_duration, 4), | |
"fpr": round(fpr, 4) | |
} | |
message = json.dumps(info, ensure_ascii=False, indent=4) | |
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16) | |
except Exception as e: | |
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") | |
enhanced_audio_t = (sample_rate, enhanced_audio) | |
return enhanced_audio_t, message | |
def main(): | |
args = get_args() | |
examples_dir = Path(args.examples_dir) | |
trained_model_dir = Path(args.trained_model_dir) | |
# download models | |
if not trained_model_dir.exists(): | |
trained_model_dir.mkdir(parents=True, exist_ok=True) | |
_ = snapshot_download( | |
repo_id=args.models_repo_id, | |
local_dir=trained_model_dir.as_posix(), | |
token=args.hf_token, | |
) | |
# choices | |
denoise_engine_choices = list(denoise_engines.keys()) | |
# examples | |
if not examples_dir.exists(): | |
example_zip_file = trained_model_dir / "examples.zip" | |
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: | |
out_root = examples_dir | |
if out_root.exists(): | |
shutil.rmtree(out_root.as_posix()) | |
out_root.mkdir(parents=True, exist_ok=True) | |
f_zip.extractall(path=out_root) | |
# examples | |
examples = list() | |
for filename in examples_dir.glob("**/*.wav"): | |
examples.append([ | |
filename.as_posix(), | |
None, | |
denoise_engine_choices[0], | |
]) | |
# ui | |
with gr.Blocks() as blocks: | |
gr.Markdown(value="nx denoise.") | |
with gr.Tabs(): | |
with gr.TabItem("denoise"): | |
with gr.Row(): | |
with gr.Column(variant="panel", scale=5): | |
with gr.Tabs(): | |
with gr.TabItem("file"): | |
dn_noisy_audio_file = gr.Audio(label="noisy_audio") | |
with gr.TabItem("microphone"): | |
dn_noisy_audio_microphone = gr.Audio(sources="microphone", label="noisy_audio") | |
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine") | |
dn_button = gr.Button(variant="primary") | |
with gr.Column(variant="panel", scale=5): | |
dn_enhanced_audio = gr.Audio(label="enhanced_audio") | |
dn_message = gr.Textbox(lines=1, max_lines=20, label="message") | |
dn_button.click( | |
when_click_denoise_button, | |
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], | |
outputs=[dn_enhanced_audio, dn_message] | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], | |
outputs=[dn_enhanced_audio, dn_message], | |
fn=when_click_denoise_button, | |
# cache_examples=True, | |
# cache_mode="lazy", | |
) | |
with gr.TabItem("shell"): | |
shell_text = gr.Textbox(label="cmd") | |
shell_button = gr.Button("run") | |
shell_output = gr.Textbox(label="output") | |
shell_button.click( | |
shell, | |
inputs=[shell_text,], | |
outputs=[shell_output], | |
) | |
# http://127.0.0.1:7865/ | |
blocks.queue().launch( | |
share=False if platform.system() == "Windows" else False, | |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
server_port=args.server_port | |
) | |
return | |
if __name__ == "__main__": | |
main() | |