File size: 2,323 Bytes
4273f74
 
54440ac
 
 
 
4273f74
 
54440ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4273f74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54440ac
 
4273f74
0b8d1b9
4273f74
54440ac
 
 
0b8d1b9
54440ac
 
4273f74
0b8d1b9
 
54440ac
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os

import spaces
import torch
from transformers import pipeline
import gradio as gr
# Please note that the below import will override whisper LANGUAGES to add bambara
# this is not the best way to do it but at least it works. for more info check the bambara_utils code
from bambara_utils import BambaraWhisperTokenizer

# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the model checkpoint and language
model_checkpoint = "oza75/whisper-bambara-asr-001"
language = "bambara"

# Load the custom tokenizer designed for Bambara and the ASR model
tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device)


@spaces.GPU()
def transcribe(audio):
    """
    Transcribes the provided audio file into text using the configured ASR pipeline.

    Args:
        audio: The path to the audio file to transcribe.

    Returns:
        A string representing the transcribed text.
    """
    # Use the pipeline to perform transcription
    text = pipe(audio)["text"]
    return text

def get_wav_files(directory):
    """
    Returns a list of absolute paths to all .wav files in the specified directory.

    Args:
        directory (str): The directory to search for .wav files.

    Returns:
        list: A list of absolute paths to the .wav files.
    """
    # List all files in the directory
    files = os.listdir(directory)
    # Filter for .wav files and create absolute paths
    wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
    return wav_files

def main():
    # Get a list of all .wav files in the examples directory
    example_files = get_wav_files("./examples")

    # Setup Gradio interface
    iface = gr.Interface(
        fn=transcribe,
        inputs=gr.Audio(type="filepath", value=example_files[0]),
        outputs="text",
        title="Bambara Automatic Speech Recognition",
        description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.",
        examples=example_files,
        cache_examples="lazy",
    )

    # Launch the interface
    iface.launch(share=False)


if __name__ == "__main__":
    main()