Spaces:
Sleeping
Sleeping
File size: 4,525 Bytes
4273f74 54440ac 6673c70 66b8805 54440ac 4273f74 564acd4 54440ac 4adc8b9 3bff29a 2968589 4b7fd50 7c907b8 b13531b 84d46a9 6dcf9ee de130d1 6dcf9ee 0484c17 4adc8b9 3bff29a df2df62 2968589 564acd4 54440ac df2df62 44671a3 812710c 54440ac 66b8805 54440ac 2968589 54440ac 66b8805 54440ac d9fff62 934fa9b 66b8805 54440ac 4273f74 d9fff62 4273f74 54440ac 4273f74 0b8d1b9 4273f74 54440ac 2968589 d9fff62 b13531b 2968589 54440ac 4273f74 0b8d1b9 54440ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import spaces
import torch
from transformers import pipeline, WhisperTokenizer
import torchaudio
import gradio as gr
# Please note that the below import will override whisper LANGUAGES to add bambara
# this is not the best way to do it but at least it works. for more info check the bambara_utils code
from bambara_utils import BambaraWhisperTokenizer
# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define the model checkpoint and language
# model_checkpoint = "oza75/whisper-bambara-asr-002"
# revision = "831cd15ed74a554caac9f304cf50dc773841ba1b"
# model_checkpoint = "oza75/whisper-bambara-asr-005"
# revision = "6a92cd0f19985d12739c2f6864607627115e015d" # first good checkpoint for bambara
#revision = "fb69a5750182933868397543366dbb63747cf40c" # this only translate in english
#revision = "129f9e68ead6cc854e7754b737b93aa78e0e61e1" # support transcription and translation
#revision = "cb8e351b35d6dc524066679d9646f4a947300b27"
#revision = "5f143f6070b64412a44fea08e912e1b7312e9ae9" # this checkpoint support both task without overfitting
#model_checkpoint = "oza75/whisper-bambara-asr-006"
#revision = "96535debb4ce0b7af7c9c186d09d088825f63840"
#revision = "4549778c08f29ed2e033cc9a497a187488b6bf56"
model_checkpoint = "oza75/bm-whisper-02"
revision = "06e81aa0214f6d07d3d787b367e3e8357b171549"
# language = "bambara"
language = "icelandic" # we use icelandic as the model was trained to replace the icelandic with bambara.
# Load the custom tokenizer designed for Bambara and the ASR model
#tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
pipe = pipeline("automatic-speech-recognition", model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
def resample_audio(audio_path, target_sample_rate=16000):
"""
Converts the audio file to the target sampling rate (16000 Hz).
Args:
audio_path (str): Path to the audio file.
target_sample_rate (int): The desired sample rate.
Returns:
A tensor containing the resampled audio data and the target sample rate.
"""
waveform, original_sample_rate = torchaudio.load(audio_path)
if original_sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
@spaces.GPU()
def transcribe(audio, task_type):
"""
Transcribes the provided audio file into text using the configured ASR pipeline.
Args:
audio: The path to the audio file to transcribe.
Returns:
A string representing the transcribed text.
"""
# Convert the audio to 16000 Hz
waveform, sample_rate = resample_audio(audio)
# Use the pipeline to perform transcription
sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate}
text = pipe(sample, generate_kwargs={"task": task_type, "language": language})["text"]
return text
def get_wav_files(directory):
"""
Returns a list of absolute paths to all .wav files in the specified directory.
Args:
directory (str): The directory to search for .wav files.
Returns:
list: A list of absolute paths to the .wav files.
"""
# List all files in the directory
files = os.listdir(directory)
# Filter for .wav files and create absolute paths
wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
wav_files = [[f, "transcribe"] for f in wav_files]
return wav_files
def main():
# Get a list of all .wav files in the examples directory
example_files = get_wav_files("./examples")
# Setup Gradio interface
iface = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(type="filepath", value=example_files[0][0]),
gr.Radio(choices=["transcribe"], label="Task Type", value="transcribe")
],
outputs="text",
title="Bambara Automatic Speech Recognition",
description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.",
examples=example_files,
cache_examples="lazy",
)
# Launch the interface
iface.launch(share=False)
if __name__ == "__main__":
main()
|