mohan007 commited on
Commit
f77db10
·
verified ·
1 Parent(s): 52be56c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +321 -0
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import time
4
+ # from omegaconf import OmegaConf
5
+ import shutil
6
+ import os
7
+ import wget
8
+ import time
9
+ variable = []
10
+ speech = ""
11
+ # context_2 = ""
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import logging
16
+ import torch
17
+ import os
18
+ import base64
19
+
20
+ from pyannote.audio import Pipeline
21
+ from transformers import pipeline, AutoModelForCausalLM
22
+ from diarization_utils import diarize
23
+ from huggingface_hub import HfApi
24
+ from pydantic import ValidationError
25
+ from starlette.exceptions import HTTPException
26
+
27
+ # from config import model_settings, InferenceConfig
28
+
29
+ import logging
30
+
31
+ from pydantic import BaseModel
32
+ from pydantic_settings import BaseSettings
33
+ from typing import Optional, Literal
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class ModelSettings(BaseSettings):
39
+ asr_model: str
40
+ assistant_model: Optional[str]
41
+ diarization_model: Optional[str]
42
+ hf_token: Optional[str]
43
+
44
+
45
+ class InferenceConfig(BaseModel):
46
+ task: Literal["transcribe", "translate"] = "transcribe"
47
+ batch_size: int = 24
48
+ assisted: bool = False
49
+ chunk_length_s: int = 30
50
+ sampling_rate: int = 16000
51
+ language: Optional[str] = None
52
+ num_speakers: Optional[int] = None
53
+ min_speakers: Optional[int] = None
54
+ max_speakers: Optional[int] = None
55
+
56
+ # from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR
57
+ # from nemo.collections.asr.parts.utils.decoder_timestamps_utils import ASRDecoderTimeStamps
58
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
59
+ # logger.info(f"Using device: {device.type}")
60
+ torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
61
+
62
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True)
63
+ model = AutoModel.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True,device_map='auto')
64
+ # base_model = "lyogavin/Anima-7B-100K"
65
+ # tokenizer = AutoTokenizer.from_pretrained(base_model)
66
+ # model = AutoModelForCausalLM.from_pretrained(
67
+ # base_model,
68
+ # bnb_4bit_compute_dtype=torch.float16,
69
+ # # torch_dtype=torch.float16,
70
+ # trust_remote_code=True,
71
+ # device_map="auto",
72
+ # load_in_4bit=True
73
+ # )
74
+ # model.eval()
75
+
76
+ assistant_model = AutoModelForCausalLM.from_pretrained(
77
+ "distil-whisper/distil-large-v3",
78
+ torch_dtype=torch_dtype,
79
+ low_cpu_mem_usage=True,
80
+ use_safetensors=True
81
+ )
82
+
83
+ assistant_model.to(device)
84
+
85
+ asr_pipeline = pipeline(
86
+ "automatic-speech-recognition",
87
+ model="openai/whisper-large-v3",
88
+ torch_dtype=torch_dtype,
89
+ device=device
90
+ )
91
+
92
+
93
+ HfApi().whoami(os.getenv('HF_TOKEN'))
94
+ diarization_pipeline = Pipeline.from_pretrained(
95
+ checkpoint_path="pyannote/speaker-diarization-3.1",
96
+ use_auth_token=os.getenv('HF_TOKEN'),
97
+ )
98
+ diarization_pipeline.to(device)
99
+
100
+
101
+ def upload_file(files):
102
+ file_paths = [file.name for file in files]
103
+
104
+ global variable
105
+ variable = file_paths
106
+
107
+ return file_paths
108
+
109
+
110
+
111
+
112
+ def audio_function():
113
+ # Call the function and return its result to be displayed
114
+
115
+ time_1 = time.time()
116
+ paths = variable
117
+
118
+ str1 = "processed speech"
119
+ for i in paths:
120
+ str1 = str1 + i
121
+
122
+ str1=str1.replace("processed speech","")
123
+ print("before processing ffmpeg ! ")
124
+
125
+ command_to_mp4_to_wav = "ffmpeg -i {} current_out.wav -y"
126
+ #-acodec pcm_s16le -ar 16000 -ac 1
127
+ os.system(command_to_mp4_to_wav.format(str1))
128
+
129
+ print("after ffmpeg")
130
+
131
+ # os.system("insanely-fast-whisper --file-name {}_new.wav --task transcribe --hf_token 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".format(str1.replace("mp3","")))
132
+
133
+ parameters = InferenceConfig()
134
+
135
+
136
+ generate_kwargs = {
137
+ "task": parameters.task,
138
+ "language": parameters.language,
139
+ "assistant_model": assistant_model if parameters.assisted else None
140
+ }
141
+
142
+
143
+ asr_outputs = asr_pipeline(
144
+ "current_out.wav",
145
+ chunk_length_s=parameters.chunk_length_s,
146
+ batch_size=parameters.batch_size,
147
+ generate_kwargs=generate_kwargs,
148
+ return_timestamps=True,
149
+ )
150
+
151
+
152
+
153
+
154
+ transcript = diarize(diarization_pipeline, "current_out.wav", parameters, asr_outputs)
155
+ return transcript,asr_outputs["chunks"],asr_outputs["text"]
156
+ return {
157
+ "speakers": transcript,
158
+ "chunks": asr_outputs["chunks"],
159
+ "text": asr_outputs["text"],
160
+ }
161
+ a=time.time()
162
+ DOMAIN_TYPE = "meeting" # Can be meeting or telephonic based on domain type of the audio file
163
+ CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
164
+
165
+ CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
166
+
167
+
168
+ CONFIG = wget.download(CONFIG_URL,"./")
169
+ cfg = OmegaConf.load(CONFIG)
170
+ # print(OmegaConf.to_yaml(cfg))
171
+
172
+
173
+ # Create a manifest file for input with below format.
174
+ # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-",
175
+ # "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"}
176
+ import json
177
+ meta = {
178
+ 'audio_filepath': "current_out.wav",
179
+ 'offset': 0,
180
+ 'duration':None,
181
+ 'label': 'infer',
182
+ 'text': '-',
183
+ 'num_speakers': None,
184
+ 'rttm_filepath': None,
185
+ 'uem_filepath' : None
186
+ }
187
+ with open(os.path.join('input_manifest.json'),'w') as fp:
188
+ json.dump(meta,fp)
189
+ fp.write('\n')
190
+
191
+ cfg.diarizer.manifest_filepath = 'input_manifest.json'
192
+ cfg.diarizer.out_dir = "./" # Directory to store intermediate files and prediction outputs
193
+ pretrained_speaker_model = 'titanet_large'
194
+ cfg.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
195
+ cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec = [1.5,1.25,1.0,0.75,0.5]
196
+ cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec = [0.75,0.625,0.5,0.375,0.1]
197
+ cfg.diarizer.speaker_embeddings.parameters.multiscale_weights= [1,1,1,1,1]
198
+ cfg.diarizer.oracle_vad = True # ----> ORACLE VAD
199
+ cfg.diarizer.clustering.parameters.oracle_num_speakers = False
200
+ # cfg.diarizer.manifest_filepath = 'input_manifest.json'
201
+ # # !cat {cfg.diarizer.manifest_filepath}
202
+ # pretrained_speaker_model='titanet_large'
203
+ # cfg.diarizer.manifest_filepath = cfg.diarizer.manifest_filepath
204
+ # cfg.diarizer.out_dir = "./" #Directory to store intermediate files and prediction outputs
205
+ # cfg.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
206
+ # cfg.diarizer.clustering.parameters.oracle_num_speakers=False
207
+
208
+ # Using Neural VAD and Conformer ASR
209
+ cfg.diarizer.vad.model_path = 'vad_multilingual_marblenet'
210
+ cfg.diarizer.asr.model_path = 'stt_en_conformer_ctc_large'
211
+ cfg.diarizer.oracle_vad = False # ----> Not using oracle VAD
212
+ cfg.diarizer.asr.parameters.asr_based_vad = False
213
+
214
+
215
+ asr_decoder_ts = ASRDecoderTimeStamps(cfg.diarizer)
216
+ asr_model = asr_decoder_ts.set_asr_model()
217
+ print(asr_model)
218
+ word_hyp, word_ts_hyp = asr_decoder_ts.run_ASR(asr_model)
219
+
220
+ print("Decoded word output dictionary: \n", word_hyp)
221
+ print("Word-level timestamps dictionary: \n", word_ts_hyp)
222
+
223
+
224
+ asr_diar_offline = OfflineDiarWithASR(cfg.diarizer)
225
+ asr_diar_offline.word_ts_anchor_offset = asr_decoder_ts.word_ts_anchor_offset
226
+
227
+ diar_hyp, diar_score = asr_diar_offline.run_diarization(cfg, word_ts_hyp)
228
+ print("Diarization hypothesis output: \n", diar_hyp)
229
+ trans_info_dict = asr_diar_offline.get_transcript_with_speaker_labels(diar_hyp, word_hyp, word_ts_hyp)
230
+ # print(trans_info_dict)
231
+
232
+ # with open(os.path.join('output_diarization.json'),'w') as fp1:
233
+ # json.dump(trans_info_dict,fp1)
234
+ # fp1.write('\n')
235
+ # b = time.time()
236
+ # print(b-a,"seconds diartization time for 50 min audio")
237
+
238
+
239
+ import json
240
+ context = ""
241
+ context_2 = ""
242
+ # global context_2
243
+ # with open("output.json","r") as fli:
244
+ # json_dict = json.load(fli)
245
+ # for lst in sorted(json_dict["speakers"], key=lambda x: x['timestamp'][0], reverse=False):
246
+ # context = context + str(lst["timestamp"][0])+" : "+str(lst["timestamp"][1]) + " = " + lst["text"]+"\n"
247
+ # context = context + str(lst["timestamp"][0])+" : "+str(lst["timestamp"][1]) + " = " + lst["speaker"]+" ; "+ lst["text"]+"\n"
248
+ for dct in trans_info_dict["current_out"]["sentences"]:
249
+ # context = context + "start_time : {} ".format(dct["start_time"]) + "end_time : {} ".format(dct["end_time"])+ "speaker : {} ".format(dct["speaker"]) + "\n"
250
+ context = context + str(dct["start_time"])+" : "+str(dct["end_time"]) + " = " + dct["speaker"]+" ; "+ dct["text"]+"\n"
251
+ context_2 = context_2 + str(dct["start_time"])+" : "+str(dct["end_time"]) + " = "+ dct["text"]+"\n"
252
+ global speech
253
+ speech = trans_info_dict["current_out"]["transcription"]
254
+
255
+ time_2 = time.time()
256
+
257
+ return context,context_2,str(int(time_2-time_1)) + " seconds"
258
+
259
+ def audio_function2():
260
+ # Call the function and return its result to be displayed
261
+
262
+ # global speech
263
+ str2 = speech
264
+ time_3 = time.time()
265
+
266
+
267
+ # prompt = " {} generate medical subjective objective assessment plan (soap) notes ?".format(str2)
268
+ prompt = " {} summary of sales call ? is the agent qualified the lead properly ?".format(str2)
269
+
270
+ # model = model.eval()
271
+ response, history = model.chat(tokenizer, prompt, history=[])
272
+ print(response)
273
+ # del model
274
+ # del tokenizer
275
+ # torch.cuda.empty_cache()
276
+ time_4 = time.time()
277
+ # response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
278
+ # print(response)
279
+
280
+ # inputs = tokenizer(prompt, return_tensors="pt")
281
+
282
+ # inputs['input_ids'] = inputs['input_ids'].cuda()
283
+ # inputs['attention_mask'] = inputs['attention_mask'].cuda()
284
+
285
+
286
+ # generate_ids = model.generate(**inputs, max_new_tokens=4096,
287
+ # only_last_logit=True, # to save memory
288
+ # use_cache=False, # when run into OOM, enable this can save memory
289
+ # xentropy=True)
290
+ # output = tokenizer.batch_decode(generate_ids,
291
+ # skip_special_tokens=True,
292
+ # clean_up_tokenization_spaces=False)
293
+
294
+ # tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
295
+ # model = AutoModelForCausalLM.from_pretrained("togethercomputer/LLaMA-2-7B-32K", trust_remote_code=True, torch_dtype=torch.float16,device_map="auto",bnb_4bit_compute_dtype=torch.float16,load_in_4bit=True)
296
+
297
+
298
+ # input_context = "summarize "+" the following {}".format(str2)
299
+ # input_ids = tokenizer.encode(input_context, return_tensors="pt").cuda()
300
+ # output = model.generate(input_ids, max_new_tokens=512, temperature=0.7)
301
+ # output_text = tokenizer.decode(output[0], skip_special_tokens=True)
302
+ # print(output_text,"wow what happened ")
303
+ # return output
304
+ return response,str(int(time_4-time_3)) + " seconds"
305
+
306
+
307
+ with gr.Blocks() as demo:
308
+ file_output = gr.File()
309
+ upload_button = gr.UploadButton("Click to Upload a File", file_types=["audio","video"], file_count="multiple")
310
+ upload_button.upload(upload_file, upload_button, file_output)
311
+ gr.Markdown("## Click process audio to display text from audio file")
312
+ submit_button = gr.Button("Process Audio")
313
+ output_text = gr.Textbox(label="Speech Diarization")
314
+ output_text_2 = gr.Textbox(label="Speech chunks")
315
+ submit_button.click(audio_function, outputs=[output_text,output_text_2,gr.Textbox(label=" asr_text :")])
316
+ gr.Markdown("## Click the Summarize to display call summary")
317
+ submit_button = gr.Button("Summarize")
318
+ output_text = gr.Textbox(label="SOAP Notes")
319
+ submit_button.click(audio_function2, outputs=[output_text,gr.Textbox(label="Time Taken :")])
320
+
321
+ demo.launch(server_name="0.0.0.0",auth = ('manish', 'openrainbow'),auth_message = "Enter your credentials")