MingLi
bug fix
6c97f40
raw
history blame
2.87 kB
import gradio as gr
import os
import tempfile
import subprocess
from transformers import pipeline
import torch
from zipfile import ZipFile
from fastapi import FastAPI
app = FastAPI()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
model_id = "openai/whisper-small.en"
else:
model_id = "openai/whisper-tiny.en"
pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
chunk_length_s=30,
device=device,
)
def support_gbk(zip_file: ZipFile):
name_to_info = zip_file.NameToInfo
# copy map first
for name, info in name_to_info.copy().items():
real_name = name.encode("cp437").decode("gbk")
if real_name != name:
info.filename = real_name
del name_to_info[name]
name_to_info[real_name] = info
return zip_file
def handel(f):
if not f:
raise gr.Error("请上传文件")
if f.name.endswith(".zip"):
with support_gbk(ZipFile(f.name, "r")) as z:
dir = tempfile.TemporaryDirectory()
z.extractall(path=dir.name)
return handel_files(
[
os.path.join(filepath, filename)
for filepath, _, filenames in os.walk(dir.name)
for filename in filenames
]
)
else:
return handel_files([f.name])
def ffmpeg_convert(file_input, file_output):
if subprocess.run(["ffmpeg", "-y", "-i", file_input, file_output]).returncode:
raise gr.Error("ffmpeg_convert 失败, 请检查文件格式是否正确")
def handel_files(f_ls):
files = []
for file in f_ls:
file_output=None
if file.endswith(".m4a"):
file_output = file.replace(".m4a", ".wav")
ffmpeg_convert(file, file_output)
elif file.endswith(".mp3"):
file_output = file.replace(".mp3", ".wav")
ffmpeg_convert(file, file_output)
elif file.endswith(".wav"):
# check wav file is valid or not
file_output = file+".wav"
ffmpeg_convert(file, file_output)
if file_output:
files.append(file_output)
else:
gr.Warning(f"存在不合法文件{os.path.basename(file)},已跳过处理")
ret = []
for file in files:
ret.append(whisper_handler(file))
return "\n\n".join(ret)
def whisper_handler(file):
file_name = os.path.basename(file)
gr.Info(f"处理文件 - {file_name.split('.')[0]}")
return pipe(file)["text"]
with gr.Blocks() as blocks:
f = gr.File(file_types=[".zip", ".mp3", ".wav", ".m4a"])
b = gr.Button(value="提交")
t = gr.Textbox(label="结果")
b.click(handel, inputs=f, outputs=t)
blocks.queue(max_size=3)
app = gr.mount_gradio_app(app, blocks, path="/")