File size: 2,250 Bytes
f5bdd75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import time
from faster_whisper import WhisperModel
from utils import ffmpeg_read, stt, greeting_list
from sentence_transformers import SentenceTransformer, util
import torch

whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"]
audio_model = WhisperModel("base", compute_type="int8", device="cpu")
text_model = SentenceTransformer('all-MiniLM-L6-v2')
corpus_embeddings = torch.load('corpus_embeddings.pt')
model_type = "whisper"

def speech_to_text(upload_audio):
    """
    Transcribe audio using whisper model.
    """
    # Transcribe audio
    if model_type == "whisper":
        transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True)
        segments_raw, info = audio_model.transcribe(upload_audio, **transcribe_options)
        segments = [segment.text for segment in segments_raw]
        return ' '.join(segments)
    else:
        text = stt(upload_audio)
        return text

def voice_detect(audio, recongnize_text=""):
    """
    Transcribe audio using whisper model.
    """
    time.sleep(2)
    if len(recongnize_text) !=0:
        count_state = int(recongnize_text[0])
        recongnize_text = recongnize_text[1:]
    else:
        count_state = 0

    threshold = 0.8
    detect_greeting = 0
    text = speech_to_text(audio)
    recongnize_text = recongnize_text + " " + text
    query_embedding = text_model.encode(text, convert_to_tensor=True)
    for greeting in greeting_list:
        if greeting in text:
            detect_greeting = 1
            break
    if detect_greeting == 0:
        hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0]
        if hits[0]['score'] > threshold:
            detect_greeting = 1

    recongnize_state = str(count_state + detect_greeting) + recongnize_text
    return  recongnize_text, recongnize_state, count_state

demo = gr.Interface(
    title= "Greeting detection demo app",
    fn=voice_detect, 
    inputs=[
        gr.Audio(source="microphone", type="filepath", streaming=True), 
        "state", 
    ],
    outputs=[
        gr.Textbox(label="Predicted"),
        "state",
        gr.Number(label="Greeting count"),
    ],
    live=True)

demo.launch(debug=True)