mboushaba's picture
Update app.py
e779c90 verified
raw
history blame
6.67 kB
import gradio as gr
from datasets import Audio
from datasets import load_dataset
from jiwer import wer, cer
from transformers import pipeline
from arabic_normalizer import ArabicTextNormalizer
# Load dataset
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", trust_remote_code = True, name = "ar",
split = "train")
# select column that will be used
common_voice = common_voice.select_columns(["audio", "sentence"])
generate_kwargs = {
"language": "arabic",
"task": "transcribe"
}
# Initialize ASR pipeline
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3",
generate_kwargs = generate_kwargs)
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
generate_kwargs = generate_kwargs)
asr_whisper_large_turbo_mboushaba = pipeline("automatic-speech-recognition", model =
"mboushaba/whisper-large-v3-turbo-arabic",
generate_kwargs = generate_kwargs)
normalizer = ArabicTextNormalizer()
def generate_audio(index = None):
"""Select an audio sample, resample if needed, and transcribe using ASR."""
# inspect dataset
# print(common_voice)
# print(common_voice.features)
# resample audio using dataset function
global common_voice
common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))
# print(common_voice.features)
# Randomly shuffle the dataset and pick the first sample
example = common_voice.shuffle()[0]
audio = example["audio"]
# Ground truth transcription (for WER/CER calculations)
reference_text = normalizer(example["sentence"])
# Prepare audio data for ASR
audio_data = {
"array": audio["array"],
"sampling_rate": audio["sampling_rate"]
}
audio_data_turbo = {
"raw": audio["array"],
"sampling_rate": audio["sampling_rate"]
}
audio_data_turbo_mboushaba = {
"raw": audio["array"],
"sampling_rate": audio["sampling_rate"]
}
# Perform automatic speech recognition (ASR) directly on the resampled audio array
asr_output = asr_whisper_large(audio_data)
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
asr_output_turbo_mboushaba = asr_whisper_large_turbo_mboushaba(audio_data_turbo_mboushaba)
# Extract the transcription from the ASR model output
predicted_text = normalizer(asr_output["text"])
predicted_text_turbo = normalizer(asr_output_turbo["text"])
predicted_text_turbo_mboushaba = normalizer(asr_output_turbo_mboushaba["text"])
# Compute WER, Word Accuracy, and CER
wer_score = wer(reference_text, predicted_text)
cer_score = cer(reference_text, predicted_text)
wer_score_turbo = wer(reference_text, predicted_text_turbo)
cer_score_turbo = cer(reference_text, predicted_text_turbo)
wer_score_turbo_mboushaba = wer(reference_text, predicted_text_turbo_mboushaba)
cer_score_turbo_mboushaba = cer(reference_text, predicted_text_turbo_mboushaba)
# Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
return {
"audio": (
audio["sampling_rate"],
audio["array"]
),
"sentence_info": sentence_info,
"predicted_text": predicted_text,
"wer_score": wer_score,
"cer_score": cer_score,
"predicted_text_turbo": predicted_text_turbo,
"wer_score_turbo": wer_score_turbo,
"cer_score_turbo": cer_score_turbo,
"predicted_text_turbo_mboushaba": predicted_text_turbo_mboushaba,
"wer_score_turbo_mboushaba": wer_score_turbo_mboushaba,
"cer_score_turbo_mboushaba": cer_score_turbo_mboushaba
}
def update_ui():
res = []
for i in range(4):
res.append(gr.Textbox(label = f"Label {i}"))
return res
with gr.Blocks() as demo:
gr.HTML("""
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
gr.Markdown("""
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
arabic dataset from mozilla-foundation/common_voice_11_0
""")
num_samples_input = gr.Slider(minimum = 1, maximum = 10, step = 1, value = 4, label = "Number of audio samples")
generate_button = gr.Button("Generate Samples")
@gr.render(inputs = num_samples_input, triggers = [generate_button.click])
def render(num_samples):
with gr.Column():
for i in range(num_samples):
# Generate audio and associated data
data = generate_audio()
# Create Gradio components to display the audio, transcription, and metrics
gr.Audio(data["audio"], label = data["sentence_info"])
with gr.Row():
with gr.Column():
gr.Textbox(value = data["predicted_text"], label = "Whisper large output"),
gr.Textbox(value = f"WER: {data['wer_score']:.2f}", label = "Word Error Rate"),
gr.Textbox(value = f"CER: {data['cer_score']:.2f}", label = "Character Error Rate"),
with gr.Column():
gr.Textbox(value = data["predicted_text_turbo"], label = "Whisper large turbo output"),
gr.Textbox(value = f"WER: {data['wer_score_turbo']:.2f}", label = "Word Error Rate - "
"TURBO "),
gr.Textbox(value = f"CER: {data['cer_score_turbo']:.2f}", label = "Character Error "
"Rate - TURBO")
with gr.Column():
gr.Textbox(value = data["predicted_text_turbo_mboushaba"], label = "Whisper large turbo "
"mboushaba output"),
gr.Textbox(value = f"WER: {data['wer_score_turbo_mboushaba']:.2f}", label = "Word Error Rate - "
" mboushaba TURBO "),
gr.Textbox(value = f"CER: {data['cer_score_turbo_mboushaba']:.2f}", label = "Character Error "
"Rate - mboushaba TURBO")
demo.launch(show_error = True)