File size: 2,330 Bytes
b55d767
 
 
 
 
 
2ade89d
b55d767
 
 
 
 
 
dbb3e47
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ade89d
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import importlib
from types import SimpleNamespace

import gradio as gr
import pandas as pd

import spaces
import torch

from utmosv2.utils import get_dataset, get_model

description = (
    "# πŸš€ UTMOSv2 demo\n\n"
    "[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n"
    "This is a demonstration of MOS prediction using UTMOSv2. "
    "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.weight = None
cfg.num_workers = 1

@spaces.GPU
def predict_mos(audio_path: str, domain: str) -> float:
    data = pd.DataFrame({"file_path": [audio_path]})
    data["dataset"] = domain
    data['mos'] = 0

    preds = 0.0
    for fold in range(5):
        cfg.now_fold = fold
        model = get_model(cfg, device)
        for _ in range(5):
            test_dataset = get_dataset(cfg, data, "test")
            p = model(*[torch.tensor(t).unsqueeze(0) for t in test_dataset[0][:-1]])
            preds += p[0]
    preds /= 25.0
    return preds


with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            audio = gr.Audio(type="filepath", label="Audio")
            domain = gr.Dropdown(
                [
                    "sarulab",
                    "bvcc",
                    "somos",
                    "blizzard2008",
                    "blizzard2009",
                    "blizzard2010-EH1",
                    "blizzard2010-EH2",
                    "blizzard2010-ES1",
                    "blizzard2010-ES3",
                    "blizzard2011",
                ],
                label="Data-domain ID for the MOS prediction",
            )
            submit = gr.Button(value="Submit")

        with gr.Column():
            output = gr.Textbox(label="Predicted MOS", type="text")
    submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])

demo.queue().launch()