Kr08 commited on
Commit
6db9237
·
verified ·
1 Parent(s): 51a5dfa

Optimized app.py with on-demand model loading and lighter models

Browse files
Files changed (1) hide show
  1. app.py +43 -29
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from audio_processing import process_audio, load_models
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
4
  import spaces
5
  import torch
6
  import logging
@@ -13,44 +13,54 @@ cuda_available = torch.cuda.is_available()
13
  device = "cuda" if cuda_available else "cpu"
14
  logger.info(f"Using device: {device}")
15
 
16
- # Load models globally
17
- print("Loading models...")
18
- try:
19
- load_models() # Load Whisper and diarization models
20
- except Exception as e:
21
- logger.error(f"Error loading Whisper and diarization models: {str(e)}")
22
- raise
23
 
 
 
24
  try:
25
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)
26
- summarizer_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
27
  except Exception as e:
28
- logger.error(f"Error loading summarization model: {str(e)}")
29
  raise
30
 
31
- try:
32
- qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad").to(device)
33
- qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
34
- except Exception as e:
35
- logger.error(f"Error loading QA model: {str(e)}")
36
- raise
 
 
 
37
 
38
- print("Models loaded successfully.")
 
 
 
 
 
 
39
 
40
  @spaces.GPU
41
- def transcribe_audio(audio_file, translate, model_size):
42
- language_segments, final_segments = process_audio(audio_file, translate=translate, model_size=model_size)
43
 
44
  output = "Detected language changes:\n\n"
45
  for segment in language_segments:
46
  output += f"Language: {segment['language']}\n"
47
  output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
48
 
49
- output += f"Transcription with language detection and speaker diarization (using {model_size} model):\n\n"
50
  full_text = ""
51
  for segment in final_segments:
52
- output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:\n"
53
- output += f"Original: {segment['text']}\n"
 
 
54
  if translate:
55
  output += f"Translated: {segment['translated']}\n"
56
  full_text += segment['translated'] + " "
@@ -62,6 +72,7 @@ def transcribe_audio(audio_file, translate, model_size):
62
 
63
  @spaces.GPU
64
  def summarize_text(text):
 
65
  inputs = summarizer_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
66
  summary_ids = summarizer_model.generate(inputs["input_ids"], max_length=150, min_length=50, do_sample=False)
67
  summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
@@ -69,6 +80,7 @@ def summarize_text(text):
69
 
70
  @spaces.GPU
71
  def answer_question(context, question):
 
72
  inputs = qa_tokenizer(question, context, return_tensors="pt").to(device)
73
  outputs = qa_model(**inputs)
74
  answer_start = torch.argmax(outputs.start_logits)
@@ -77,14 +89,14 @@ def answer_question(context, question):
77
  return answer
78
 
79
  @spaces.GPU
80
- def process_and_summarize(audio_file, translate, model_size):
81
- transcription, full_text = transcribe_audio(audio_file, translate, model_size)
82
  summary = summarize_text(full_text)
83
  return transcription, summary
84
 
85
  @spaces.GPU
86
- def qa_interface(audio_file, translate, model_size, question):
87
- _, full_text = transcribe_audio(audio_file, translate, model_size)
88
  answer = answer_question(full_text, question)
89
  return answer
90
 
@@ -96,13 +108,14 @@ with gr.Blocks() as iface:
96
  audio_input = gr.Audio(type="filepath")
97
  translate_checkbox = gr.Checkbox(label="Enable Translation")
98
  model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
 
99
  transcribe_button = gr.Button("Transcribe and Summarize")
100
  transcription_output = gr.Textbox(label="Transcription")
101
  summary_output = gr.Textbox(label="Summary")
102
 
103
  transcribe_button.click(
104
  process_and_summarize,
105
- inputs=[audio_input, translate_checkbox, model_dropdown],
106
  outputs=[transcription_output, summary_output]
107
  )
108
 
@@ -110,13 +123,14 @@ with gr.Blocks() as iface:
110
  qa_audio_input = gr.Audio(type="filepath")
111
  qa_translate_checkbox = gr.Checkbox(label="Enable Translation")
112
  qa_model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
 
113
  question_input = gr.Textbox(label="Ask a question about the audio")
114
  qa_button = gr.Button("Get Answer")
115
  answer_output = gr.Textbox(label="Answer")
116
 
117
  qa_button.click(
118
  qa_interface,
119
- inputs=[qa_audio_input, qa_translate_checkbox, qa_model_dropdown, question_input],
120
  outputs=answer_output
121
  )
