Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,330 Bytes
b55d767 2ade89d b55d767 dbb3e47 b55d767 2ade89d b55d767 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import importlib
from types import SimpleNamespace
import gradio as gr
import pandas as pd
import spaces
import torch
from utmosv2.utils import get_dataset, get_model
description = (
"# π UTMOSv2 demo\n\n"
"[](https://github.com/sarulab-speech/UTMOSv2)\n\n"
"This is a demonstration of MOS prediction using UTMOSv2. "
"This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.weight = None
cfg.num_workers = 1
@spaces.GPU
def predict_mos(audio_path: str, domain: str) -> float:
data = pd.DataFrame({"file_path": [audio_path]})
data["dataset"] = domain
data['mos'] = 0
preds = 0.0
for fold in range(5):
cfg.now_fold = fold
model = get_model(cfg, device)
for _ in range(5):
test_dataset = get_dataset(cfg, data, "test")
p = model(*[torch.tensor(t).unsqueeze(0) for t in test_dataset[0][:-1]])
preds += p[0]
preds /= 25.0
return preds
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
audio = gr.Audio(type="filepath", label="Audio")
domain = gr.Dropdown(
[
"sarulab",
"bvcc",
"somos",
"blizzard2008",
"blizzard2009",
"blizzard2010-EH1",
"blizzard2010-EH2",
"blizzard2010-ES1",
"blizzard2010-ES3",
"blizzard2011",
],
label="Data-domain ID for the MOS prediction",
)
submit = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Predicted MOS", type="text")
submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])
demo.queue().launch() |