File size: 684 Bytes
e8bade5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import gradio as gr

from cnn import CNNetwork
from server.preprocess import process_raw_wav, _wav_to_spec

model = CNNetwork()
state_dict = torch.load('models/void_20230522_223553.pth')
model.load_state_dict(state_dict)

LABELS = ["shafqat", "aman", "jake"]


def greet(input):
    sr, wav = input

    wav = torch.tensor([wav]).float()
    wav = process_raw_wav(wav, sr, 48000, 3)
    wav = _wav_to_spec(wav, 48000)

    model_input = wav.unsqueeze(0)
    output = model(model_input)
    print(output)

    prediction_index = torch.argmax(output, 1).item()
    return LABELS[prediction_index]

demo = gr.Interface(fn=greet, inputs="mic", outputs="text")

demo.launch()