Spaces:
Runtime error
Runtime error
File size: 684 Bytes
e8bade5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import gradio as gr
from cnn import CNNetwork
from server.preprocess import process_raw_wav, _wav_to_spec
model = CNNetwork()
state_dict = torch.load('models/void_20230522_223553.pth')
model.load_state_dict(state_dict)
LABELS = ["shafqat", "aman", "jake"]
def greet(input):
sr, wav = input
wav = torch.tensor([wav]).float()
wav = process_raw_wav(wav, sr, 48000, 3)
wav = _wav_to_spec(wav, 48000)
model_input = wav.unsqueeze(0)
output = model(model_input)
print(output)
prediction_index = torch.argmax(output, 1).item()
return LABELS[prediction_index]
demo = gr.Interface(fn=greet, inputs="mic", outputs="text")
demo.launch() |