Hammad712's picture
Update app.py
f19273b verified
import gradio as gr
import torch
import librosa
import os
import uuid
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import Levenshtein
from pathlib import Path
# Load the processor and model for Wav2Vec2 once
def load_model():
MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
return processor, model
processor, model = load_model()
def save_audio(audio_data, folder="recorded_audios"):
"""
Saves the recorded audio data to a file in the specified folder.
Args:
audio_data (str): The file path of the audio file.
folder (str): The directory where the audio file will be saved.
Returns:
str: The file path of the saved audio file.
"""
# Ensure the folder exists
Path(folder).mkdir(parents=True, exist_ok=True)
# Generate a unique filename
filename = f"{uuid.uuid4()}.wav"
file_path = os.path.join(folder, filename)
# Move the audio file to the desired folder
os.rename(audio_data, file_path)
return file_path
def transcribe_audio(audio_file_path):
"""
Transcribes speech from an audio file using a pretrained Wav2Vec2 model.
Args:
audio_file_path (str): Path to the audio file.
Returns:
str: The transcription of the speech in the audio file.
"""
speech_array, sampling_rate = librosa.load(audio_file_path, sr=16000)
input_values = processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt", padding=True).input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0].strip()
return transcription
def levenshtein_similarity(transcription1, transcription2):
"""
Calculate the Levenshtein similarity between two transcriptions.
Args:
transcription1 (str): The first transcription.
transcription2 (str): The second transcription.
Returns:
float: A normalized similarity score between 0 and 1, where 1 indicates identical transcriptions.
"""
distance = Levenshtein.distance(transcription1, transcription2)
max_len = max(len(transcription1), len(transcription2))
return 1 - distance / max_len # Normalize to get similarity score
def evaluate_audio_similarity(original_audio_path, user_audio_path):
"""
Compares the similarity between the transcription of an original audio file and a user's audio file.
Args:
original_audio_path (str): Path to the original audio file.
user_audio_path (str): Path to the user's audio file.
Returns:
tuple: Transcriptions and Levenshtein similarity score.
"""
transcription_original = transcribe_audio(original_audio_path)
transcription_user = transcribe_audio(user_audio_path)
similarity_score_levenshtein = levenshtein_similarity(transcription_original, transcription_user)
return transcription_original, transcription_user, similarity_score_levenshtein
def perform_testing(original_audio, user_audio):
# Debugging: Check if audio data is received
if original_audio is None:
print("Original audio is None")
else:
print(f"Original audio path: {original_audio}")
if user_audio is None:
print("User audio is None")
else:
print(f"User audio path: {user_audio}")
if original_audio is None or user_audio is None:
return {"Error": "Please provide both original and user audio."}
# Save the recorded audio files
original_audio_path = save_audio(original_audio)
user_audio_path = save_audio(user_audio)
transcription_original, transcription_user, similarity_score = evaluate_audio_similarity(original_audio_path, user_audio_path)
result = {
"Original Transcription": transcription_original,
"User Transcription": transcription_user,
"Levenshtein Similarity Score": similarity_score,
}
if similarity_score > 0.8:
result["Feedback"] = "The pronunciation is likely correct based on transcription similarity."
else:
result["Feedback"] = "The pronunciation may be incorrect based on transcription similarity."
return result
# Define the Gradio app for recording and processing audio
def gradio_app():
with gr.Blocks() as demo:
gr.Markdown("# Audio Transcription and Similarity Checker")
original_audio = gr.Audio(label="Record Original Audio", type="filepath")
user_audio = gr.Audio(label="Record User Audio", type="filepath")
result_output = gr.JSON(label="Output")
# Button to perform the testing
test_button = gr.Button("Perform Testing")
test_button.click(perform_testing, inputs=[original_audio, user_audio], outputs=result_output)
return demo
# Launch the Gradio app
demo = gradio_app()
demo.launch()