semakoc commited on
Commit
5543320
·
verified ·
1 Parent(s): fa92657

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -45
app.py CHANGED
@@ -1,45 +1,140 @@
1
- import gradio as gr
2
- import torch
3
- import librosa
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
-
6
- # Load Wav2Vec2 Model
7
- MODEL_NAME = "facebook/wav2vec2-large-960h"
8
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
9
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
10
-
11
- def transcribe(audio_file):
12
- """
13
- Transcribes speech from an uploaded audio file or live microphone input.
14
- """
15
- try:
16
- # Load and convert audio to 16kHz
17
- audio, rate = librosa.load(audio_file, sr=16000)
18
-
19
- # Convert audio to tensor format for Wav2Vec
20
- input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
21
-
22
- # Run the model for transcription
23
- with torch.no_grad():
24
- logits = model(input_values).logits
25
-
26
- # Convert predicted tokens into text
27
- predicted_ids = torch.argmax(logits, dim=-1)
28
- transcription = processor.batch_decode(predicted_ids)[0]
29
-
30
- return transcription
31
-
32
- except Exception as e:
33
- return "Error processing file"
34
-
35
- # UI Build
36
- interface = gr.Interface(
37
- fn=transcribe,
38
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Speak or Upload Audio"),
39
- outputs="text",
40
- title="Wav2Vec2 Speech-to-Text Transcription",
41
- description="Speak into your microphone or upload an audio file to get an automatic transcription.",
42
- live=True # Real-time microphone processing
43
- )
44
-
45
- interface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ from transformers import Wav2Vec2Processor, AutoModelForCTC
5
+ import zipfile
6
+ import os
7
+ import firebase_admin
8
+ from firebase_admin import credentials, firestore
9
+ from datetime import datetime
10
+
11
+ # 🔹 Initialize Firebase
12
+ cred = credentials.Certificate('firebase_credentials.json') # Your Firebase JSON key file
13
+ firebase_admin.initialize_app(cred)
14
+ db = firestore.client()
15
+
16
+ # Load the ASR model and processor
17
+ MODEL_NAME = "eleferrand/xlsr53_Amis"
18
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
19
+ model = AutoModelForCTC.from_pretrained(MODEL_NAME)
20
+
21
+ def transcribe(audio_file):
22
+ """
23
+ Transcribes the audio file using the loaded ASR model.
24
+ Returns the transcription string.
25
+ """
26
+ try:
27
+ # Load and resample the audio to 16kHz
28
+ audio, rate = librosa.load(audio_file, sr=16000)
29
+ # Prepare the input tensor for the model
30
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
31
+
32
+ # Get model predictions (logits) and decode to text
33
+ with torch.no_grad():
34
+ logits = model(input_values).logits
35
+ predicted_ids = torch.argmax(logits, dim=-1)
36
+ transcription = processor.batch_decode(predicted_ids)[0]
37
+ return transcription.replace("[UNK]", "")
38
+ except Exception as e:
39
+ return f"Error processing file: {e}"
40
+
41
+ def transcribe_both(audio_file):
42
+ """
43
+ Calls the transcribe function and returns the transcription
44
+ for both the original (read-only) and the corrected (editable) textboxes.
45
+ """
46
+ transcription = transcribe(audio_file)
47
+ return transcription, transcription
48
+
49
+ def store_correction(original_transcription, corrected_transcription):
50
+ """
51
+ Stores the original and corrected transcription in Firestore.
52
+ """
53
+ try:
54
+ correction_data = {
55
+ 'original_text': original_transcription,
56
+ 'corrected_text': corrected_transcription,
57
+ 'timestamp': datetime.now().isoformat()
58
+ }
59
+ db.collection('transcription_corrections').add(correction_data)
60
+ return "✅ Correction saved successfully!"
61
+ except Exception as e:
62
+ return f"⚠️ Error saving correction: {e}"
63
+
64
+ def prepare_download(audio_file, original_transcription, corrected_transcription):
65
+ """
66
+ Prepares a ZIP file containing:
67
+ - The uploaded audio file (saved as audio.wav)
68
+ - A text file with the original transcription
69
+ - A text file with the corrected transcription
70
+ Returns the path to the ZIP file.
71
+ """
72
+ if audio_file is None:
73
+ return None
74
+
75
+ zip_filename = "results.zip"
76
+ with zipfile.ZipFile(zip_filename, "w") as zf:
77
+ # Add the audio file (saved as audio.wav in the zip)
78
+ if os.path.exists(audio_file):
79
+ zf.write(audio_file, arcname="audio.wav")
80
+ else:
81
+ print("Audio file not found:", audio_file)
82
+
83
+ # Create and add the original transcription file
84
+ orig_txt = "original_transcription.txt"
85
+ with open(orig_txt, "w", encoding="utf-8") as f:
86
+ f.write(original_transcription)
87
+ zf.write(orig_txt, arcname="original_transcription.txt")
88
+ os.remove(orig_txt)
89
+
90
+ # Create and add the corrected transcription file
91
+ corr_txt = "corrected_transcription.txt"
92
+ with open(corr_txt, "w", encoding="utf-8") as f:
93
+ f.write(corrected_transcription)
94
+ zf.write(corr_txt, arcname="corrected_transcription.txt")
95
+ os.remove(corr_txt)
96
+ return zip_filename
97
+
98
+ # Build the Gradio Blocks interface
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("# ASR Demo with Editable Transcription, Firestore Storage, and Download")
101
+
102
+ with gr.Row():
103
+ audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio")
104
+ transcribe_button = gr.Button("Transcribe Audio")
105
+
106
+ with gr.Row():
107
+ # The original transcription is displayed (non-editable)
108
+ original_text = gr.Textbox(label="Original Transcription", interactive=False)
109
+ # The corrected transcription is pre-filled with the original, but remains editable.
110
+ corrected_text = gr.Textbox(label="Corrected Transcription", interactive=True)
111
+
112
+ save_button = gr.Button("Save Correction to Database")
113
+ save_status = gr.Textbox(label="Save Status", interactive=False)
114
+
115
+ download_button = gr.Button("Download Results (ZIP)")
116
+ download_output = gr.File(label="Download ZIP")
117
+
118
+ # When the transcribe button is clicked, update both textboxes with the transcription.
119
+ transcribe_button.click(
120
+ fn=transcribe_both,
121
+ inputs=audio_input,
122
+ outputs=[original_text, corrected_text]
123
+ )
124
+
125
+ # When the "Save Correction" button is clicked, store the corrected transcription in Firestore.
126
+ save_button.click(
127
+ fn=store_correction,
128
+ inputs=[original_text, corrected_text],
129
+ outputs=save_status
130
+ )
131
+
132
+ # When the download button is clicked, package the audio file and both transcriptions into a zip.
133
+ download_button.click(
134
+ fn=prepare_download,
135
+ inputs=[audio_input, original_text, corrected_text],
136
+ outputs=download_output
137
+ )
138
+
139
+ # Launch the demo
140
+ demo.launch(share=True)