hunterschep commited on
Commit
4b8259d
·
verified ·
1 Parent(s): d954236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -94
app.py CHANGED
@@ -1,126 +1,146 @@
1
  import gradio as gr
2
- import torch, librosa, zipfile, os, json, tempfile, uuid
 
3
  from transformers import Wav2Vec2Processor, AutoModelForCTC
4
- from datetime import datetime, timedelta
 
5
  import firebase_admin
6
  from firebase_admin import credentials, firestore, storage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # ---------- Firebase init ----------
9
- firebase_config = json.loads(os.environ.get("firebase_creds"))
10
  cred = credentials.Certificate(firebase_config)
11
- firebase_admin.initialize_app(
12
- cred, {"storageBucket": "amis-asr-corrections-dem-8cf3d.firebasestorage.app"}
13
- )
14
  db = firestore.client()
15
  bucket = storage.bucket()
16
 
17
- # ---------- ASR model ----------
18
  MODEL_NAME = "eleferrand/XLSR_paiwan"
19
  processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
20
  model = AutoModelForCTC.from_pretrained(MODEL_NAME)
21
 
22
- # ---------- Core helpers ----------
23
- def transcribe(path):
24
- try:
25
- audio, _ = librosa.load(path, sr=16_000)
26
- inputs = processor(audio, sampling_rate=16_000, return_tensors="pt").input_values
27
- with torch.no_grad():
28
- logits = model(inputs).logits
29
- ids = torch.argmax(logits, dim=-1)
30
- text = processor.batch_decode(ids)[0]
31
- return text.replace("[UNK]", "")
32
- except Exception as e:
33
- return f"處理文件錯誤: {e}"
34
-
35
- def transcribe_both(path):
36
- txt = transcribe(path)
37
- return txt, txt # original & editable copies
38
 
39
- def store_correction(orig, corr, audio, age, native):
40
  try:
41
- audio_meta, audio_url = {}, None
42
- if audio and os.path.exists(audio):
43
- a, sr = librosa.load(audio, sr=44_100)
44
- audio_meta = {
45
- "duration": librosa.get_duration(y=a, sr=sr),
46
- "file_size": os.path.getsize(audio),
47
- }
48
- uid = f"{uuid.uuid4()}.wav"
49
- blob = bucket.blob(f"audio/pai/{uid}")
50
- blob.upload_from_filename(audio)
51
- audio_url = blob.generate_signed_url(expiration=timedelta(hours=1))
52
-
53
- db.collection("paiwan_transcriptions").add(
54
- {
55
- "transcription_info": {
56
- "original_text": orig,
57
- "corrected_text": corr,
58
- "language": "pai",
59
- },
60
- "audio_data": {"audio_metadata": audio_meta, "audio_file_url": audio_url},
61
- "user_info": {"native_paiwan_speaker": native, "age": age},
62
- "timestamp": datetime.now().isoformat(),
63
- "model_name": MODEL_NAME,
64
- }
65
- )
66
- return "校正保存成功! (Correction saved successfully!)"
67
  except Exception as e:
68
- return f"保存失敗: {e} (Error saving correction: {e})"
69
 
70
- def prepare_download(audio, orig, corr):
71
- if not audio:
72
  return None
73
  tmp_zip = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
74
  tmp_zip.close()
75
- with zipfile.ZipFile(tmp_zip.name, "w") as z:
76
- if os.path.exists(audio):
77
- z.write(audio, arcname="audio.wav")
78
- for name, txt in [("original_transcription.txt", orig),
79
- ("corrected_transcription.txt", corr)]:
80
- with open(name, "w", encoding="utf-8") as f:
81
- f.write(txt)
82
- z.write(name, arcname=name)
83
- os.remove(name)
 
 
 
 
84
  return tmp_zip.name
85
 
86
- # ---------- Interface ----------
87
  with gr.Blocks() as demo:
88
- gr.Markdown("# 排灣語自動語音識別逐字稿與修正系統 (Paiwan ASR Transcription & Correction System)")
89
-
90
- # Step 1
91
- gr.Markdown("### 步驟 1:音訊上傳 (Audio Upload)")
92
- gr.Markdown("上傳後請至步驟 2 按「產生逐字稿」,系統處理時請耐心等待…")
93
- audio_input = gr.Audio(["upload", "microphone"], type="filepath",
94
- label="音訊輸入 (Audio Input)")
 
