Kr08 commited on
Commit
51a5dfa
·
verified ·
1 Parent(s): 3a346c4

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +34 -5
audio_processing.py CHANGED
@@ -21,20 +21,27 @@ logger = logging.getLogger(__name__)
21
 
22
  # Global variables for models
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
- compute_type = "float16" if device == "cuda" else "float32"
25
  whisper_model = None
26
  diarization_pipeline = None
27
 
28
  def load_models(model_size="small"):
29
- global whisper_model, diarization_pipeline
30
 
31
  # Load Whisper model
32
- whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
 
 
 
 
 
 
33
 
34
  # Try to initialize diarization pipeline
35
  try:
36
  diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
37
- diarization_pipeline = diarization_pipeline.to(torch.device(device))
 
38
  except Exception as e:
39
  logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
40
  diarization_pipeline = None
@@ -136,4 +143,26 @@ def process_audio(audio_file, translate=False, model_size="small"):
136
  logger.error(f"An error occurred during audio processing: {str(e)}")
137
  raise
138
 
139
- # The merge_nearby_segments and print_results functions remain unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Global variables for models
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ compute_type = "float16" if device == "cuda" else "int8"
25
  whisper_model = None
26
  diarization_pipeline = None
27
 
28
  def load_models(model_size="small"):
29
+ global whisper_model, diarization_pipeline, device, compute_type
30
 
31
  # Load Whisper model
32
+ try:
33
+ whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
34
+ except RuntimeError as e:
35
+ logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
36
+ device = "cpu"
37
+ compute_type = "int8"
38
+ whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
39
 
40
  # Try to initialize diarization pipeline
41
  try:
42
  diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
43
+ if device == "cuda":
44
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
45
  except Exception as e:
46
  logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
47
  diarization_pipeline = None
 
143
  logger.error(f"An error occurred during audio processing: {str(e)}")
144
  raise
145
 
146
+ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
147
+ merged = []
148
+ for segment in segments:
149
+ if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
150
+ merged.append(segment)
151
+ else:
152
+ # Find the overlap
153
+ matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
154
+ match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
155
+
156
+ if match.size / len(segment['text']) > similarity_threshold:
157
+ # Merge the segments
158
+ merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
159
+ merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
160
+
161
+ merged[-1]['end'] = segment['end']
162
+ merged[-1]['text'] = merged_text
163
+ if 'translated' in segment:
164
+ merged[-1]['translated'] = merged_translated
165
+ else:
166
+ # If no significant overlap, append as a new segment
167
+ merged.append(segment)
168
+ return merged