Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import tempfile | |
import subprocess | |
from transformers import pipeline | |
import torch | |
from zipfile import ZipFile | |
from fastapi import FastAPI | |
app = FastAPI() | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if torch.cuda.is_available(): | |
model_id = "openai/whisper-small.en" | |
else: | |
model_id = "openai/whisper-tiny.en" | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model_id, | |
chunk_length_s=30, | |
device=device, | |
) | |
def support_gbk(zip_file: ZipFile): | |
name_to_info = zip_file.NameToInfo | |
# copy map first | |
for name, info in name_to_info.copy().items(): | |
real_name = name.encode("cp437").decode("gbk") | |
if real_name != name: | |
info.filename = real_name | |
del name_to_info[name] | |
name_to_info[real_name] = info | |
return zip_file | |
def handel(f): | |
if not f: | |
raise gr.Error("请上传文件") | |
if f.name.endswith(".zip"): | |
with support_gbk(ZipFile(f.name, "r")) as z: | |
dir = tempfile.TemporaryDirectory() | |
z.extractall(path=dir.name) | |
return handel_files( | |
[ | |
os.path.join(filepath, filename) | |
for filepath, _, filenames in os.walk(dir.name) | |
for filename in filenames | |
] | |
) | |
else: | |
return handel_files([f.name]) | |
def ffmpeg_convert(file_input, file_output): | |
if subprocess.run(["ffmpeg", "-y", "-i", file_input, file_output]).returncode: | |
raise gr.Error("ffmpeg_convert 失败, 请检查文件格式是否正确") | |
def handel_files(f_ls): | |
files = [] | |
for file in f_ls: | |
file_output=None | |
if file.endswith(".m4a"): | |
file_output = file.replace(".m4a", ".wav") | |
ffmpeg_convert(file, file_output) | |
elif file.endswith(".mp3"): | |
file_output = file.replace(".mp3", ".wav") | |
ffmpeg_convert(file, file_output) | |
elif file.endswith(".wav"): | |
file_output = file | |
ffmpeg_convert(file, file_output) | |
if file_output: | |
files.append(file_output) | |
else: | |
gr.Warning(f"存在不合法文件{os.path.basename(file)},已跳过处理") | |
ret = [] | |
for file in files: | |
ret.append(whisper_handler(file)) | |
return "\n\n".join(ret) | |
def whisper_handler(file): | |
file_name = os.path.basename(file) | |
gr.Info(f"处理文件 - {file_name}") | |
return pipe(file)["text"] | |
with gr.Blocks() as blocks: | |
f = gr.File(file_types=[".zip", ".mp3", ".wav", ".m4a"]) | |
b = gr.Button(value="提交") | |
t = gr.Textbox(label="结果") | |
b.click(handel, inputs=f, outputs=t) | |
blocks.queue(max_size=3) | |
app = gr.mount_gradio_app(app, blocks, path="/") | |