File size: 5,193 Bytes
4273f74
 
54440ac
 
6673c70
66b8805
54440ac
4273f74
 
564acd4
54440ac
 
 
 
 
964c1a5
4adc8b9
3bff29a
 
 
2968589
4b7fd50
7c907b8
b13531b
84d46a9
6dcf9ee
de130d1
6dcf9ee
0484c17
964c1a5
 
3bff29a
df2df62
964c1a5
564acd4
b1b5d1d
2cacb94
b1b5d1d
 
688f0fe
25354af
9c1d9a6
 
3b2d585
54440ac
 
df2df62
3b2d585
812710c
54440ac
ad4bd3d
54440ac
66b8805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54440ac
3b2d585
54440ac
 
 
 
 
 
 
 
 
66b8805
 
 
3b2d585
 
54440ac
d9fff62
4c75eaa
f49dcf2
e355c84
f49dcf2
 
4c75eaa
 
 
54440ac
4273f74
 
 
 
 
 
 
 
 
 
 
 
 
d9fff62
3b2d585
d9fff62
4273f74
54440ac
 
4273f74
0b8d1b9
4273f74
54440ac
 
 
2968589
d9fff62
25354af
3b2d585
402c03f
2968589
54440ac
 
4273f74
0b8d1b9
 
54440ac
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os

import spaces
import torch
from transformers import pipeline, WhisperTokenizer
import torchaudio
import gradio as gr
# Please note that the below import will override whisper LANGUAGES to add bambara
# this is not the best way to do it but at least it works. for more info check the bambara_utils code
from bambara_utils import BambaraWhisperTokenizer

# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the model checkpoint and language
# model_checkpoint = "oza75/whisper-bambara-asr-002" # first model
# revision = "831cd15ed74a554caac9f304cf50dc773841ba1b"
# model_checkpoint = "oza75/whisper-bambara-asr-005"
# revision = "6a92cd0f19985d12739c2f6864607627115e015d" # first good checkpoint for bambara

#revision = "fb69a5750182933868397543366dbb63747cf40c" # this only translate in english
#revision = "129f9e68ead6cc854e7754b737b93aa78e0e61e1" # support transcription and translation
#revision = "cb8e351b35d6dc524066679d9646f4a947300b27"
#revision = "5f143f6070b64412a44fea08e912e1b7312e9ae9" # this checkpoint support both task without overfitting

#model_checkpoint = "oza75/whisper-bambara-asr-006"
#revision = "96535debb4ce0b7af7c9c186d09d088825f63840"
#revision = "4549778c08f29ed2e033cc9a497a187488b6bf56"

# model_checkpoint = "oza75/bm-whisper-02"
# revision = "06e81aa0214f6d07d3d787b367e3e8357b171549"

# language = "bambara"
# language = "icelandic" # we use icelandic as the model was trained to replace the icelandic with bambara.

#model_checkpoint = "oza75/bm-whisper-from-swa-02"
revision = None
#language = "swahili"

#model_checkpoint = "oza75/bm-whisper-large-turbo-v4"
#model_checkpoint = "oza75/bm-whisper-large-v3-base"
# model_checkpoint = "oza75/bm-whisper-large-v3-sft-3"
model_checkpoint = "djelia/bm-whisper-large-v2-lora-merged"
# language = "sundanese"

# Load the custom tokenizer designed for Bambara and the ASR model
#tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, device=device)
pipe = pipeline("automatic-speech-recognition", model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)

LANGUAGES = {"bambara": "french", "french": "french", "english": "english"}

def resample_audio(audio_path, target_sample_rate=16000):
    """
    Converts the audio file to the target sampling rate (16000 Hz).
    
    Args:
        audio_path (str): Path to the audio file.
        target_sample_rate (int): The desired sample rate.

    Returns:
        A tensor containing the resampled audio data and the target sample rate.
    """
    waveform, original_sample_rate = torchaudio.load(audio_path)
    
    if original_sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    return waveform, target_sample_rate

@spaces.GPU()
def transcribe(audio, task_type, language):
    """
    Transcribes the provided audio file into text using the configured ASR pipeline.

    Args:
        audio: The path to the audio file to transcribe.

    Returns:
        A string representing the transcribed text.
    """
    # Convert the audio to 16000 Hz
    waveform, sample_rate = resample_audio(audio)
    
    language = LANGUAGES[language]
    
    # Use the pipeline to perform transcription
    sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate}
    text = pipe(sample, generate_kwargs={
        "task": task_type,
        "num_beams": 1, 
        "early_stopping": True, 
        "language": language
    })["text"]
    
    return text

def get_wav_files(directory):
    """
    Returns a list of absolute paths to all .wav files in the specified directory.

    Args:
        directory (str): The directory to search for .wav files.

    Returns:
        list: A list of absolute paths to the .wav files.
    """
    # List all files in the directory
    files = os.listdir(directory)
    # Filter for .wav files and create absolute paths
    wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
    wav_files = [[f, "transcribe", "bambara"] for f in wav_files]

    return wav_files

def main():
    # Get a list of all .wav files in the examples directory
    example_files = get_wav_files("./examples")

    # Setup Gradio interface
    iface = gr.Interface(
        fn=transcribe,
        inputs=[
            gr.Audio(type="filepath", value=example_files[0][0]),
            gr.Radio(choices=["transcribe", "translate"], label="Task Type", value="transcribe"),
            gr.Dropdown(choices=LANGUAGES.keys(), label="Language", value="bambara"),
            
        ],
        outputs="text",
        title="Bambara Automatic Speech Recognition",
        description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.",
        examples=example_files,
        cache_examples="lazy",
    )

    # Launch the interface
    iface.launch(share=False)


if __name__ == "__main__":
    main()