95
 
96
- # Step 2
97
- gr.Markdown("### 步驟 2:產生與編輯逐字稿 (Generate & Edit Transcript)")
98
- trans_btn = gr.Button("產生逐字稿 (Generate Transcript)")
99
- original = gr.Textbox(label="原始逐字稿 (Original Transcription)",
100
- interactive=False, lines=6)
101
- corrected = gr.Textbox(label="更正逐字稿 (Corrected Transcription)",
102
- interactive=True, lines=6)
 
 
 
 
 
 
 
 
103
 
104
- # Step 3
105
- gr.Markdown("### 步驟 3:使用者資訊 (User Information)")
106
  with gr.Row():
107
- age = gr.Slider(0, 100, step=1, value=25, label="年齡 (Age)")
108
- native = gr.Checkbox(value=True, label="母語排灣語使用者?(Native Paiwan Speaker?)")
 
 
 
 
109
 
110
- # Step 4
111
- gr.Markdown("### 步驟 4:儲存與下載 (Save & Download)")
112
  with gr.Row():
113
- save_btn = gr.Button("儲存 (Save)")
114
- save_msg = gr.Textbox(label="儲存狀態 (Save Status)", interactive=False)
 
 
 
115
  with gr.Row():
116
- dl_btn = gr.Button("下載 ZIP 檔案 (Download ZIP File)")
117
- dl_out = gr.File()
118
 
119
- # --- wiring ---
120
- trans_btn.click(transcribe_both, audio_input, [original, corrected])
121
- save_btn.click(store_correction,
122
- [original, corrected, audio_input, age, native],
123
- save_msg)
124
- dl_btn.click(prepare_download, [audio_input, original, corrected], dl_out)
 
 
 
 
125
 
126
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import librosa
4
  from transformers import Wav2Vec2Processor, AutoModelForCTC
5
+ import zipfile
6
+ import os
7
  import firebase_admin
8
  from firebase_admin import credentials, firestore, storage
9
+ from datetime import datetime, timedelta
10
+ import json
11
+ tmpdir = None
12
+
13
+ def transcribe(audio_file):
14
+ try:
15
+ audio, rate = librosa.load(audio_file, sr=16000)
16
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
17
+ with torch.no_grad():
18
+ logits = model(input_values).logits
19
+ predicted_ids = torch.argmax(logits, dim=-1)
20
+ transcription = processor.batch_decode(predicted_ids)[0]
21
+ return transcription.replace("[UNK]", "")
22
+ except Exception as e:
23
+ return f"處理文件錯誤: {e}"
24
 
25
+ # Initialize Firebase
26
+ firebase_config = json.loads(os.environ.get('firebase_creds'))
27
  cred = credentials.Certificate(firebase_config)
28
+ firebase_admin.initialize_app(cred, {
29
+ "storageBucket": "amis-asr-corrections-dem-8cf3d.firebasestorage.app"
30
+ })
31
  db = firestore.client()
32
  bucket = storage.bucket()
33
 
34
+ # Load ASR model and processor
35
  MODEL_NAME = "eleferrand/XLSR_paiwan"
36
  processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
37
  model = AutoModelForCTC.from_pretrained(MODEL_NAME)
38
 
39
+ def transcribe_both(audio_file):
40
+ transcription = transcribe(audio_file)
41
+ return transcription, transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def store_correction(original_transcription, corrected_transcription, audio_file, age, native_speaker):
44
  try:
45
+ audio_metadata = {}
46
+ audio_file_url = None
47
+ if audio_file and os.path.exists(audio_file):
48
+ audio, sr = librosa.load(audio_file, sr=44100)
49
+ duration = librosa.get_duration(y=audio, sr=sr)
50
+ file_size = os.path.getsize(audio_file)
51
+ audio_metadata = {'duration': duration, 'file_size': file_size}
52
+ unique_id = str(uuid.uuid4())
53
+ destination_path = f"audio/pai/{unique_id}.wav"
54
+ blob = bucket.blob(destination_path)
55
+ blob.upload_from_filename(audio_file)
56
+ audio_file_url = blob.generate_signed_url(expiration=timedelta(hours=1))
57
+ combined_data = {
58
+ 'transcription_info': {'original_text': original_transcription, 'corrected_text': corrected_transcription, 'language': 'pai'},
59
+ 'audio_data': {'audio_metadata': audio_metadata, 'audio_file_url': audio_file_url},
60
+ 'user_info': {'native_paiwan_speaker': native_speaker, 'age': age},
61
+ 'timestamp': datetime.now().isoformat(), 'model_name': MODEL_NAME
62
+ }
63
+ db.collection('paiwan_transcriptions').add(combined_data)
64
+ return "校正保存成功!"
 
 
 
 
 
 
65
  except Exception as e:
