Step-Audio / app.py
martin
initial
67c46fd
raw
history blame
5.67 kB
import gradio as gr
import time
from pathlib import Path
import torchaudio
from stepaudio import StepAudio
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
CACHE_DIR = "/tmp/gradio/"
system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"}
class CustomAsr:
def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"):
self.model = AutoModel(
model=model_name,
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
device=device,
)
def run(self, audio_path):
res = self.model.generate(
input=audio_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
return text
def add_message(chatbot, history, mic, text, asr_model):
if not mic and not text:
return chatbot, history, "Input is empty"
if text:
chatbot.append({"role": "user", "content": text})
history.append({"role": "user", "content": text})
elif mic and Path(mic).exists():
chatbot.append({"role": "user", "content": {"path": mic}})
# 使用用户语音的 asr 结果为了加速推理
text = asr_model.run(mic)
chatbot.append({"role": "user", "content": text})
history.append({"role": "user", "content": text})
print(f"{history=}")
return chatbot, history, None
def reset_state():
"""Reset the chat history."""
return [], [system_promtp]
def save_tmp_audio(audio, sr):
import tempfile
with tempfile.NamedTemporaryFile(
dir=CACHE_DIR, delete=False, suffix=".wav"
) as temp_audio:
temp_audio_path = temp_audio.name
torchaudio.save(temp_audio_path, audio, sr)
return temp_audio.name
def predict(chatbot, history, audio_model):
"""Generate a response from the model."""
try:
text, audio, sr = audio_model(history, "闫雨婷")
print(f"predict {text=}")
audio_path = save_tmp_audio(audio, sr)
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
chatbot.append({"role": "assistant", "content": text})
history.append({"role": "assistant", "content": text})
except Exception as e:
print(e)
gr.Warning(f"Some error happend, retry submit")
return chatbot, history
def _launch_demo(args, audio_model, asr_model):
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
chatbot = gr.Chatbot(
elem_id="chatbot",
avatar_images=["assets/user.png", "assets/assistant.png"],
min_height=800,
type="messages",
)
# 保存 chat 历史,不需要每次再重新拼格式
history = gr.State([system_promtp])
mic = gr.Audio(type="filepath")
text = gr.Textbox(placeholder="Enter message ...")
with gr.Row():
clean_btn = gr.Button("🧹 Clear History (清除历史)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
submit_btn = gr.Button("🚀 Submit")
def on_submit(chatbot, history, mic, text):
chatbot, history, error = add_message(
chatbot, history, mic, text, asr_model
)
if error:
gr.Warning(error) # 显示警告消息
return chatbot, history, None, None
else:
chatbot, history = predict(chatbot, history, audio_model)
return chatbot, history, None, None
submit_btn.click(
fn=on_submit,
inputs=[chatbot, history, mic, text],
outputs=[chatbot, history, mic, text],
concurrency_limit=4,
concurrency_id="gpu_queue",
)
clean_btn.click(
reset_state,
outputs=[chatbot, history],
show_progress=True,
)
def regenerate(chatbot, history):
while chatbot and chatbot[-1]["role"] == "assistant":
chatbot.pop()
while history and history[-1]["role"] == "assistant":
print(f"discard {history[-1]}")
history.pop()
return predict(chatbot, history, audio_model)
regen_btn.click(
regenerate,
[chatbot, history],
[chatbot, history],
show_progress=True,
concurrency_id="gpu_queue",
)
demo.queue().launch(
share=False,
server_port=args.server_port,
server_name=args.server_name,
)
if __name__ == "__main__":
from argparse import ArgumentParser
import os
parser = ArgumentParser()
parser.add_argument("--model-path", type=str, required=True, help="Model path.")
parser.add_argument(
"--server-port", type=int, default=7860, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
)
args = parser.parse_args()
audio_model = StepAudio(
tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"),
tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"),
llm_path=os.path.join(args.model_path, "Step-Audio-Chat"),
)
asr_model = CustomAsr()
_launch_demo(args, audio_model, asr_model)