122
 
 
1
  import gradio as gr
2
  from audio_processing import process_audio, load_models
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, pipeline
4
  import spaces
5
  import torch
6
  import logging
 
13
  device = "cuda" if cuda_available else "cpu"
14
  logger.info(f"Using device: {device}")
15
 
16
+ # Initialize model variables
17
+ summarizer_model = None
18
+ summarizer_tokenizer = None
19
+ qa_model = None
20
+ qa_tokenizer = None
 
 
21
 
22
+ # Load Whisper model
23
+ print("Loading Whisper model...")
24
  try:
25
+ load_models() # Load Whisper model
 
26
  except Exception as e:
27
+ logger.error(f"Error loading Whisper model: {str(e)}")
28
  raise
29
 
30
+ print("Whisper model loaded successfully.")
31
+
32
+ def load_summarization_model():
33
+ global summarizer_model, summarizer_tokenizer
34
+ if summarizer_model is None:
35
+ logger.info("Loading summarization model...")
36
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(device)
37
+ summarizer_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
38
+ logger.info("Summarization model loaded.")
39
 
40
+ def load_qa_model():
41
+ global qa_model, qa_tokenizer
42
+ if qa_model is None:
43
+ logger.info("Loading QA model...")
44
+ qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad").to(device)
45
+ qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
46
+ logger.info("QA model loaded.")
47
 
48
  @spaces.GPU
49
+ def transcribe_audio(audio_file, translate, model_size, use_diarization):
50
+ language_segments, final_segments = process_audio(audio_file, translate=translate, model_size=model_size, use_diarization=use_diarization)
51
 
52
  output = "Detected language changes:\n\n"
53
  for segment in language_segments:
54
  output += f"Language: {segment['language']}\n"
55
  output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
56
 
57
+ output += f"Transcription with language detection {f'and speaker diarization' if use_diarization else ''} (using {model_size} model):\n\n"
58
  full_text = ""
59
  for segment in final_segments:
60
+ output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']})"
61
+ if use_diarization:
62
+ output += f" {segment['speaker']}:"
63
+ output += f"\nOriginal: {segment['text']}\n"
64
  if translate:
65
  output += f"Translated: {segment['translated']}\n"
66
  full_text += segment['translated'] + " "
 
72
 
73
  @spaces.GPU
74
  def summarize_text(text):
75
+ load_summarization_model()
76
  inputs = summarizer_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
77
  summary_ids = summarizer_model.generate(inputs["input_ids"], max_length=150, min_length=50, do_sample=False)
78
  summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
80
 
81
  @spaces.GPU
82
  def answer_question(context, question):
83
+ load_qa_model()
84
  inputs = qa_tokenizer(question, context, return_tensors="pt").to(device)
85
  outputs = qa_model(**inputs)
86
  answer_start = torch.argmax(outputs.start_logits)
 
89
  return answer
90
 
91
  @spaces.GPU
92
+ def process_and_summarize(audio_file, translate, model_size, use_diarization):
93
+ transcription, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
94
  summary = summarize_text(full_text)
95
  return transcription, summary
96
 
97
  @spaces.GPU
98
+ def qa_interface(audio_file, translate, model_size, use_diarization, question):
99
+ _, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
100
  answer = answer_question(full_text, question)
101
  return answer
102
 
 
108
  audio_input = gr.Audio(type="filepath")
109
  translate_checkbox = gr.Checkbox(label="Enable Translation")
110
  model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
111
+ diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
112
  transcribe_button = gr.Button("Transcribe and Summarize")
113
  transcription_output = gr.Textbox(label="Transcription")
114
  summary_output = gr.Textbox(label="Summary")
115
 
116
  transcribe_button.click(
117
  process_and_summarize,
118
+ inputs=[audio_input, translate_checkbox, model_dropdown, diarization_checkbox],
119
  outputs=[transcription_output, summary_output]
120
  )
121
 
 
123
  qa_audio_input = gr.Audio(type="filepath")
124
  qa_translate_checkbox = gr.Checkbox(label="Enable Translation")
125
  qa_model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
126
+ qa_diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
127
  question_input = gr.Textbox(label="Ask a question about the audio")
128
  qa_button = gr.Button("Get Answer")
129
  answer_output = gr.Textbox(label="Answer")
130
 
131
  qa_button.click(
132
  qa_interface,
133
+ inputs=[qa_audio_input, qa_translate_checkbox, qa_model_dropdown, qa_diarization_checkbox, question_input],
134
  outputs=answer_output
135
  )
136