AbirMessaoudi ssolito commited on
Commit
c2fb837
·
verified ·
1 Parent(s): e43c4b1

Update whisper.py (#19)

Browse files

- Update whisper.py (4b3ffb498e507ff494dfeb10319aa5991b873e02)


Co-authored-by: Sarah Solito <[email protected]>

Files changed (1) hide show
  1. whisper.py +251 -90
whisper.py CHANGED
@@ -1,60 +1,108 @@
1
  from pydub import AudioSegment
2
  import os
3
- from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
4
  import torchaudio
5
  import torch
6
  import re
7
- from transformers import pipeline
8
- from peft import PeftModel, PeftConfig
 
 
9
  import spaces
 
10
 
11
  device = 0 if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float32
13
 
14
- ### Configuration
15
- MODEL_NAME_V2 = "./whisper-large-v3-catalan"
16
- MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar"
17
- CHUNK_LENGTH = 30
18
- BATCH_SIZE = 1
19
 
20
- pipe = pipeline(
21
- task="automatic-speech-recognition",
22
- model=MODEL_NAME_V1,
23
- chunk_length_s=30,
24
- device=device,
25
- token=os.getenv("HF_TOKEN")
26
- )
27
 
28
 
29
- peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2)
30
- model = WhisperForConditionalGeneration.from_pretrained(
31
- peft_config.base_model_name_or_path,
32
- device_map="auto"
33
- )
34
 
35
- task = "transcribe"
36
-
37
- model = PeftModel.from_pretrained(model, MODEL_NAME_V2)
38
- model.config.use_cache = True
39
-
40
- tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task)
41
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task)
42
- feature_extractor = processor.feature_extractor
43
- forced_decoder_ids = processor.get_decoder_prompt_ids(task=task)
44
 
45
- asr_pipe = pipeline(
46
- task="automatic-speech-recognition",
47
- model=model,
48
- tokenizer=tokenizer,
49
- feature_extractor=feature_extractor,
50
- chunk_length_s=30)
51
 
52
- def asr(audio_path, task):
53
- asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True)
54
- base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model
55
- return asr_result
56
 
57
- def post_process_transcription(transcription, max_repeats=2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
59
 
60
  cleaned_tokens = []
@@ -79,57 +127,167 @@ def post_process_transcription(transcription, max_repeats=2):
79
 
80
  return cleaned_transcription
81
 
 
 
 
82
 
83
- def format_audio(audio_path):
84
- input_audio, sample_rate = torchaudio.load(audio_path)
 
 
85
 
86
- if input_audio.shape[0] == 2: #stereo2mono
87
- input_audio = torch.mean(input_audio, dim=0, keepdim=True)
88
-
89
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
90
- input_audio = resampler(input_audio)
91
- input_audio = input_audio.squeeze().numpy()
92
- return(input_audio)
93
 
94
- def split_stereo_channels(audio_path):
95
 
96
- audio = AudioSegment.from_wav(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- channels = audio.split_to_mono()
99
- if len(channels) != 2:
100
- raise ValueError(f"Audio {audio_path} does not have 2 channels.")
101
 
102
- channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
103
- channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- def transcribe_pipeline(audio, task):
106
- text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
107
- return text
108
 
109
  def generate(audio_path, use_v2):
110
- task = "transcribe"
111
- temp_mono_path = None
112
 
113
  if use_v2:
 
114
  split_stereo_channels(audio_path)
115
 
116
  audio_id = os.path.splitext(os.path.basename(audio_path))[0]
117
 
118
  left_channel_path = "temp_mono_speaker2.wav"
119
  right_channel_path = "temp_mono_speaker1.wav"
120
-
121
- left_audio = format_audio(left_channel_path)
122
- right_audio = format_audio(right_channel_path)
123
-
124
- left_result = asr(left_audio, task)
125
- right_result = asr(right_audio, task)
126
-
127
- left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]]
128
- right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]]
129
-
130
- #merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0])
131
- merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0] if x[0] is not None else 0.0)
132
-
 
 
 
 
 
 
 
 
 
133
 
134
  output = ""
