Kr08 commited on
Commit
6e73abb
·
verified ·
1 Parent(s): 64f2bf5

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +37 -46
audio_processing.py CHANGED
@@ -19,6 +19,26 @@ OVERLAP = 2
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
23
  chunks = []
24
  for i in range(0, len(audio), chunk_size - overlap):
@@ -28,24 +48,25 @@ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000
28
  chunks.append(chunk)
29
  return chunks
30
 
31
- @spaces.GPU(duration=120)
32
  def process_audio(audio_file, translate=False, model_size="small"):
 
 
 
 
 
33
  start_time = time.time()
34
 
35
  try:
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
- compute_type = "float16" if device == "cuda" else "float32"
38
  audio = whisperx.load_audio(audio_file)
39
- model = whisperx.load_model(model_size, device, compute_type=compute_type)
40
 
41
- # Try to initialize diarization pipeline, but proceed without it if there's an error
42
- try:
43
- diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
44
- diarization_pipeline = diarization_pipeline.to(torch.device(device))
45
- diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
46
- except Exception as e:
47
- logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Proceeding without diarization.")
48
- diarization_result = None
49
 
50
  chunks = preprocess_audio(audio)
51
 
@@ -57,10 +78,10 @@ def process_audio(audio_file, translate=False, model_size="small"):
57
  chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
58
  chunk_end_time = chunk_start_time + CHUNK_LENGTH
59
  logger.info(f"Processing chunk {i+1}/{len(chunks)}")
60
- lang = model.detect_language(chunk)
61
- result_transcribe = model.transcribe(chunk, language=lang)
62
  if translate:
63
- result_translate = model.transcribe(chunk, task="translate")
64
  chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
65
  for j, t_seg in enumerate(result_transcribe["segments"]):
66
  segment_start = chunk_start_time + t_seg["start"]
@@ -115,34 +136,4 @@ def process_audio(audio_file, translate=False, model_size="small"):
115
  logger.error(f"An error occurred during audio processing: {str(e)}")
116
  raise
117
 
118
- def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
119
- merged = []
120
- for segment in segments:
121
- if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
122
- merged.append(segment)
123
- else:
124
- # Find the overlap
125
- matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
126
- match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
127
-
128
- if match.size / len(segment['text']) > similarity_threshold:
129
- # Merge the segments
130
- merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
131
- merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
132
-
133
- merged[-1]['end'] = segment['end']
134
- merged[-1]['text'] = merged_text
135
- if 'translated' in segment:
136
- merged[-1]['translated'] = merged_translated
137
- else:
138
- # If no significant overlap, append as a new segment
139
- merged.append(segment)
140
- return merged
141
-
142
- def print_results(segments):
143
- for segment in segments:
144
- print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
145
- print(f"Original: {segment['text']}")
146
- if 'translated' in segment:
147
- print(f"Translated: {segment['translated']}")
148
- print()
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ # Global variables for models
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ compute_type = "float16" if device == "cuda" else "float32"
25
+ whisper_model = None
26
+ diarization_pipeline = None
27
+
28
+ def load_models(model_size="small"):
29
+ global whisper_model, diarization_pipeline
30
+
31
+ # Load Whisper model
32
+ whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
33
+
34
+ # Try to initialize diarization pipeline
35
+ try:
36
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
37
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
38
+ except Exception as e:
39
+ logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
40
+ diarization_pipeline = None
41
+
42
  def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
43
  chunks = []
44
  for i in range(0, len(audio), chunk_size - overlap):
 
48
  chunks.append(chunk)
49
  return chunks
50
 
51
+ @spaces.GPU
52
  def process_audio(audio_file, translate=False, model_size="small"):
53
+ global whisper_model, diarization_pipeline
54
+
55
+ if whisper_model is None:
56
+ load_models(model_size)
57
+
58
  start_time = time.time()
59
 
60
  try:
 
 
61
  audio = whisperx.load_audio(audio_file)
 
62
 
63
+ # Perform diarization if pipeline is available
64
+ diarization_result = None
65
+ if diarization_pipeline is not None:
66
+ try:
67
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
68
+ except Exception as e:
69
+ logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
 
70
 
71
  chunks = preprocess_audio(audio)
72
 
 
78
  chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
79
  chunk_end_time = chunk_start_time + CHUNK_LENGTH
80
  logger.info(f"Processing chunk {i+1}/{len(chunks)}")
81
+ lang = whisper_model.detect_language(chunk)
82
+ result_transcribe = whisper_model.transcribe(chunk, language=lang)
83
  if translate:
84
+ result_translate = whisper_model.transcribe(chunk, task="translate")
85
  chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
86
  for j, t_seg in enumerate(result_transcribe["segments"]):
87
  segment_start = chunk_start_time + t_seg["start"]
 
136
  logger.error(f"An error occurred during audio processing: {str(e)}")
137
  raise
138
 
139
+ # The merge_nearby_segments and print_results functions remain unchanged