StevenChen16 commited on
Commit
67d22ff
·
verified ·
1 Parent(s): 589b5f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,28 +1,32 @@
1
  import torch
2
  import whisperx
3
  import gradio as gr
 
4
 
5
  # 检测设备类型
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
-
8
- # 设置计算类型
9
  compute_type = "float16" if device == "cuda" else "int8"
10
 
 
 
 
 
 
11
  # 加载 WhisperX 模型
12
- model = whisperx.load_model("large-v3", device=device, compute_type=compute_type)
13
 
14
  def transcribe(audio_path):
15
- # 使用WhisperX进行转录
16
  result = model.transcribe(audio_path)
17
  return result['text']
18
 
19
- # 创建Gradio接口
20
  iface = gr.Interface(
21
  fn=transcribe,
22
- inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
23
  outputs="text",
24
  title="WhisperX 语音转文字",
25
- description="上传音频文件,使用WhisperX模型进行转录。"
26
  )
27
 
28
  if __name__ == "__main__":
 
1
  import torch
2
  import whisperx
3
  import gradio as gr
4
+ import spaces
5
 
6
  # 检测设备类型
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
8
  compute_type = "float16" if device == "cuda" else "int8"
9
 
10
+ @spaces.GPU # 确保在 GPU 环境下运行
11
+ def load_model():
12
+ # 在 GPU 可用时加载模型
13
+ return whisperx.load_model("large-v3", device=device, compute_type=compute_type)
14
+
15
  # 加载 WhisperX 模型
16
+ model = load_model()
17
 
18
  def transcribe(audio_path):
19
+ # 使用 WhisperX 进行转录
20
  result = model.transcribe(audio_path)
21
  return result['text']
22
 
23
+ # 创建 Gradio 接口
24
  iface = gr.Interface(
25
  fn=transcribe,
26
+ inputs=gr.Audio(source=["upload","microphone"], type="filepath"),
27
  outputs="text",
28
  title="WhisperX 语音转文字",
29
+ description="上传音频文件,使用 WhisperX 模型进行转录。"
30
  )
31
 
32
  if __name__ == "__main__":