import gradio as gr import torch import librosa from transformers import Wav2Vec2Processor, AutoModelForCTC import zipfile import os import firebase_admin from firebase_admin import credentials, firestore from datetime import datetime import json # Initialize Firebase firebase_config = json.loads(os.environ.get('firebase_creds')) cred = credentials.Certificate(firebase_config) # Your Firebase JSON key file firebase_admin.initialize_app(cred) db = firestore.client() # Load the ASR model and processor MODEL_NAME = "eleferrand/xlsr53_Amis" processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) model = AutoModelForCTC.from_pretrained(MODEL_NAME) def transcribe(audio_file): """ Transcribes the audio file using the loaded ASR model. Returns the transcription string. """ try: # Load and resample the audio to 16kHz audio, rate = librosa.load(audio_file, sr=16000) # Prepare the input tensor for the model input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values # Get model predictions (logits) and decode to text with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription.replace("[UNK]", "") except Exception as e: return f"Error processing file: {e}" def transcribe_both(audio_file): """ Transcribes the audio and returns: - the original transcription (for the non-editable textbox), - the transcription (pre-filled for the editable textbox), and - the processing time (in seconds). """ start_time = datetime.now() transcription = transcribe(audio_file) processing_time = (datetime.now() - start_time).total_seconds() return transcription, transcription, processing_time def store_correction(original_transcription, corrected_transcription, audio_file, processing_time): """ Stores the transcriptions and additional metadata in Firestore. Saves: - original & corrected text, - timestamp, - processing time, - audio metadata (duration & file size, if available), - a placeholder for the audio URL, and - the model name. """ try: audio_metadata = {} if audio_file and os.path.exists(audio_file): # Load audio for metadata calculations audio, sr = librosa.load(audio_file, sr=16000) duration = librosa.get_duration(y=audio, sr=sr) file_size = os.path.getsize(audio_file) audio_metadata = { 'duration': duration, 'file_size': file_size } correction_data = { 'original_text': original_transcription, 'corrected_text': corrected_transcription, 'timestamp': datetime.now().isoformat(), 'processing_time': processing_time, 'audio_metadata': audio_metadata, 'audio_url': None, 'model_name': MODEL_NAME } db.collection('transcription_corrections').add(correction_data) return "Correction saved successfully!" except Exception as e: return f"Error saving correction: {e}" def prepare_download(audio_file, original_transcription, corrected_transcription): """ Prepares a ZIP file containing: - The uploaded audio file (as audio.wav), - a text file with the original transcription, and - a text file with the corrected transcription. Returns the ZIP file's path. """ if audio_file is None: return None zip_filename = "results.zip" with zipfile.ZipFile(zip_filename, "w") as zf: # Add the audio file (renamed inside the zip) if os.path.exists(audio_file): zf.write(audio_file, arcname="audio.wav") else: print("Audio file not found:", audio_file) # Add the original transcription as a text file orig_txt = "original_transcription.txt" with open(orig_txt, "w", encoding="utf-8") as f: f.write(original_transcription) zf.write(orig_txt, arcname="original_transcription.txt") os.remove(orig_txt) # Add the corrected transcription as a text file corr_txt = "corrected_transcription.txt" with open(corr_txt, "w", encoding="utf-8") as f: f.write(corrected_transcription) zf.write(corr_txt, arcname="corrected_transcription.txt") os.remove(corr_txt) return zip_filename # Build the Gradio Blocks interface with improved styling with gr.Blocks(css=""" .container { max-width: 800px; margin: auto; } .title { text-align: center; } """) as demo: with gr.Column(elem_classes="container"): gr.Markdown("