ssolito commited on
Commit
5669eef
·
verified ·
1 Parent(s): ab6d2b5

Update whisper.py

Browse files
Files changed (1) hide show
  1. whisper.py +17 -2
whisper.py CHANGED
@@ -203,18 +203,30 @@ def processing_vad_threshold(audio, output_vad, threshold, max_duration, concate
203
 
204
  def format_audio(audio_path):
205
  input_audio, sample_rate = torchaudio.load(audio_path)
 
 
 
 
206
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
207
  input_audio = resampler(input_audio)
208
  input_audio = input_audio.squeeze().numpy()
209
  return(input_audio)
210
 
 
211
  def transcribe_pipeline(audio, task):
212
  text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
213
  return text
214
 
215
  def generate(audio_path, use_v5):
216
  audio = AudioSegment.from_wav(audio_path)
217
-
 
 
 
 
 
 
 
218
  output_vad = pipeline_vad(audio_path)
219
  concatenated_segment = AudioSegment.empty()
220
  max_duration = 0
@@ -226,5 +238,8 @@ def generate(audio_path, use_v5):
226
  output = transcribe_pipeline(format_audio(audio_path), task)
227
 
228
  clean_output = post_process_transcription(output)
229
-
 
 
 
230
  return clean_output
 
203
 
204
  def format_audio(audio_path):
205
  input_audio, sample_rate = torchaudio.load(audio_path)
206
+
207
+ if input_audio.shape[0] == 2: #stereo2mono
208
+ input_audio = torch.mean(input_audio, dim=0, keepdim=True)
209
+
210
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
211
  input_audio = resampler(input_audio)
212
  input_audio = input_audio.squeeze().numpy()
213
  return(input_audio)
214
 
215
+
216
  def transcribe_pipeline(audio, task):
217
  text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
218
  return text
219
 
220
  def generate(audio_path, use_v5):
221
  audio = AudioSegment.from_wav(audio_path)
222
+
223
+ temp_mono_path = None
224
+ if audio.channels != 1: #stereo2mono
225
+ audio = audio.set_channels(1)
226
+ temp_mono_path = "temp_mono.wav"
227
+ audio.export(temp_mono_path, format="wav")
228
+ audio_path = temp_mono_path
229
+
230
  output_vad = pipeline_vad(audio_path)
231
  concatenated_segment = AudioSegment.empty()
232
  max_duration = 0
 
238
  output = transcribe_pipeline(format_audio(audio_path), task)
239
 
240
  clean_output = post_process_transcription(output)
241
+
242
+ if temp_mono_path and os.path.exists(temp_mono_path):
243
+ os.remove(temp_mono_path)
244
+
245
  return clean_output