File size: 6,473 Bytes
5543320
 
 
 
 
 
 
 
 
ecaaa4b
5543320
763f8fd
658ecdc
8e3a22d
5543320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecaaa4b
 
 
 
5543320
ecaaa4b
5543320
ecaaa4b
 
5543320
ecaaa4b
5543320
ecaaa4b
 
 
 
 
 
 
 
5543320
 
ecaaa4b
 
 
 
 
 
 
 
 
 
5543320
 
 
ecaaa4b
 
 
 
 
5543320
 
a08d409
5543320
a08d409
5543320
 
 
 
ecaaa4b
 
 
 
5543320
 
 
 
 
 
ecaaa4b
5543320
 
 
 
 
ecaaa4b
5543320
 
 
 
 
 
ecaaa4b
5543320
 
 
 
 
 
 
ecaaa4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5543320
 
ecaaa4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import torch
import librosa
from transformers import Wav2Vec2Processor, AutoModelForCTC
import zipfile
import os
import firebase_admin
from firebase_admin import credentials, firestore
from datetime import datetime
import json

# Initialize Firebase
firebase_config = json.loads(os.environ.get('firebase_creds'))
cred = credentials.Certificate(firebase_config)  # Your Firebase JSON key file
firebase_admin.initialize_app(cred)
db = firestore.client()

# Load the ASR model and processor
MODEL_NAME = "eleferrand/xlsr53_Amis"
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model = AutoModelForCTC.from_pretrained(MODEL_NAME)

def transcribe(audio_file):
    """
    Transcribes the audio file using the loaded ASR model.
    Returns the transcription string.
    """
    try:
        # Load and resample the audio to 16kHz
        audio, rate = librosa.load(audio_file, sr=16000)
        # Prepare the input tensor for the model
        input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values

        # Get model predictions (logits) and decode to text
        with torch.no_grad():
            logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
        return transcription.replace("[UNK]", "")
    except Exception as e:
        return f"Error processing file: {e}"

def transcribe_both(audio_file):
    """
    Transcribes the audio and returns:
      - the original transcription (for the non-editable textbox),
      - the transcription (pre-filled for the editable textbox), and
      - the processing time (in seconds).
    """
    start_time = datetime.now()
    transcription = transcribe(audio_file)
    processing_time = (datetime.now() - start_time).total_seconds()
    return transcription, transcription, processing_time

def store_correction(original_transcription, corrected_transcription, audio_file, processing_time):
    """
    Stores the transcriptions and additional metadata in Firestore.
    Saves:
      - original & corrected text,
      - timestamp,
      - processing time,
      - audio metadata (duration & file size, if available),
      - a placeholder for the audio URL, and
      - the model name.
    """
    try:
        audio_metadata = {}
        if audio_file and os.path.exists(audio_file):
            # Load audio for metadata calculations
            audio, sr = librosa.load(audio_file, sr=16000)
            duration = librosa.get_duration(y=audio, sr=sr)
            file_size = os.path.getsize(audio_file)
            audio_metadata = {
                'duration': duration,
                'file_size': file_size
            }
        correction_data = {
            'original_text': original_transcription,
            'corrected_text': corrected_transcription,
            'timestamp': datetime.now().isoformat(),
            'processing_time': processing_time,
            'audio_metadata': audio_metadata,
            'audio_url': None,
            'model_name': MODEL_NAME
        }
        db.collection('transcription_corrections').add(correction_data)
        return "Correction saved successfully!"
    except Exception as e:
        return f"Error saving correction: {e}"

def prepare_download(audio_file, original_transcription, corrected_transcription):
    """
    Prepares a ZIP file containing:
      - The uploaded audio file (as audio.wav),
      - a text file with the original transcription, and
      - a text file with the corrected transcription.
    Returns the ZIP file's path.
    """
    if audio_file is None:
        return None

    zip_filename = "results.zip"
    with zipfile.ZipFile(zip_filename, "w") as zf:
        # Add the audio file (renamed inside the zip)
        if os.path.exists(audio_file):
            zf.write(audio_file, arcname="audio.wav")
        else:
            print("Audio file not found:", audio_file)

        # Add the original transcription as a text file
        orig_txt = "original_transcription.txt"
        with open(orig_txt, "w", encoding="utf-8") as f:
            f.write(original_transcription)
        zf.write(orig_txt, arcname="original_transcription.txt")
        os.remove(orig_txt)

        # Add the corrected transcription as a text file
        corr_txt = "corrected_transcription.txt"
        with open(corr_txt, "w", encoding="utf-8") as f:
            f.write(corrected_transcription)
        zf.write(corr_txt, arcname="corrected_transcription.txt")
        os.remove(corr_txt)
    return zip_filename

# Build the Gradio Blocks interface with improved styling
with gr.Blocks(css="""
    .container { max-width: 800px; margin: auto; }
    .title { text-align: center; }
""") as demo:
    with gr.Column(elem_classes="container"):
        gr.Markdown("<h1 class='title'>ASR Demo with Editable Transcription</h1>")
        with gr.Row():
            audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio")
            transcribe_button = gr.Button("Transcribe Audio", variant="primary")
        with gr.Row():
            original_text = gr.Textbox(label="Original Transcription", interactive=False, lines=5)
            corrected_text = gr.Textbox(label="Corrected Transcription", interactive=True, lines=5)
        # Hidden state to hold processing time
        proc_time_state = gr.State()
        with gr.Row():
            save_button = gr.Button("Save Correction to Database", variant="primary")
            save_status = gr.Textbox(label="Save Status", interactive=False)
        with gr.Accordion("Download Options", open=False):
            with gr.Row():
                download_button = gr.Button("Download Results (ZIP)")
                download_output = gr.File(label="Download ZIP")
        
        # Set up actions
        transcribe_button.click(
            fn=transcribe_both,
            inputs=audio_input,
            outputs=[original_text, corrected_text, proc_time_state]
        )
        
        save_button.click(
            fn=store_correction,
            inputs=[original_text, corrected_text, audio_input, proc_time_state],
            outputs=save_status
        )
        
        download_button.click(
            fn=prepare_download,
            inputs=[audio_input, original_text, corrected_text],
            outputs=download_output
        )

# Launch the demo
demo.launch(share=True)