66
+ return f"保存失败: {e}"
67
 
68
+ def prepare_download(audio_file, original_transcription, corrected_transcription):
69
+ if audio_file is None:
70
  return None
71
  tmp_zip = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
72
  tmp_zip.close()
73
+ with zipfile.ZipFile(tmp_zip.name, "w") as zf:
74
+ if os.path.exists(audio_file):
75
+ zf.write(audio_file, arcname="audio.wav")
76
+ orig_txt = "original_transcription.txt"
77
+ with open(orig_txt, "w", encoding="utf-8") as f:
78
+ f.write(original_transcription)
79
+ zf.write(orig_txt, arcname=orig_txt)
80
+ os.remove(orig_txt)
81
+ corr_txt = "corrected_transcription.txt"
82
+ with open(corr_txt, "w", encoding="utf-8") as f:
83
+ f.write(corrected_transcription)
84
+ zf.write(corr_txt, arcname=corr_txt)
85
+ os.remove(corr_txt)
86
  return tmp_zip.name
87
 
88
+ # Interface
89
  with gr.Blocks() as demo:
90
+ title = gr.Markdown("排灣語自動語音識別校正系統 (Paiwan ASR Transcription & Correction System)")
91
+ step1 = gr.Markdown(
92
+ "步驟 1:音訊上傳與產生逐字稿 (Audio Upload & Automatic Transcription)\n\n上傳後系統將自動產生逐字稿,請耐心等待。"
93
+ )
94
+ with gr.Row():
95
+ audio_input = gr.Audio(
96
+ sources=["upload", "microphone"], type="filepath", label="音訊輸入 (Audio Input)"
97
+ )
98
 
99
+ step2 = gr.Markdown("步驟 2:審閱與編輯逐字稿 (Step 2: Review & Edit Transcription)")
100
+ with gr.Row():
101
+ original_text = gr.Textbox(
102
+ label="原始逐字稿 (Original Transcription)", interactive=False, lines=5
103
+ )
104
+ corrected_text = gr.Textbox(
105
+ label="更正逐字稿 (Corrected Transcription)", interactive=True, lines=5
106
+ )
107
+ # Automatically generate transcription on audio upload
108
+ audio_input.change(
109
+ transcribe_both,
110
+ inputs=audio_input,
111
+ outputs=[original_text, corrected_text],
112
+ queue=True
113
+ )
114
 
115
+ step3 = gr.Markdown("步驟 3:使用者資訊 (Step 3: User Information)")
 
116
  with gr.Row():
117
+ age_input = gr.Slider(
118
+ minimum=0, maximum=100, step=1, label="年齡 (Age)", value=25
119
+ )
120
+ native_speaker_input = gr.Checkbox(
121
+ label="母語排灣語使用者? (Native Paiwan Speaker?)", value=True
122
+ )
123
 
124
+ step4 = gr.Markdown("步驟 4:儲存與下載 (Step 4: Save & Download)")
 
125
  with gr.Row():
126
+ save_button = gr.Button("儲存 (Save)")
127
+ save_status = gr.Textbox(
128
+ label="儲存狀態 (Save Status)", interactive=False
129
+ )
130
+
131
  with gr.Row():
132
+ download_button = gr.Button("下載 ZIP 檔案 (Download ZIP File)")
133
+ download_output = gr.File()
134
 
135
+ save_button.click(
136
+ store_correction,
137
+ inputs=[original_text, corrected_text, audio_input, age_input, native_speaker_input],
138
+ outputs=save_status
139
+ )
140
+ download_button.click(
141
+ prepare_download,
142
+ inputs=[audio_input, original_text, corrected_text],
143
+ outputs=download_output
144
+ )
145
 
146
  demo.launch()