oza75 commited on
Commit
3b2d585
·
1 Parent(s): dbf668e

add multiples language choice

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -38,13 +38,14 @@ revision = None
38
  #language = "swahili"
39
 
40
  model_checkpoint = "oza75/bm-whisper-large-turbo-v4"
41
- language = "sundanese"
42
 
43
  # Load the custom tokenizer designed for Bambara and the ASR model
44
  #tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
45
- tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
46
  pipe = pipeline("automatic-speech-recognition", model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
47
 
 
48
 
49
  def resample_audio(audio_path, target_sample_rate=16000):
50
  """
@@ -66,7 +67,7 @@ def resample_audio(audio_path, target_sample_rate=16000):
66
  return waveform, target_sample_rate
67
 
68
  @spaces.GPU()
69
- def transcribe(audio, task_type):
70
  """
71
  Transcribes the provided audio file into text using the configured ASR pipeline.
72
 
@@ -79,6 +80,8 @@ def transcribe(audio, task_type):
79
  # Convert the audio to 16000 Hz
80
  waveform, sample_rate = resample_audio(audio)
81
 
 
 
82
  # Use the pipeline to perform transcription
83
  sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate}
84
  text = pipe(sample, generate_kwargs={"task": task_type, "language": language})["text"]
@@ -99,7 +102,7 @@ def get_wav_files(directory):
99
  files = os.listdir(directory)
100
  # Filter for .wav files and create absolute paths
101
  wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
102
- wav_files = [[f, "transcribe"] for f in wav_files]
103
 
104
  return wav_files
105
 
@@ -112,6 +115,7 @@ def main():
112
  fn=transcribe,
113
  inputs=[
114
  gr.Audio(type="filepath", value=example_files[0][0]),
 
115
  gr.Radio(choices=["transcribe"], label="Task Type", value="transcribe")
116
  ],
117
  outputs="text",
 
38
  #language = "swahili"
39
 
40
  model_checkpoint = "oza75/bm-whisper-large-turbo-v4"
41
+ # language = "sundanese"
42
 
43
  # Load the custom tokenizer designed for Bambara and the ASR model
44
  #tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
45
+ tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, device=device)
46
  pipe = pipeline("automatic-speech-recognition", model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
47
 
48
+ LANGUAGES = {"bambara": "sundanese", "french": "french", "english": "english"}
49
 
50
  def resample_audio(audio_path, target_sample_rate=16000):
51
  """
 
67
  return waveform, target_sample_rate
68
 
69
  @spaces.GPU()
70
+ def transcribe(audio, task_type, language):
71
  """
72
  Transcribes the provided audio file into text using the configured ASR pipeline.
73
 
 
80
  # Convert the audio to 16000 Hz
81
  waveform, sample_rate = resample_audio(audio)
82
 
83
+ language = LANGUAGES[language]
84
+
85
  # Use the pipeline to perform transcription
86
  sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate}
87
  text = pipe(sample, generate_kwargs={"task": task_type, "language": language})["text"]
 
102
  files = os.listdir(directory)
103
  # Filter for .wav files and create absolute paths
104
  wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
105
+ wav_files = [[f, "transcribe", "bambara"] for f in wav_files]
106
 
107
  return wav_files
108
 
 
115
  fn=transcribe,
116
  inputs=[
117
  gr.Audio(type="filepath", value=example_files[0][0]),
118
+ gr.Dropdown(choices=LANGUAGES.keys(), label="Language", value="bambara"),
119
  gr.Radio(choices=["transcribe"], label="Task Type", value="transcribe")
120
  ],
121
  outputs="text",