Step-Audio / app.py
martin
fix aqta predict
930f36f
raw
history blame
6.12 kB
import gradio as gr
import time
from pathlib import Path
import torchaudio
from stepaudio import StepAudio
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
CACHE_DIR = "/tmp/gradio/"
system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"}
class CustomAsr:
def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"):
self.model = AutoModel(
model=model_name,
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
device=device,
)
def run(self, audio_path):
res = self.model.generate(
input=audio_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
return text
def add_message(chatbot, history, mic, text):
if not mic and not text:
return chatbot, history, "Input is empty"
if text:
chatbot.append({"role": "user", "content": text})
history.append({"role": "user", "content": text})
elif mic and Path(mic).exists():
chatbot.append({"role": "user", "content": {"path": mic}})
history.append({"role": "user", "content": {"type":"audio", "audio": mic}})
print(f"{history=}")
return chatbot, history, None
def reset_state():
"""Reset the chat history."""
return [], [system_promtp]
def save_tmp_audio(audio, sr):
import tempfile
with tempfile.NamedTemporaryFile(
dir=CACHE_DIR, delete=False, suffix=".wav"
) as temp_audio:
temp_audio_path = temp_audio.name
torchaudio.save(temp_audio_path, audio, sr)
return temp_audio.name
def predict(chatbot, history, audio_model, asr_model):
"""Generate a response from the model."""
try:
is_input_audio = False
user_audio_path = None
# 检测用户输入的是音频还是文本
if isinstance(history[-1]["content"], dict):
is_input_audio = True
user_audio_path = history[-1]["content"]["audio"]
text, audio, sr = audio_model(history, "闫雨婷")
print(f"predict {text=}")
audio_path = save_tmp_audio(audio, sr)
# 缓存用户语音的 asr 文本结果为了加速下一次推理
if is_input_audio:
asr_text = asr_model.run(user_audio_path)
chatbot.append({"role": "user", "content": asr_text})
history[-1]["content"] = asr_text
print(f"{asr_text=}")
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
chatbot.append({"role": "assistant", "content": text})
history.append({"role": "assistant", "content": text})
except Exception as e:
print(e)
gr.Warning(f"Some error happend, retry submit")
return chatbot, history
def _launch_demo(args, audio_model, asr_model):
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
chatbot = gr.Chatbot(
elem_id="chatbot",
avatar_images=["assets/user.png", "assets/assistant.png"],
min_height=800,
type="messages",
)
# 保存 chat 历史,不需要每次再重新拼格式
history = gr.State([system_promtp])
mic = gr.Audio(type="filepath")
text = gr.Textbox(placeholder="Enter message ...")
with gr.Row():
clean_btn = gr.Button("🧹 Clear History (清除历史)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
submit_btn = gr.Button("🚀 Submit")
def on_submit(chatbot, history, mic, text):
chatbot, history, error = add_message(
chatbot, history, mic, text
)
if error:
gr.Warning(error) # 显示警告消息
return chatbot, history, None, None
else:
chatbot, history = predict(chatbot, history, audio_model, asr_model)
return chatbot, history, None, None
submit_btn.click(
fn=on_submit,
inputs=[chatbot, history, mic, text],
outputs=[chatbot, history, mic, text],
concurrency_limit=4,
concurrency_id="gpu_queue",
)
clean_btn.click(
reset_state,
outputs=[chatbot, history],
show_progress=True,
)
def regenerate(chatbot, history):
while chatbot and chatbot[-1]["role"] == "assistant":
chatbot.pop()
while history and history[-1]["role"] == "assistant":
print(f"discard {history[-1]}")
history.pop()
return predict(chatbot, history, audio_model, asr_model)
regen_btn.click(
regenerate,
[chatbot, history],
[chatbot, history],
show_progress=True,
concurrency_id="gpu_queue",
)
demo.queue().launch(
share=False,
server_port=args.server_port,
server_name=args.server_name,
)
if __name__ == "__main__":
from argparse import ArgumentParser
import os
parser = ArgumentParser()
parser.add_argument("--model-path", type=str, required=True, help="Model path.")
parser.add_argument(
"--server-port", type=int, default=7860, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
)
args = parser.parse_args()
audio_model = StepAudio(
tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"),
tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"),
llm_path=os.path.join(args.model_path, "Step-Audio-Chat"),
)
asr_model = CustomAsr()
_launch_demo(args, audio_model, asr_model)