135
  for start, end, speaker, text in merged_transcript:
@@ -138,21 +296,24 @@ def generate(audio_path, use_v2):
138
  clean_output = output.strip()
139
 
140
  else:
141
- audio = AudioSegment.from_wav(audio_path)
142
-
143
- if audio.channels != 1: #stereo2mono
144
- audio = audio.set_channels(1)
145
- temp_mono_path = "temp_mono.wav"
146
- audio.export(temp_mono_path, format="wav")
147
- audio_path = temp_mono_path
148
- output = transcribe_pipeline(format_audio(audio_path), task)
149
- clean_output = post_process_transcription(output)
150
-
151
- if temp_mono_path and os.path.exists(temp_mono_path):
152
- os.remove(temp_mono_path)
153
-
154
- for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]:
155
- if os.path.exists(temp_file):
156
- os.remove(temp_file)
157
 
 
 
 
 
 
 
 
 
 
 
158
  return clean_output
 
1
  from pydub import AudioSegment
2
  import os
 
3
  import torchaudio
4
  import torch
5
  import re
6
+ from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
7
+ from pyannote.audio import Pipeline as DiarizationPipeline
8
+ import whisperx
9
+ import whisper_timestamped as whisper_ts
10
  import spaces
11
+ from typing import Dict
12
 
13
  device = 0 if torch.cuda.is_available() else "cpu"
14
  torch_dtype = torch.float32
15
 
16
+ MODEL_PATH_1 = "./whisper-large-v3-tiny-caesar"
17
+ MODEL_PATH_2 = "langtech-veu/whisper-timestamped-cs"
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
19
 
20
+ def clean_text(input_text):
21
+ remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
22
+ '*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…']
23
+ output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text)
24
+ return ' '.join(output_text.split()).lower()
 
 
25
 
26
 
27
+ def split_stereo_channels(audio_path):
 
 
 
 
28
 
29
+ audio = AudioSegment.from_wav(audio_path)
 
 
 
 
 
 
 
 
30
 
31
+ channels = audio.split_to_mono()
32
+ if len(channels) != 2:
33
+ raise ValueError(f"Audio {audio_path} does not have 2 channels.")
34
+
35
+ channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
36
+ channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
37
 
 
 
 
 
38
 
39
+ def convert_to_mono(input_path):
40
+ audio = AudioSegment.from_file(input_path)
41
+ base, ext = os.path.splitext(input_path)
42
+ output_path = f"{base}_merged.wav"
43
+ mono = audio.set_channels(1)
44
+ mono.export(output_path, format="wav")
45
+ return output_path
46
+
47
+ def save_temp_audio(waveform, sample_rate, path):
48
+ waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform
49
+ torchaudio.save(path, waveform, sample_rate)
50
+
51
+ def format_audio(audio_path):
52
+ input_audio, sample_rate = torchaudio.load(audio_path)
53
+ if input_audio.shape[0] == 2:
54
+ input_audio = torch.mean(input_audio, dim=0, keepdim=True)
55
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
56
+ input_audio = resampler(input_audio)
57
+ return input_audio.squeeze(), 16000
58
+
59
+ def assign_timestamps(asr_segments, audio_path):
60
+ waveform, sr = format_audio(audio_path)
61
+ total_duration = waveform.shape[-1] / sr
62
+
63
+ total_words = sum(len(seg["text"].split()) for seg in asr_segments)
64
+ if total_words == 0:
65
+ raise ValueError("Total number of words in ASR segments is zero. Cannot assign timestamps.")
66
+
67
+ avg_word_duration = total_duration / total_words
68
+
69
+ current_time = 0.0
70
+ for segment in asr_segments:
71
+ word_count = len(segment["text"].split())
72
+ segment_duration = word_count * avg_word_duration
73
+ segment["start"] = round(current_time, 3)
74
+ segment["end"] = round(current_time + segment_duration, 3)
75
+ current_time += segment_duration
76
+
77
+ return asr_segments
78
+
79
+ def hf_chunks_to_whisperx_segments(chunks):
80
+ return [
81
+ {
82
+ "text": chunk["text"],
83
+ "start": chunk["timestamp"][0],
84
+ "end": chunk["timestamp"][1],
85
+ }
86
+ for chunk in chunks
87
+ if chunk["timestamp"] and isinstance(chunk["timestamp"], (list, tuple))
88
+ ]
89
+
90
+ def align_words_to_segments(words, segments, window=5.0):
91
+ aligned = []
92
+ seg_idx = 0
93
+ for word in words:
94
+ while seg_idx < len(segments) and segments[seg_idx]["end"] < word["start"] - window:
95
+ seg_idx += 1
96
+ for j in range(seg_idx, len(segments)):
97
+ seg = segments[j]
98
+ if seg["start"] > word["end"] + window:
99
+ break
100
+ if seg["start"] <= word["start"] < seg["end"]:
101
+ aligned.append((word, seg))
102
+ break
103
+ return aligned
104
+
105
+ def post_process_transcription(transcription, max_repeats=2):
106
  tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
