Spaces:
Runtime error
Runtime error
import re | |
import gradio as gr | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline | |
import torch | |
import numpy as np | |
# Load Whisper model for transcription | |
whisper_model_name = "openai/whisper-large" | |
processor = WhisperProcessor.from_pretrained(whisper_model_name) | |
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name) | |
# Initialize the language detection model | |
lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
# Function to transcribe audio to text using Whisper model | |
def transcribe_audio(audio_file): | |
# Check if audio_file is a list (Gradio returns a list when multiple clips are recorded) | |
if isinstance(audio_file, list): | |
audio = np.concatenate(audio_file) # Concatenate the list of arrays into a single 1D array | |
else: | |
audio = np.array(audio_file) # Ensure it's a 1D array | |
# Ensure the shape is 1D (if the shape is (2, N), we flatten it) | |
if len(audio.shape) > 1: | |
audio = audio.flatten() | |
# Prepare input features for Whisper (sampling rate should be 16000 for Whisper) | |
input_features = processor(audio, return_tensors="pt", sampling_rate=16000) | |
# Generate transcription | |
generated_ids = model.generate(input_features["input_features"]) | |
transcription = processor.decode(generated_ids[0], skip_special_tokens=True) | |
return transcription | |
# Function to detect the language of the transcription using zero-shot classification | |
def detect_language(text): | |
result = lang_detect_model(text, candidate_labels=["en", "fr", "es", "de", "it", "pt", "zh", "ja", "ar", "hi"]) | |
return result['labels'][0], result['scores'][0] # Return the detected language and score | |
# Cleanup function to remove filler words and clean the transcription | |
def cleanup_text(text): | |
# Remove filler words like "uh", "um", etc. | |
text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE) | |
# Remove extra spaces | |
text = re.sub(r'\s+', ' ', text) | |
# Strip leading and trailing spaces | |
text = text.strip() | |
# Capitalize the first letter | |
text = text.capitalize() | |
return text | |
# Main function to process the audio and detect language | |
def process_audio(audio_file): | |
try: | |
transcription = transcribe_audio(audio_file) # Transcribe audio to text | |
if not transcription.strip(): # If transcription is empty or just whitespace | |
raise ValueError("Transcription is empty.") | |
lang, score = detect_language(transcription) # Detect the language of the transcription | |
cleaned_text = cleanup_text(transcription) # Clean up the transcription | |
return cleaned_text, lang, score # Return cleaned transcription, language, and confidence score | |
except Exception as e: | |
# If any error occurs, return the error message | |
return f"Error: {str(e)}", "", "" | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(label="Record your voice", type="numpy", scale=1) # Input for live audio (microphone) | |
output_text = gr.Textbox(label="Transcription", scale=1) # Output text for transcription | |
output_lang = gr.Textbox(label="Detected Language", scale=1) # Output text for detected language | |
output_score = gr.Textbox(label="Confidence Score", scale=1) # Output confidence score | |
process_btn = gr.Button("Process Audio") # Button to process audio | |
process_btn.click(fn=process_audio, inputs=[audio_input], outputs=[output_text, output_lang, output_score]) | |
demo.launch(debug=True) |