vumichien commited on
Commit
f5bdd75
·
1 Parent(s): 8148b06

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from faster_whisper import WhisperModel
4
+ from utils import ffmpeg_read, stt, greeting_list
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import torch
7
+
8
+ whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"]
9
+ audio_model = WhisperModel("base", compute_type="int8", device="cpu")
10
+ text_model = SentenceTransformer('all-MiniLM-L6-v2')
11
+ corpus_embeddings = torch.load('corpus_embeddings.pt')
12
+ model_type = "whisper"
13
+
14
+ def speech_to_text(upload_audio):
15
+ """
16
+ Transcribe audio using whisper model.
17
+ """
18
+ # Transcribe audio
19
+ if model_type == "whisper":
20
+ transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True)
21
+ segments_raw, info = audio_model.transcribe(upload_audio, **transcribe_options)
22
+ segments = [segment.text for segment in segments_raw]
23
+ return ' '.join(segments)
24
+ else:
25
+ text = stt(upload_audio)
26
+ return text
27
+
28
+ def voice_detect(audio, recongnize_text=""):
29
+ """
30
+ Transcribe audio using whisper model.
31
+ """
32
+ time.sleep(2)
33
+ if len(recongnize_text) !=0:
34
+ count_state = int(recongnize_text[0])
35
+ recongnize_text = recongnize_text[1:]
36
+ else:
37
+ count_state = 0
38
+
39
+ threshold = 0.8
40
+ detect_greeting = 0
41
+ text = speech_to_text(audio)
42
+ recongnize_text = recongnize_text + " " + text
43
+ query_embedding = text_model.encode(text, convert_to_tensor=True)
44
+ for greeting in greeting_list:
45
+ if greeting in text:
46
+ detect_greeting = 1
47
+ break
48
+ if detect_greeting == 0:
49
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0]
50
+ if hits[0]['score'] > threshold:
51
+ detect_greeting = 1
52
+
53
+ recongnize_state = str(count_state + detect_greeting) + recongnize_text
54
+ return recongnize_text, recongnize_state, count_state
55
+
56
+ demo = gr.Interface(
57
+ title= "Greeting detection demo app",
58
+ fn=voice_detect,
59
+ inputs=[
60
+ gr.Audio(source="microphone", type="filepath", streaming=True),
61
+ "state",
62
+ ],
63
+ outputs=[
64
+ gr.Textbox(label="Predicted"),
65
+ "state",
66
+ gr.Number(label="Greeting count"),
67
+ ],
68
+ live=True)
69
+
70
+ demo.launch(debug=True)