oza75 commited on
Commit
66b8805
·
verified ·
1 Parent(s): 46f81ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import spaces
4
  import torch
5
  from transformers import pipeline, WhisperTokenizer
 
6
  import gradio as gr
7
  # Please note that the below import will override whisper LANGUAGES to add bambara
8
  # this is not the best way to do it but at least it works. for more info check the bambara_utils code
@@ -26,6 +27,25 @@ tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language
26
  pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @spaces.GPU()
30
  def transcribe(audio):
31
  """
@@ -37,8 +57,12 @@ def transcribe(audio):
37
  Returns:
38
  A string representing the transcribed text.
39
  """
 
 
 
40
  # Use the pipeline to perform transcription
41
- text = pipe(audio)["text"]
 
42
  return text
43
 
44
  def get_wav_files(directory):
 
3
  import spaces
4
  import torch
5
  from transformers import pipeline, WhisperTokenizer
6
+ import torchaudio
7
  import gradio as gr
8
  # Please note that the below import will override whisper LANGUAGES to add bambara
9
  # this is not the best way to do it but at least it works. for more info check the bambara_utils code
 
27
  pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
28
 
29
 
30
+ def resample_audio(audio_path, target_sample_rate=16000):
31
+ """
32
+ Converts the audio file to the target sampling rate (16000 Hz).
33
+
34
+ Args:
35
+ audio_path (str): Path to the audio file.
36
+ target_sample_rate (int): The desired sample rate.
37
+
38
+ Returns:
39
+ A tensor containing the resampled audio data and the target sample rate.
40
+ """
41
+ waveform, original_sample_rate = torchaudio.load(audio_path)
42
+
43
+ if original_sample_rate != target_sample_rate:
44
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
45
+ waveform = resampler(waveform)
46
+
47
+ return waveform, target_sample_rate
48
+
49
  @spaces.GPU()
50
  def transcribe(audio):
51
  """
 
57
  Returns:
58
  A string representing the transcribed text.
59
  """
60
+ # Convert the audio to 16000 Hz
61
+ waveform, sample_rate = resample_audio(audio)
62
+
63
  # Use the pipeline to perform transcription
64
+ text = pipe({"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate})["text"]
65
+
66
  return text
67
 
68
  def get_wav_files(directory):