import os import spaces import torch from transformers import pipeline import gradio as gr # Please note that the below import will override whisper LANGUAGES to add bambara # this is not the best way to do it but at least it works. for more info check the bambara_utils code from bambara_utils import BambaraWhisperTokenizer # Determine the appropriate device (GPU or CPU) device = "cuda" if torch.cuda.is_available() else "cpu" # Define the model checkpoint and language model_checkpoint = "oza75/whisper-bambara-asr-001" language = "bambara" # Load the custom tokenizer designed for Bambara and the ASR model tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device) pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device) @spaces.GPU() def transcribe(audio): """ Transcribes the provided audio file into text using the configured ASR pipeline. Args: audio: The path to the audio file to transcribe. Returns: A string representing the transcribed text. """ # Use the pipeline to perform transcription text = pipe(audio)["text"] return text def get_wav_files(directory): """ Returns a list of absolute paths to all .wav files in the specified directory. Args: directory (str): The directory to search for .wav files. Returns: list: A list of absolute paths to the .wav files. """ # List all files in the directory files = os.listdir(directory) # Filter for .wav files and create absolute paths wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')] return wav_files def main(): # Get a list of all .wav files in the examples directory example_files = get_wav_files("./examples") # Setup Gradio interface iface = gr.Interface( fn=transcribe, inputs=gr.Audio(type="filepath", value=example_files[0]), outputs="text", title="Bambara Automatic Speech Recognition", description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.", examples=example_files, cache_examples="lazy", ) # Launch the interface iface.launch(share=False) if __name__ == "__main__": main()