107
 
108
  cleaned_tokens = []
 
127
 
128
  return cleaned_transcription
129
 
130
+ def post_merge_consecutive_segments(input_file, output_file): #check
131
+ with open(input_file, "r") as f:
132
+ transcription_text = f.read()
133
 
134
+ segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
135
+ merged_transcription = ''
136
+ current_speaker = None
137
+ current_segment = []
138
 
139
+ for i in range(1, len(segments) - 1, 2):
140
+ speaker_tag = segments[i]
141
+ text = segments[i + 1].strip()
 
 
 
 
142
 
143
+ speaker = re.search(r'\d{2}', speaker_tag).group()
144
 
145
+ if speaker == current_speaker:
146
+ current_segment.append(text)
147
+ else:
148
+ if current_speaker is not None:
149
+ merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
150
+ current_speaker = speaker
151
+ current_segment = [text]
152
+
153
+ if current_speaker is not None:
154
+ merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
155
+
156
+ with open(output_file, "w") as f:
157
+ f.write(merged_transcription.strip())
158
+
159
+ def cleanup_temp_files(*file_paths):
160
+ for path in file_paths:
161
+ if path and os.path.exists(path):
162
+ os.remove(path)
163
+
164
+
165
+
166
+ def load_whisper_model(model_path: str):
167
+ device = "cuda" if torch.cuda.is_available() else "cpu"
168
+ model = whisper_ts.load_model(model_path, device=device)
169
+ return model
170
+
171
+ def transcribe_audio(model, audio_path: str) -> Dict:
172
+ try:
173
+ result = whisper_ts.transcribe(
174
+ model,
175
+ audio_path,
176
+ beam_size=5,
177
+ best_of=5,
178
+ temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
179
+ vad=False,
180
+ detect_disfluencies=True,
181
+ )
182
+
183
+ words = []
184
+ for segment in result.get('segments', []):
185
+ for word in segment.get('words', []):
186
+ word_text = word.get('word', '').strip()
187
+ if word_text.startswith(' '):
188
+ word_text = word_text[1:]
189
+
190
+ words.append({
191
+ 'word': word_text,
192
+ 'start': word.get('start', 0),
193
+ 'end': word.get('end', 0),
194
+ 'confidence': word.get('confidence', 0)
195
+ })
196
+
197
+ return {
198
+ 'audio_path': audio_path,
199
+ 'text': result['text'].strip(),
200
+ 'segments': result.get('segments', []),
201
+ 'words': words,
202
+ 'duration': result.get('duration', 0),
203
+ 'success': True
204
+ }
205
+
206
+ except Exception as e:
207
+ return {
208
+ 'audio_path': audio_path,
209
+ 'error': str(e),
210
+ 'success': False
211
+ }
212
+
213
+
214
+
215
+ diarization_pipeline = DiarizationPipeline.from_pretrained("pyannote/diarization_config.yaml")
216
+ align_model, metadata = whisperx.load_align_model(language_code="en", device=DEVICE)
217
+
218
+ asr_pipe = pipeline(
219
+ task="automatic-speech-recognition",
220
+ model=MODEL_PATH_1,
221
+ chunk_length_s=30,
222
+ device=DEVICE,
223
+ return_timestamps=True)
224
+
225
+ def diarization(audio_path):
226
+ diarization_result = diarization_pipeline(audio_path)
227
+ diarized_segments = list(diarization_result.itertracks(yield_label=True))
228
+ return diarized_segments
229
+
230
+ def asr(audio_path):
231
+ asr_result = asr_pipe(audio_path, return_timestamps=True)
232
+ asr_segments = hf_chunks_to_whisperx_segments(asr_result['chunks'])
233
+ asr_segments = assign_timestamps(asr_segments, audio_path)
234
+ return asr_segments
235
+
236
+ def align_asr_to_diarization(asr_segments, diarized_segments, audio_path):
237
+ waveform, sample_rate = format_audio(audio_path)
238
 
