Spaces:
Runtime error
Runtime error
import gradio as gr | |
import time | |
from faster_whisper import WhisperModel | |
from utils import ffmpeg_read, stt, greeting_list | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"] | |
audio_model = WhisperModel("base", compute_type="int8", device="cpu") | |
text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
corpus_embeddings = torch.load('corpus_embeddings.pt') | |
model_type = "whisper" | |
def speech_to_text(upload_audio): | |
""" | |
Transcribe audio using whisper model. | |
""" | |
# Transcribe audio | |
if model_type == "whisper": | |
transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True) | |
segments_raw, info = audio_model.transcribe(upload_audio, **transcribe_options) | |
segments = [segment.text for segment in segments_raw] | |
return ' '.join(segments) | |
else: | |
text = stt(upload_audio) | |
return text | |
def voice_detect(audio, recongnize_text=""): | |
""" | |
Transcribe audio using whisper model. | |
""" | |
time.sleep(2) | |
if len(recongnize_text) !=0: | |
count_state = int(recongnize_text[0]) | |
recongnize_text = recongnize_text[1:] | |
else: | |
count_state = 0 | |
threshold = 0.8 | |
detect_greeting = 0 | |
text = speech_to_text(audio) | |
recongnize_text = recongnize_text + " " + text | |
query_embedding = text_model.encode(text, convert_to_tensor=True) | |
for greeting in greeting_list: | |
if greeting in text: | |
detect_greeting = 1 | |
break | |
if detect_greeting == 0: | |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0] | |
if hits[0]['score'] > threshold: | |
detect_greeting = 1 | |
recongnize_state = str(count_state + detect_greeting) + recongnize_text | |
return recongnize_text, recongnize_state, count_state | |
demo = gr.Interface( | |
title= "Greeting detection demo app", | |
fn=voice_detect, | |
inputs=[ | |
gr.Audio(source="microphone", type="filepath", streaming=True), | |
"state", | |
], | |
outputs=[ | |
gr.Textbox(label="Predicted"), | |
"state", | |
gr.Number(label="Greeting count"), | |
], | |
live=True) | |
demo.launch(debug=True) |