Update app.py
Browse files
app.py
CHANGED
@@ -1,118 +1,217 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from pyannote.audio import Pipeline
|
3 |
import os
|
|
|
|
|
|
|
4 |
import torch
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
diarization = pipeline(
|
41 |
-
audio_path,
|
42 |
-
min_speakers=min_speakers if min_speakers > 0 else None,
|
43 |
-
max_speakers=max_speakers if max_speakers > 0 else None
|
44 |
-
)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
except Exception as e:
|
63 |
-
|
64 |
-
|
65 |
-
# 讬爪讬专转 诪诪砖拽 诪砖转诪砖
|
66 |
-
demo = gr.Interface(
|
67 |
-
fn=process_audio,
|
68 |
-
inputs=[
|
69 |
-
gr.Audio(
|
70 |
-
label="拽讜讘抓 讗讜讚讬讜",
|
71 |
-
source="upload",
|
72 |
-
type="filepath"
|
73 |
-
),
|
74 |
-
gr.Number(
|
75 |
-
label="诪讬谞讬诪讜诐 讚讜讘专讬诐 (讗讜驻爪讬讜谞诇讬)",
|
76 |
-
value=0,
|
77 |
-
minimum=0,
|
78 |
-
step=1
|
79 |
-
),
|
80 |
-
gr.Number(
|
81 |
-
label="诪拽住讬诪讜诐 讚讜讘专讬诐 (讗讜驻爪讬讜谞诇讬)",
|
82 |
-
value=0,
|
83 |
-
minimum=0,
|
84 |
-
step=1
|
85 |
-
)
|
86 |
-
],
|
87 |
-
outputs=gr.Textbox(
|
88 |
-
label="转讜爪讗讜转 讛讝讬讛讜讬",
|
89 |
-
lines=10
|
90 |
-
),
|
91 |
-
title="讝讬讛讜讬 讚讜讘专讬诐 讘讛拽诇讟讜转",
|
92 |
-
description="""
|
93 |
-
讛注诇讛 拽讜讘抓 讗讜讚讬讜 诇讝讬讛讜讬 讛讚讜讘专讬诐 讛砖讜谞讬诐 讜讛讝诪谞讬诐 砖诇讛诐.
|
94 |
-
|
95 |
-
讛注专讜转:
|
96 |
-
- 讗诐 讬讚讜注 诇讱 诪住驻专 讛讚讜讘专讬诐, 讛讝谉 讗讜转讜 讻讚讬 诇砖驻专 讗转 讛讚讬讜拽
|
97 |
-
- 转讜诪讱 讘驻讜专诪讟讬诐: WAV, MP3, FLAC
|
98 |
-
- 诪讜诪诇抓 诇讛砖转诪砖 讘讛拽诇讟讜转 讘讗讬讻讜转 讟讜讘讛
|
99 |
-
- 诪砖讱 诪拽住讬诪诇讬: 2 砖注讜转
|
100 |
-
""",
|
101 |
-
examples=[
|
102 |
-
["example.wav", 2, 4],
|
103 |
-
["interview.mp3", 2, 2]
|
104 |
-
]
|
105 |
-
)
|
106 |
-
|
107 |
-
if __name__ == "__main__":
|
108 |
-
# 讛讚驻住转 诪讬讚注 注诇 讛住讘讬讘讛
|
109 |
-
space_name = os.getenv('SPACE_ID', 'unknown')
|
110 |
-
print(f"Space name: {space_name}")
|
111 |
-
print(f"GPU available: {torch.cuda.is_available()}")
|
112 |
-
|
113 |
-
# 讛驻注诇转 讛诪诪砖拽
|
114 |
-
demo.launch(
|
115 |
-
share=True,
|
116 |
-
enable_queue=True,
|
117 |
-
debug=True
|
118 |
-
)
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import warnings
|
5 |
import torch
|
6 |
+
import numpy as np
|
7 |
+
from typing import Optional, Union, Dict
|
8 |
|
9 |
+
# 讛讙讚专转 logging
|
10 |
+
logging.basicConfig(
|
11 |
+
level=logging.INFO,
|
12 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
13 |
+
)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
# 讛转注诇诪讜转 诪讗讝讛专讜转 诪讬讜转专讜转
|
17 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
18 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
19 |
+
|
20 |
+
try:
|
21 |
+
import gradio as gr
|
22 |
+
from pyannote.audio import Pipeline
|
23 |
+
except ImportError as e:
|
24 |
+
logger.error(f"砖讙讬讗讛 讘讟注讬谞转 住驻专讬讜转: {str(e)}")
|
25 |
+
sys.exit(1)
|
26 |
+
|
27 |
+
class DiarizationPipeline:
|
28 |
+
def __init__(self):
|
29 |
+
self.pipeline = None
|
30 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
logger.info(f"Using device: {self.device}")
|
32 |
+
|
33 |
+
def initialize(self) -> Optional[str]:
|
34 |
+
"""诪讗转讞诇 讗转 讛驻讬讬驻诇讬讬谉 注诐 讟讬驻讜诇 砖讙讬讗讜转"""
|
35 |
+
try:
|
36 |
+
hf_token = os.getenv('HF_TOKEN')
|
37 |
+
if not hf_token:
|
38 |
+
return "讞住专 讟讜拽谉 HF_TOKEN. 讗谞讗 讛讙讚专 讗讜转讜 讘讛讙讚专讜转."
|
39 |
+
|
40 |
+
self.pipeline = Pipeline.from_pretrained(
|
41 |
+
"pyannote/[email protected]",
|
42 |
+
use_auth_token=hf_token
|
43 |
+
)
|
44 |
|
45 |
+
if self.device == "cuda":
|
46 |
+
self.pipeline = self.pipeline.to(torch.device("cuda"))
|
47 |
+
|
48 |
+
return None # 讗转讞讜诇 讛爪诇讬讞
|
49 |
+
|
50 |
+
except Exception as e:
|
51 |
+
error_msg = f"砖讙讬讗讛 讘讗转讞讜诇 讛诪讜讚诇: {str(e)}"
|
52 |
+
logger.error(error_msg)
|
53 |
+
return error_msg
|
54 |
+
|
55 |
+
def process_audio(
|
56 |
+
self,
|
57 |
+
audio_path: str,
|
58 |
+
min_speakers: Optional[int] = None,
|
59 |
+
max_speakers: Optional[int] = None
|
60 |
+
) -> Dict[str, Union[str, float, int]]:
|
61 |
+
"""诪注讘讚 拽讜讘抓 讗讜讚讬讜 讜诪讞讝讬专 转讜爪讗讜转 诪驻讜专讟讜转"""
|
62 |
|
63 |
+
try:
|
64 |
+
if not os.path.exists(audio_path):
|
65 |
+
return {"error": "拽讜讘抓 讛讗讜讚讬讜 诇讗 谞诪爪讗"}
|
66 |
+
|
67 |
+
file_size = os.path.getsize(audio_path) / (1024 * 1024) # MB
|
68 |
+
if file_size > 100: # 讛讙讘诇转 讙讜讚诇 拽讜讘抓
|
69 |
+
return {"error": f"讙讜讚诇 讛拽讜讘抓 ({file_size:.1f}MB) 讙讚讜诇 诪讚讬. 讛诪拽住讬诪讜诐 讛讜讗 100MB"}
|
70 |
+
|
71 |
+
# 讜讬讚讜讗 砖讛驻讬讬驻诇讬讬谉 诪讗讜转讞诇
|
72 |
+
if self.pipeline is None:
|
73 |
+
init_error = self.initialize()
|
74 |
+
if init_error:
|
75 |
+
return {"error": init_error}
|
76 |
+
|
77 |
+
# 注讬讘讜讚 讛讗讜讚讬讜
|
78 |
+
diarization = self.pipeline(
|
79 |
+
audio_path,
|
80 |
+
min_speakers=min_speakers if min_speakers and min_speakers > 0 else None,
|
81 |
+
max_speakers=max_speakers if max_speakers and max_speakers > 0 else None
|
82 |
+
)
|
83 |
+
|
84 |
+
# 注讬讘讜讚 讛转讜爪讗讜转
|
85 |
+
segments = []
|
86 |
+
speakers = set()
|
87 |
+
total_duration = 0
|
88 |
+
|
89 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
90 |
+
segment = {
|
91 |
+
"start": turn.start,
|
92 |
+
"end": turn.end,
|
93 |
+
"duration": turn.duration,
|
94 |
+
"speaker": speaker
|
95 |
+
}
|
96 |
+
segments.append(segment)
|
97 |
+
speakers.add(speaker)
|
98 |
+
total_duration += turn.duration
|
99 |
+
|
100 |
+
# 讬爪讬专转 驻诇讟 诪注讜爪讘
|
101 |
+
output_text = "转讜爪讗讜转 讝讬讛讜讬 讛讚讜讘专讬诐:\n\n"
|
102 |
|
103 |
+
for segment in segments:
|
104 |
+
output_text += (
|
105 |
+
f"[{segment['start']:.1f}s -> {segment['end']:.1f}s] "
|
106 |
+
f"{segment['speaker']}\n"
|
107 |
+
)
|
108 |
|
109 |
+
# 讛讜住驻转 住讟讟讬住讟讬拽讜转
|
110 |
+
output_text += f"\n住讬讻讜诐:\n"
|
111 |
+
output_text += f"诪住驻专 讚讜讘专讬诐 砖讝讜讛讜: {len(speakers)}\n"
|
112 |
+
output_text += f"诪砖讱 讻讜诇诇: {total_duration:.1f} 砖谞讬讜转\n"
|
113 |
+
output_text += f"讙讜讚诇 讛拽讜讘抓: {file_size:.1f}MB\n"
|
114 |
|
115 |
+
if min_speakers or max_speakers:
|
116 |
+
output_text += f"讛讙讘诇讜转 砖讛讜讙讚专讜: "
|
117 |
+
if min_speakers:
|
118 |
+
output_text += f"诪讬谞讬诪讜诐 {min_speakers} 讚讜讘专讬诐, "
|
119 |
+
if max_speakers:
|
120 |
+
output_text += f"诪拽住讬诪讜诐 {max_speakers} 讚讜讘专讬诐"
|
121 |
+
output_text += "\n"
|
122 |
+
|
123 |
+
return {
|
124 |
+
"text": output_text,
|
125 |
+
"num_speakers": len(speakers),
|
126 |
+
"duration": total_duration,
|
127 |
+
"file_size": file_size,
|
128 |
+
"segments": segments
|
129 |
+
}
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
error_msg = f"砖讙讬讗讛 讘注讬讘讜讚 讛讗讜讚讬讜: {str(e)}"
|
133 |
+
logger.error(error_msg)
|
134 |
+
return {"error": error_msg}
|
135 |
+
|
136 |
+
# 讬爪讬专转 讛诪注讟驻转 诇诪诪砖拽 诪砖转诪砖
|
137 |
+
def create_interface(pipeline: DiarizationPipeline) -> gr.Interface:
|
138 |
+
def process_wrapper(audio_path, min_speakers, max_speakers):
|
139 |
+
if audio_path is None:
|
140 |
+
return "诇讗 谞讘讞专 拽讜讘抓 讗讜讚讬讜"
|
141 |
|
142 |
+
result = pipeline.process_audio(audio_path, min_speakers, max_speakers)
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
+
if "error" in result:
|
145 |
+
return f"砖讙讬讗讛: {result['error']}"
|
146 |
+
return result["text"]
|
147 |
+
|
148 |
+
return gr.Interface(
|
149 |
+
fn=process_wrapper,
|
150 |
+
inputs=[
|
151 |
+
gr.Audio(
|
152 |
+
label="拽讜讘抓 讗讜讚讬讜",
|
153 |
+
source="upload",
|
154 |
+
type="filepath"
|
155 |
+
),
|
156 |
+
gr.Number(
|
157 |
+
label="诪讬谞讬诪讜诐 讚讜讘专讬诐 (讗讜驻爪讬讜谞诇讬)",
|
158 |
+
value=0,
|
159 |
+
minimum=0,
|
160 |
+
step=1
|
161 |
+
),
|
162 |
+
gr.Number(
|
163 |
+
label="诪拽住讬诪讜诐 讚讜讘专讬诐 (讗讜驻爪讬讜谞诇讬)",
|
164 |
+
value=0,
|
165 |
+
minimum=0,
|
166 |
+
step=1
|
167 |
+
)
|
168 |
+
],
|
169 |
+
outputs=gr.Textbox(
|
170 |
+
label="转讜爪讗讜转 讛讝讬讛讜讬",
|
171 |
+
lines=10
|
172 |
+
),
|
173 |
+
title="讝讬讛讜讬 讚讜讘专讬诐 讘讛拽诇讟讜转",
|
174 |
+
description="""
|
175 |
+
讛注诇讛 拽讜讘抓 讗讜讚讬讜 诇讝讬讛讜讬 讛讚讜讘专讬诐 讛砖讜谞讬诐 讜讛讝诪谞讬诐 砖诇讛诐.
|
176 |
|
177 |
+
讛注专讜转:
|
178 |
+
- 转讜诪讱 讘驻讜专诪讟讬诐: WAV, MP3, FLAC
|
179 |
+
- 讙讜讚诇 拽讜讘抓 诪拽住讬诪诇讬: 100MB
|
180 |
+
- 诪讜诪诇抓 诇讛砖转诪砖 讘讛拽诇讟讜转 讘讗讬讻讜转 讟讜讘讛
|
181 |
+
- 讗诐 讬讚讜注 诇讱 诪住驻专 讛讚讜讘专讬诐, 讛讝谉 讗讜转讜 诇砖讬驻讜专 讛讚讬讜拽
|
182 |
+
""",
|
183 |
+
examples=[
|
184 |
+
["example.wav", 2, 4],
|
185 |
+
["interview.mp3", 2, 2]
|
186 |
+
],
|
187 |
+
allow_flagging="never",
|
188 |
+
theme="default"
|
189 |
+
)
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
try:
|
193 |
+
# 讛讚驻住转 诪讬讚注 注诇 讛住讘讬讘讛
|
194 |
+
logger.info(f"Python version: {sys.version}")
|
195 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
196 |
+
logger.info(f"NumPy version: {np.__version__}")
|
197 |
+
logger.info(f"Space ID: {os.getenv('SPACE_ID', 'unknown')}")
|
198 |
+
logger.info(f"GPU available: {torch.cuda.is_available()}")
|
199 |
|
200 |
+
if torch.cuda.is_available():
|
201 |
+
logger.info(f"GPU model: {torch.cuda.get_device_name(0)}")
|
202 |
+
|
203 |
+
# 讬爪讬专转 讛驻讬讬驻诇讬讬谉 讜讛诪诪砖拽
|
204 |
+
pipeline = DiarizationPipeline()
|
205 |
+
demo = create_interface(pipeline)
|
206 |
|
207 |
+
# 讛驻注诇转 讛诪诪砖拽
|
208 |
+
demo.launch(
|
209 |
+
share=True,
|
210 |
+
enable_queue=True,
|
211 |
+
max_threads=4,
|
212 |
+
debug=True
|
213 |
+
)
|
214 |
|
215 |
except Exception as e:
|
216 |
+
logger.error(f"砖讙讬讗讛 拽专讬讟讬转: {str(e)}")
|
217 |
+
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|