|
import gradio as gr |
|
from datasets import Audio |
|
from datasets import load_dataset |
|
from jiwer import wer, cer |
|
from transformers import pipeline |
|
|
|
from arabic_normalizer import ArabicTextNormalizer |
|
|
|
|
|
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", trust_remote_code = True, name = "ar", |
|
split = "train") |
|
|
|
common_voice = common_voice.select_columns(["audio", "sentence"]) |
|
|
|
generate_kwargs = { |
|
"language": "arabic", |
|
"task": "transcribe" |
|
} |
|
|
|
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3", |
|
generate_kwargs = generate_kwargs) |
|
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo", |
|
generate_kwargs = generate_kwargs) |
|
asr_whisper_large_turbo_mboushaba = pipeline("automatic-speech-recognition", model = |
|
"mboushaba/whisper-large-v3-turbo-arabic", |
|
generate_kwargs = generate_kwargs) |
|
normalizer = ArabicTextNormalizer() |
|
|
|
|
|
def generate_audio(index = None): |
|
"""Select an audio sample, resample if needed, and transcribe using ASR.""" |
|
|
|
|
|
|
|
|
|
|
|
global common_voice |
|
common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000)) |
|
|
|
|
|
|
|
example = common_voice.shuffle()[0] |
|
audio = example["audio"] |
|
|
|
|
|
reference_text = normalizer(example["sentence"]) |
|
|
|
|
|
audio_data = { |
|
"array": audio["array"], |
|
"sampling_rate": audio["sampling_rate"] |
|
} |
|
|
|
audio_data_turbo = { |
|
"raw": audio["array"], |
|
"sampling_rate": audio["sampling_rate"] |
|
} |
|
|
|
audio_data_turbo_mboushaba = { |
|
"raw": audio["array"], |
|
"sampling_rate": audio["sampling_rate"] |
|
} |
|
|
|
|
|
asr_output = asr_whisper_large(audio_data) |
|
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo) |
|
asr_output_turbo_mboushaba = asr_whisper_large_turbo_mboushaba(audio_data_turbo_mboushaba) |
|
|
|
|
|
predicted_text = normalizer(asr_output["text"]) |
|
predicted_text_turbo = normalizer(asr_output_turbo["text"]) |
|
predicted_text_turbo_mboushaba = normalizer(asr_output_turbo_mboushaba["text"]) |
|
|
|
|
|
wer_score = wer(reference_text, predicted_text) |
|
cer_score = cer(reference_text, predicted_text) |
|
|
|
wer_score_turbo = wer(reference_text, predicted_text_turbo) |
|
cer_score_turbo = cer(reference_text, predicted_text_turbo) |
|
|
|
wer_score_turbo_mboushaba = wer(reference_text, predicted_text_turbo_mboushaba) |
|
cer_score_turbo_mboushaba = cer(reference_text, predicted_text_turbo_mboushaba) |
|
|
|
|
|
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])]) |
|
|
|
return { |
|
"audio": ( |
|
audio["sampling_rate"], |
|
audio["array"] |
|
), |
|
"sentence_info": sentence_info, |
|
"predicted_text": predicted_text, |
|
"wer_score": wer_score, |
|
"cer_score": cer_score, |
|
"predicted_text_turbo": predicted_text_turbo, |
|
"wer_score_turbo": wer_score_turbo, |
|
"cer_score_turbo": cer_score_turbo, |
|
"predicted_text_turbo_mboushaba": predicted_text_turbo_mboushaba, |
|
"wer_score_turbo_mboushaba": wer_score_turbo_mboushaba, |
|
"cer_score_turbo_mboushaba": cer_score_turbo_mboushaba |
|
} |
|
|
|
|
|
def update_ui(): |
|
res = [] |
|
for i in range(4): |
|
res.append(gr.Textbox(label = f"Label {i}")) |
|
return res |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(""" |
|
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""") |
|
gr.Markdown(""" |
|
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using |
|
arabic dataset from mozilla-foundation/common_voice_11_0 |
|
""") |
|
num_samples_input = gr.Slider(minimum = 1, maximum = 10, step = 1, value = 4, label = "Number of audio samples") |
|
generate_button = gr.Button("Generate Samples") |
|
|
|
|
|
@gr.render(inputs = num_samples_input, triggers = [generate_button.click]) |
|
def render(num_samples): |
|
with gr.Column(): |
|
for i in range(num_samples): |
|
|
|
data = generate_audio() |
|
|
|
|
|
gr.Audio(data["audio"], label = data["sentence_info"]) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Textbox(value = data["predicted_text"], label = "Whisper large output"), |
|
gr.Textbox(value = f"WER: {data['wer_score']:.2f}", label = "Word Error Rate"), |
|
gr.Textbox(value = f"CER: {data['cer_score']:.2f}", label = "Character Error Rate"), |
|
with gr.Column(): |
|
gr.Textbox(value = data["predicted_text_turbo"], label = "Whisper large turbo output"), |
|
gr.Textbox(value = f"WER: {data['wer_score_turbo']:.2f}", label = "Word Error Rate - " |
|
"TURBO "), |
|
gr.Textbox(value = f"CER: {data['cer_score_turbo']:.2f}", label = "Character Error " |
|
"Rate - TURBO") |
|
with gr.Column(): |
|
gr.Textbox(value = data["predicted_text_turbo_mboushaba"], label = "Whisper large turbo " |
|
"mboushaba output"), |
|
gr.Textbox(value = f"WER: {data['wer_score_turbo_mboushaba']:.2f}", label = "Word Error Rate - " |
|
" mboushaba TURBO "), |
|
gr.Textbox(value = f"CER: {data['cer_score_turbo_mboushaba']:.2f}", label = "Character Error " |
|
"Rate - mboushaba TURBO") |
|
|
|
demo.launch(show_error = True) |
|
|