UTMOSv2 / app.py
Wataru's picture
Update app.py
2ade89d verified
raw
history blame
2.33 kB
import importlib
from types import SimpleNamespace
import gradio as gr
import pandas as pd
import spaces
import torch
from utmosv2.utils import get_dataset, get_model
description = (
"# πŸš€ UTMOSv2 demo\n\n"
"[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n"
"This is a demonstration of MOS prediction using UTMOSv2. "
"This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.weight = None
cfg.num_workers = 1
@spaces.GPU
def predict_mos(audio_path: str, domain: str) -> float:
data = pd.DataFrame({"file_path": [audio_path]})
data["dataset"] = domain
data['mos'] = 0
preds = 0.0
for fold in range(5):
cfg.now_fold = fold
model = get_model(cfg, device)
for _ in range(5):
test_dataset = get_dataset(cfg, data, "test")
p = model(*[torch.tensor(t).unsqueeze(0) for t in test_dataset[0][:-1]])
preds += p[0]
preds /= 25.0
return preds
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
audio = gr.Audio(type="filepath", label="Audio")
domain = gr.Dropdown(
[
"sarulab",
"bvcc",
"somos",
"blizzard2008",
"blizzard2009",
"blizzard2010-EH1",
"blizzard2010-EH2",
"blizzard2010-ES1",
"blizzard2010-ES3",
"blizzard2011",
],
label="Data-domain ID for the MOS prediction",
)
submit = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Predicted MOS", type="text")
submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])
demo.queue().launch()