File size: 6,473 Bytes
5543320 ecaaa4b 5543320 763f8fd 658ecdc 8e3a22d 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 a08d409 5543320 a08d409 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b 5543320 ecaaa4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import torch
import librosa
from transformers import Wav2Vec2Processor, AutoModelForCTC
import zipfile
import os
import firebase_admin
from firebase_admin import credentials, firestore
from datetime import datetime
import json
# Initialize Firebase
firebase_config = json.loads(os.environ.get('firebase_creds'))
cred = credentials.Certificate(firebase_config) # Your Firebase JSON key file
firebase_admin.initialize_app(cred)
db = firestore.client()
# Load the ASR model and processor
MODEL_NAME = "eleferrand/xlsr53_Amis"
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model = AutoModelForCTC.from_pretrained(MODEL_NAME)
def transcribe(audio_file):
"""
Transcribes the audio file using the loaded ASR model.
Returns the transcription string.
"""
try:
# Load and resample the audio to 16kHz
audio, rate = librosa.load(audio_file, sr=16000)
# Prepare the input tensor for the model
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
# Get model predictions (logits) and decode to text
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription.replace("[UNK]", "")
except Exception as e:
return f"Error processing file: {e}"
def transcribe_both(audio_file):
"""
Transcribes the audio and returns:
- the original transcription (for the non-editable textbox),
- the transcription (pre-filled for the editable textbox), and
- the processing time (in seconds).
"""
start_time = datetime.now()
transcription = transcribe(audio_file)
processing_time = (datetime.now() - start_time).total_seconds()
return transcription, transcription, processing_time
def store_correction(original_transcription, corrected_transcription, audio_file, processing_time):
"""
Stores the transcriptions and additional metadata in Firestore.
Saves:
- original & corrected text,
- timestamp,
- processing time,
- audio metadata (duration & file size, if available),
- a placeholder for the audio URL, and
- the model name.
"""
try:
audio_metadata = {}
if audio_file and os.path.exists(audio_file):
# Load audio for metadata calculations
audio, sr = librosa.load(audio_file, sr=16000)
duration = librosa.get_duration(y=audio, sr=sr)
file_size = os.path.getsize(audio_file)
audio_metadata = {
'duration': duration,
'file_size': file_size
}
correction_data = {
'original_text': original_transcription,
'corrected_text': corrected_transcription,
'timestamp': datetime.now().isoformat(),
'processing_time': processing_time,
'audio_metadata': audio_metadata,
'audio_url': None,
'model_name': MODEL_NAME
}
db.collection('transcription_corrections').add(correction_data)
return "Correction saved successfully!"
except Exception as e:
return f"Error saving correction: {e}"
def prepare_download(audio_file, original_transcription, corrected_transcription):
"""
Prepares a ZIP file containing:
- The uploaded audio file (as audio.wav),
- a text file with the original transcription, and
- a text file with the corrected transcription.
Returns the ZIP file's path.
"""
if audio_file is None:
return None
zip_filename = "results.zip"
with zipfile.ZipFile(zip_filename, "w") as zf:
# Add the audio file (renamed inside the zip)
if os.path.exists(audio_file):
zf.write(audio_file, arcname="audio.wav")
else:
print("Audio file not found:", audio_file)
# Add the original transcription as a text file
orig_txt = "original_transcription.txt"
with open(orig_txt, "w", encoding="utf-8") as f:
f.write(original_transcription)
zf.write(orig_txt, arcname="original_transcription.txt")
os.remove(orig_txt)
# Add the corrected transcription as a text file
corr_txt = "corrected_transcription.txt"
with open(corr_txt, "w", encoding="utf-8") as f:
f.write(corrected_transcription)
zf.write(corr_txt, arcname="corrected_transcription.txt")
os.remove(corr_txt)
return zip_filename
# Build the Gradio Blocks interface with improved styling
with gr.Blocks(css="""
.container { max-width: 800px; margin: auto; }
.title { text-align: center; }
""") as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("<h1 class='title'>ASR Demo with Editable Transcription</h1>")
with gr.Row():
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio")
transcribe_button = gr.Button("Transcribe Audio", variant="primary")
with gr.Row():
original_text = gr.Textbox(label="Original Transcription", interactive=False, lines=5)
corrected_text = gr.Textbox(label="Corrected Transcription", interactive=True, lines=5)
# Hidden state to hold processing time
proc_time_state = gr.State()
with gr.Row():
save_button = gr.Button("Save Correction to Database", variant="primary")
save_status = gr.Textbox(label="Save Status", interactive=False)
with gr.Accordion("Download Options", open=False):
with gr.Row():
download_button = gr.Button("Download Results (ZIP)")
download_output = gr.File(label="Download ZIP")
# Set up actions
transcribe_button.click(
fn=transcribe_both,
inputs=audio_input,
outputs=[original_text, corrected_text, proc_time_state]
)
save_button.click(
fn=store_correction,
inputs=[original_text, corrected_text, audio_input, proc_time_state],
outputs=save_status
)
download_button.click(
fn=prepare_download,
inputs=[audio_input, original_text, corrected_text],
outputs=download_output
)
# Launch the demo
demo.launch(share=True)
|