239
+ word_segments = whisperx.align(asr_segments, align_model, metadata, waveform, DEVICE)
240
+ words = word_segments['word_segments']
 
241
 
242
+ diarized = [{"start": segment.start,"end": segment.end,"speaker": speaker} for segment, _, speaker in diarized_segments]
243
+
244
+ aligned_pairs = align_words_to_segments(words, diarized)
245
+
246
+ output = []
247
+ segment_map = {}
248
+ for word, segment in aligned_pairs:
249
+ key = (segment["start"], segment["end"], segment["speaker"])
250
+ if key not in segment_map:
251
+ segment_map[key] = []
252
+ segment_map[key].append(word["word"])
253
+
254
+ for (start, end, speaker), words in sorted(segment_map.items()):
255
+ output.append(f"[{speaker}] {' '.join(words)}")
256
 
257
+ return output
 
 
258
 
259
  def generate(audio_path, use_v2):
 
 
260
 
261
  if use_v2:
262
+ model = load_whisper_model(MODEL_PATH_2)
263
  split_stereo_channels(audio_path)
264
 
265
  audio_id = os.path.splitext(os.path.basename(audio_path))[0]
266
 
267
  left_channel_path = "temp_mono_speaker2.wav"
268
  right_channel_path = "temp_mono_speaker1.wav"
269
+
270
+ left_waveform, left_sr = format_audio(left_channel_path)
271
+ right_waveform, right_sr = format_audio(right_channel_path)
272
+ left_result = transcribe_audio(model, left_waveform)
273
+ right_result = transcribe_audio(model, right_waveform)
274
+
275
+ def get_segments(result, speaker_label):
276
+ segments = result.get("segments", [])
277
+ if not segments:
278
+ return []
279
+ return [
280
+ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip()))
281
+ for seg in segments if seg.get("text")
282
+ ]
283
+
284
+ left_segs = get_segments(left_result, "Speaker 1")
285
+ right_segs = get_segments(right_result, "Speaker 2")
286
+
287
+ merged_transcript = sorted(
288
+ left_segs + right_segs,
289
+ key=lambda x: float(x[0]) if x[0] is not None else float("inf")
290
+ )
291
 
292
  output = ""
293
  for start, end, speaker, text in merged_transcript:
 
296
  clean_output = output.strip()
297
 
298
  else:
299
+ mono_audio_path = convert_to_mono(audio_path)
300
+ waveform, sr = format_audio(mono_audio_path)
301
+ tmp_full_path = "tmp_full.wav"
302
+ save_temp_audio(waveform, sr, tmp_full_path)
303
+ diarized_segments = diarization(tmp_full_path)
304
+ asr_segments = asr(tmp_full_path)
305
+ for segment in asr_segments:
306
+ segment["text"] = post_process_transcription(segment["text"])
307
+ aligned_text = align_asr_to_diarization(asr_segments, diarized_segments, tmp_full_path)
 
 
 
 
 
 
 
308
 
309
+ clean_output = ""
310
+ for line in aligned_text:
311
+ clean_output += f"{line}\n"
312
+ cleanup_temp_files(mono_audio_path,tmp_full_path)
313
+
314
+ cleanup_temp_files(
315
+ "temp_mono_speaker1.wav",
316
+ "temp_mono_speaker2.wav"
317
+ )
318
+
319
  return clean_output