File size: 3,500 Bytes
4273f74
 
54440ac
 
6673c70
66b8805
54440ac
4273f74
 
564acd4
54440ac
 
 
 
 
df2df62
 
 
46f81ad
df2df62
 
564acd4
54440ac
 
df2df62
 
010ad6d
54440ac
 
66b8805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54440ac
 
 
 
 
 
 
 
 
 
 
66b8805
 
 
54440ac
66b8805
 
54440ac
 
4273f74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54440ac
 
4273f74
0b8d1b9
4273f74
54440ac
 
 
a71de9c
54440ac
 
4273f74
0b8d1b9
 
54440ac
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os

import spaces
import torch
from transformers import pipeline, WhisperTokenizer
import torchaudio
import gradio as gr
# Please note that the below import will override whisper LANGUAGES to add bambara
# this is not the best way to do it but at least it works. for more info check the bambara_utils code
from bambara_utils import BambaraWhisperTokenizer

# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the model checkpoint and language
#model_checkpoint = "oza75/whisper-bambara-asr-002"
#revision = "831cd15ed74a554caac9f304cf50dc773841ba1b"
model_checkpoint = "oza75/whisper-bambara-asr-004"
revision = "bd0d0e8951879eb873d2f1ef278a61f7cb25d4a1"
# language = "bambara"
language = "icelandic"


# Load the custom tokenizer designed for Bambara and the ASR model
#tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)


def resample_audio(audio_path, target_sample_rate=16000):
    """
    Converts the audio file to the target sampling rate (16000 Hz).
    
    Args:
        audio_path (str): Path to the audio file.
        target_sample_rate (int): The desired sample rate.

    Returns:
        A tensor containing the resampled audio data and the target sample rate.
    """
    waveform, original_sample_rate = torchaudio.load(audio_path)
    
    if original_sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    return waveform, target_sample_rate

@spaces.GPU()
def transcribe(audio):
    """
    Transcribes the provided audio file into text using the configured ASR pipeline.

    Args:
        audio: The path to the audio file to transcribe.

    Returns:
        A string representing the transcribed text.
    """
    # Convert the audio to 16000 Hz
    waveform, sample_rate = resample_audio(audio)
    
    # Use the pipeline to perform transcription
    text = pipe({"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate})["text"]
    
    return text

def get_wav_files(directory):
    """
    Returns a list of absolute paths to all .wav files in the specified directory.

    Args:
        directory (str): The directory to search for .wav files.

    Returns:
        list: A list of absolute paths to the .wav files.
    """
    # List all files in the directory
    files = os.listdir(directory)
    # Filter for .wav files and create absolute paths
    wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
    return wav_files

def main():
    # Get a list of all .wav files in the examples directory
    example_files = get_wav_files("./examples")

    # Setup Gradio interface
    iface = gr.Interface(
        fn=transcribe,
        inputs=gr.Audio(type="filepath", value=example_files[0]),
        outputs="text",
        title="Bambara Automatic Speech Recognition",
        description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.",
        examples=example_files,
        cache_examples="lazy",
    )

    # Launch the interface
    iface.launch(share=False)


if __name__ == "__main__":
    main()