Spaces:
Running
Running
File size: 6,122 Bytes
67c46fd 930f36f 67c46fd 930f36f 67c46fd 930f36f 67c46fd 930f36f 67c46fd 930f36f 67c46fd 930f36f 67c46fd 930f36f 67c46fd 930f36f 67c46fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
import time
from pathlib import Path
import torchaudio
from stepaudio import StepAudio
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
CACHE_DIR = "/tmp/gradio/"
system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"}
class CustomAsr:
def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"):
self.model = AutoModel(
model=model_name,
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
device=device,
)
def run(self, audio_path):
res = self.model.generate(
input=audio_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
return text
def add_message(chatbot, history, mic, text):
if not mic and not text:
return chatbot, history, "Input is empty"
if text:
chatbot.append({"role": "user", "content": text})
history.append({"role": "user", "content": text})
elif mic and Path(mic).exists():
chatbot.append({"role": "user", "content": {"path": mic}})
history.append({"role": "user", "content": {"type":"audio", "audio": mic}})
print(f"{history=}")
return chatbot, history, None
def reset_state():
"""Reset the chat history."""
return [], [system_promtp]
def save_tmp_audio(audio, sr):
import tempfile
with tempfile.NamedTemporaryFile(
dir=CACHE_DIR, delete=False, suffix=".wav"
) as temp_audio:
temp_audio_path = temp_audio.name
torchaudio.save(temp_audio_path, audio, sr)
return temp_audio.name
def predict(chatbot, history, audio_model, asr_model):
"""Generate a response from the model."""
try:
is_input_audio = False
user_audio_path = None
# 检测用户输入的是音频还是文本
if isinstance(history[-1]["content"], dict):
is_input_audio = True
user_audio_path = history[-1]["content"]["audio"]
text, audio, sr = audio_model(history, "闫雨婷")
print(f"predict {text=}")
audio_path = save_tmp_audio(audio, sr)
# 缓存用户语音的 asr 文本结果为了加速下一次推理
if is_input_audio:
asr_text = asr_model.run(user_audio_path)
chatbot.append({"role": "user", "content": asr_text})
history[-1]["content"] = asr_text
print(f"{asr_text=}")
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
chatbot.append({"role": "assistant", "content": text})
history.append({"role": "assistant", "content": text})
except Exception as e:
print(e)
gr.Warning(f"Some error happend, retry submit")
return chatbot, history
def _launch_demo(args, audio_model, asr_model):
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
chatbot = gr.Chatbot(
elem_id="chatbot",
avatar_images=["assets/user.png", "assets/assistant.png"],
min_height=800,
type="messages",
)
# 保存 chat 历史,不需要每次再重新拼格式
history = gr.State([system_promtp])
mic = gr.Audio(type="filepath")
text = gr.Textbox(placeholder="Enter message ...")
with gr.Row():
clean_btn = gr.Button("🧹 Clear History (清除历史)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
submit_btn = gr.Button("🚀 Submit")
def on_submit(chatbot, history, mic, text):
chatbot, history, error = add_message(
chatbot, history, mic, text
)
if error:
gr.Warning(error) # 显示警告消息
return chatbot, history, None, None
else:
chatbot, history = predict(chatbot, history, audio_model, asr_model)
return chatbot, history, None, None
submit_btn.click(
fn=on_submit,
inputs=[chatbot, history, mic, text],
outputs=[chatbot, history, mic, text],
concurrency_limit=4,
concurrency_id="gpu_queue",
)
clean_btn.click(
reset_state,
outputs=[chatbot, history],
show_progress=True,
)
def regenerate(chatbot, history):
while chatbot and chatbot[-1]["role"] == "assistant":
chatbot.pop()
while history and history[-1]["role"] == "assistant":
print(f"discard {history[-1]}")
history.pop()
return predict(chatbot, history, audio_model, asr_model)
regen_btn.click(
regenerate,
[chatbot, history],
[chatbot, history],
show_progress=True,
concurrency_id="gpu_queue",
)
demo.queue().launch(
share=False,
server_port=args.server_port,
server_name=args.server_name,
)
if __name__ == "__main__":
from argparse import ArgumentParser
import os
parser = ArgumentParser()
parser.add_argument("--model-path", type=str, required=True, help="Model path.")
parser.add_argument(
"--server-port", type=int, default=7860, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
)
args = parser.parse_args()
audio_model = StepAudio(
tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"),
tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"),
llm_path=os.path.join(args.model_path, "Step-Audio-Chat"),
)
asr_model = CustomAsr()
_launch_demo(args, audio_model, asr_model)
|