File size: 6,122 Bytes
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930f36f
67c46fd
 
 
 
 
 
 
 
930f36f
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930f36f
67c46fd
 
930f36f
 
 
 
 
 
67c46fd
 
 
930f36f
 
 
 
 
 
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930f36f
67c46fd
 
 
 
 
930f36f
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930f36f
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
import time
from pathlib import Path
import torchaudio
from stepaudio import StepAudio

from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

CACHE_DIR = "/tmp/gradio/"
system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"}


class CustomAsr:
    def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"):
        self.model = AutoModel(
            model=model_name,
            vad_model="fsmn-vad",
            vad_kwargs={"max_single_segment_time": 30000},
            device=device,
        )

    def run(self, audio_path):
        res = self.model.generate(
            input=audio_path,
            cache={},
            language="auto",  # "zh", "en", "yue", "ja", "ko", "nospeech"
            use_itn=True,
            batch_size_s=60,
            merge_vad=True,  #
            merge_length_s=15,
        )
        text = rich_transcription_postprocess(res[0]["text"])
        return text


def add_message(chatbot, history, mic, text):
    if not mic and not text:
        return chatbot, history, "Input is empty"

    if text:
        chatbot.append({"role": "user", "content": text})
        history.append({"role": "user", "content": text})
    elif mic and Path(mic).exists():
        chatbot.append({"role": "user", "content": {"path": mic}})
        history.append({"role": "user", "content": {"type":"audio", "audio": mic}})

    print(f"{history=}")
    return chatbot, history, None


def reset_state():
    """Reset the chat history."""
    return [], [system_promtp]


def save_tmp_audio(audio, sr):
    import tempfile

    with tempfile.NamedTemporaryFile(
        dir=CACHE_DIR, delete=False, suffix=".wav"
    ) as temp_audio:
        temp_audio_path = temp_audio.name
        torchaudio.save(temp_audio_path, audio, sr)

    return temp_audio.name


def predict(chatbot, history, audio_model, asr_model):
    """Generate a response from the model."""
    try:
        is_input_audio = False
        user_audio_path = None
        # 检测用户输入的是音频还是文本
        if isinstance(history[-1]["content"], dict):
            is_input_audio = True
            user_audio_path = history[-1]["content"]["audio"]
        text, audio, sr = audio_model(history, "闫雨婷")
        print(f"predict {text=}")
        audio_path = save_tmp_audio(audio, sr)
        # 缓存用户语音的 asr 文本结果为了加速下一次推理
        if is_input_audio:
            asr_text = asr_model.run(user_audio_path)
            chatbot.append({"role": "user", "content": asr_text})
            history[-1]["content"] = asr_text
            print(f"{asr_text=}")
        chatbot.append({"role": "assistant", "content": {"path": audio_path}})
        chatbot.append({"role": "assistant", "content": text})
        history.append({"role": "assistant", "content": text})
    except Exception as e:
        print(e)
        gr.Warning(f"Some error happend, retry submit")
    return chatbot, history


def _launch_demo(args, audio_model, asr_model):
    with gr.Blocks(delete_cache=(86400, 86400)) as demo:
        gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
        chatbot = gr.Chatbot(
            elem_id="chatbot",
            avatar_images=["assets/user.png", "assets/assistant.png"],
            min_height=800,
            type="messages",
        )
        # 保存 chat 历史,不需要每次再重新拼格式
        history = gr.State([system_promtp])
        mic = gr.Audio(type="filepath")
        text = gr.Textbox(placeholder="Enter message ...")

        with gr.Row():
            clean_btn = gr.Button("🧹 Clear History (清除历史)")
            regen_btn = gr.Button("🤔️ Regenerate (重试)")
            submit_btn = gr.Button("🚀 Submit")

        def on_submit(chatbot, history, mic, text):
            chatbot, history, error = add_message(
                chatbot, history, mic, text
            )
            if error:
                gr.Warning(error)  # 显示警告消息
                return chatbot, history, None, None
            else:
                chatbot, history = predict(chatbot, history, audio_model, asr_model)
                return chatbot, history, None, None

        submit_btn.click(
            fn=on_submit,
            inputs=[chatbot, history, mic, text],
            outputs=[chatbot, history, mic, text],
            concurrency_limit=4,
            concurrency_id="gpu_queue",
        )
        clean_btn.click(
            reset_state,
            outputs=[chatbot, history],
            show_progress=True,
        )

        def regenerate(chatbot, history):
            while chatbot and chatbot[-1]["role"] == "assistant":
                chatbot.pop()
            while history and history[-1]["role"] == "assistant":
                print(f"discard {history[-1]}")
                history.pop()
            return predict(chatbot, history, audio_model, asr_model)

        regen_btn.click(
            regenerate,
            [chatbot, history],
            [chatbot, history],
            show_progress=True,
            concurrency_id="gpu_queue",
        )

    demo.queue().launch(
        share=False,
        server_port=args.server_port,
        server_name=args.server_name,
    )


if __name__ == "__main__":
    from argparse import ArgumentParser
    import os

    parser = ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True, help="Model path.")
    parser.add_argument(
        "--server-port", type=int, default=7860, help="Demo server port."
    )
    parser.add_argument(
        "--server-name", type=str, default="0.0.0.0", help="Demo server name."
    )
    args = parser.parse_args()

    audio_model = StepAudio(
        tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"),
        tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"),
        llm_path=os.path.join(args.model_path, "Step-Audio-Chat"),
    )
    asr_model = CustomAsr()
    _launch_demo(args, audio_model, asr_model)