MambaVoice / app.py
federerjiang's picture
initial update
ae466a1
raw
history blame
3.12 kB
import os
import time
import tempfile
from math import floor
from typing import Optional, List, Dict, Any
import torch
import gradio as gr
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
# configuration
MODEL_NAME = "federerjiang/mambavoice-ja-v1"
BATCH_SIZE = 16
CHUNK_LENGTH_S = 15
FILE_LIMIT_MB = 1000
TOKEN = os.environ.get('HF_TOKEN', None)
# device setting
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
device = "cuda:0"
else:
torch_dtype = torch.float32
device = "cpu"
# define the pipeline
pipe = pipeline(
model=MODEL_NAME,
chunk_length_s=CHUNK_LENGTH_S,
batch_size=BATCH_SIZE,
torch_dtype=torch_dtype,
device=device,
trust_remote_code=True,
token=TOKEN,
)
def get_prediction(inputs, prompt: Optional[str]=None):
generate_kwargs = {
"language": "japanese",
"task": "transcribe",
"length_penalty": 0,
"num_beams": 2,
}
if prompt:
generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
prediction = pipe(inputs, return_timestamps=True, generate_kwargs=generate_kwargs)
text = "".join([c['text'] for c in prediction['chunks']])
return text
def transcribe(inputs: str):
if inputs is None:
raise gr.Error("音声ファイルが送信されていません!リクエストを送信する前に、音声ファイルをアップロードまたは録音してください。")
with open(inputs, "rb") as f:
inputs = f.read()
inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
return get_prediction(inputs)
demo = gr.Blocks()
mf_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources=["microphone"], type="filepath"),
],
outputs=["text"],
# layout="horizontal",
theme="huggingface",
title=f"オーディオをMambaVoice-v1で文字起こしする",
description=f"ボタンをクリックするだけで、長時間のマイク入力やオーディオ入力を文字起こしできます!デモではMambaVoice-v1モデルを使用しており、任意の長さの音声ファイルを文字起こしすることができます。",
allow_flagging="never",
)
file_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources=["upload"], type="filepath", label="Audio file"),
],
outputs=["text"],
# layout="horizontal",
theme="huggingface",
title=f"オーディオをMambaVoice-v1で文字起こしする",
description=f"ボタンをクリックするだけで、長時間のマイク入力やオーディオ入力を文字起こしできます!デモではMambaVoice-v1モデルを使用しており、任意の長さの音声ファイルを文字起こしすることができます。",
allow_flagging="never",
)
with demo:
gr.TabbedInterface([mf_transcribe, file_transcribe], ["Microphone", "Audio file"])
demo.queue(max_size=10)
demo.launch()