zeimoto commited on
Commit
2906c35
·
verified ·
1 Parent(s): 3316aad
Files changed (1) hide show
  1. app.py +42 -18
app.py CHANGED
@@ -2,36 +2,43 @@ import streamlit as st
2
  from st_audiorec import st_audiorec
3
 
4
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
5
- from datasets import load_dataset
6
  import torch
 
7
 
8
- pipe = None
9
- audio_sample: bytes = None
10
- audio_transcription: str = None
11
 
12
  def main ():
13
 
14
- print("Run init model")
15
- pipe = init_model()
16
- # x = st.slider('Select a value')
17
- # st.write(x, 'squared is', x * x)
 
 
 
 
 
18
 
19
  print("Render UI")
20
  wav_audio_data = st_audiorec()
21
 
22
  if wav_audio_data is not None:
23
  print("Loading data...")
 
 
 
24
  st.audio(wav_audio_data, format='audio/wav')
25
- transcribe(wav_audio_data, pipe)
26
-
 
27
 
28
-
29
- # dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
30
- # sample = dataset[0]["audio"]
31
-
32
- # audio_file_path = "data/audio1.wav"
33
 
34
- def init_model ():
 
 
 
 
35
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
36
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
 
@@ -57,8 +64,16 @@ def init_model ():
57
  device=device,
58
  )
59
  print(f'Init model successful: {model}' )
 
60
  return pipe
61
-
 
 
 
 
 
 
 
62
  def transcribe (audio_sample: bytes, pipe) -> str:
63
 
64
  # dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
@@ -66,7 +81,16 @@ def transcribe (audio_sample: bytes, pipe) -> str:
66
  result = pipe(audio_sample)
67
  print(result)
68
 
69
- st.write('Result', result["text"])
 
 
 
 
 
 
 
 
 
70
 
71
  if __name__ == "__main__":
72
  main()
 
2
  from st_audiorec import st_audiorec
3
 
4
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
5
+ #from datasets import load_dataset
6
  import torch
7
+ from gliner import GLiNER
8
 
9
+ from resources import Lead_Labels, entity_labels, set_start, audit_elapsedtime
 
 
10
 
11
  def main ():
12
 
13
+
14
+ rec = init_model_trans()
15
+ ner = init_model_ner() #async
16
+
17
+ labels = entity_labels
18
+
19
+ text = "I have a proposal from cgd where they want one outsystems junior developers and one senior for an estimate of three hundred euros a day, for six months."
20
+ print(f"get entities from sample text: {text}")
21
+ get_entity_labels(model=ner, text=text, labels=labels)
22
 
23
  print("Render UI")
24
  wav_audio_data = st_audiorec()
25
 
26
  if wav_audio_data is not None:
27
  print("Loading data...")
28
+
29
+ if wav_audio_data is not None and rec is not None:
30
+ print("Loading data...")
31
  st.audio(wav_audio_data, format='audio/wav')
32
+ text = transcribe(wav_audio_data, rec)
33
+ if text is not None:
34
+ get_entity_labels(labels=labels, model=ner, text=text)
35
 
 
 
 
 
 
36
 
37
+ def init_model_trans ():
38
+ print("Initiating transcription model...")
39
+ func_name = "init_model_trans"
40
+ start = set_start()
41
+
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
44
 
 
64
  device=device,
65
  )
66
  print(f'Init model successful: {model}' )
67
+ audit_elapsedtime(function=func_name, start=start)
68
  return pipe
69
+
70
+ async def init_model_ner():
71
+ print("Initiating NER model...")
72
+ start = set_start()
73
+ model = GLiNER.from_pretrained("urchade/gliner_multi")
74
+ audit_elapsedtime(function="init_model_ner", start=start)
75
+ return model
76
+
77
  def transcribe (audio_sample: bytes, pipe) -> str:
78
 
79
  # dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
 
81
  result = pipe(audio_sample)
82
  print(result)
83
 
84
+ st.write('trancription: ', result["text"])
85
+ return result["text"]
86
+
87
+ def get_entity_labels(model: GLiNER, text: str, labels: list): #-> Lead_labels:
88
+ entities = model.predict_entities(text, labels)
89
+
90
+ for entity in entities:
91
+ print(entity["text"], "=>", entity["label"])
92
+ st.write('Entities: ', entities)
93
+ # return Lead_Labels()
94
 
95
  if __name__ == "__main__":
96
  main()