出蛰 commited on
Commit
7725096
·
1 Parent(s): ed10807

add qwen-audio demo

Browse files
Files changed (1) hide show
  1. app.py +266 -0
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """A simple web interactive chat demo based on gradio."""
7
+
8
+ from argparse import ArgumentParser
9
+ from pathlib import Path
10
+
11
+ import copy
12
+ import gradio as gr
13
+ import os
14
+ import re
15
+ import secrets
16
+ import tempfile
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from transformers import GenerationConfig
19
+ # from modelscope.hub.api import HubApi
20
+ from pydub import AudioSegment
21
+ import os
22
+ # YOUR_ACCESS_TOKEN = os.getenv('YOUR_ACCESS_TOKEN')
23
+
24
+ # api = HubApi()
25
+ # api.login(YOUR_ACCESS_TOKEN)
26
+
27
+
28
+ # DEFAULT_CKPT_PATH = snapshot_download('qwen/Qwen-Audio-Chat')
29
+ DEFAULT_CKPT_PATH = "Qwen/Qwen-Audio-Chat"
30
+
31
+ def _get_args():
32
+ parser = ArgumentParser()
33
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
34
+ help="Checkpoint name or path, default to %(default)r")
35
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
36
+
37
+ parser.add_argument("--share", action="store_true", default=False,
38
+ help="Create a publicly shareable link for the interface.")
39
+ parser.add_argument("--inbrowser", action="store_true", default=False,
40
+ help="Automatically launch the interface in a new tab on the default browser.")
41
+ parser.add_argument("--server-port", type=int, default=8000,
42
+ help="Demo server port.")
43
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
44
+ help="Demo server name.")
45
+
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def _load_model_tokenizer(args):
51
+ tokenizer = AutoTokenizer.from_pretrained(
52
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
53
+ )
54
+
55
+ if args.cpu_only:
56
+ device_map = "cpu"
57
+ else:
58
+ device_map = "cuda"
59
+
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ args.checkpoint_path,
62
+ device_map=device_map,
63
+ trust_remote_code=True,
64
+ resume_download=True,
65
+ ).eval()
66
+ model.generation_config = GenerationConfig.from_pretrained(
67
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
68
+ )
69
+
70
+ return model, tokenizer
71
+
72
+
73
+ def _parse_text(text):
74
+ lines = text.split("\n")
75
+ lines = [line for line in lines if line != ""]
76
+ count = 0
77
+ for i, line in enumerate(lines):
78
+ if "```" in line:
79
+ count += 1
80
+ items = line.split("`")
81
+ if count % 2 == 1:
82
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
83
+ else:
84
+ lines[i] = f"<br></code></pre>"
85
+ else:
86
+ if i > 0:
87
+ if count % 2 == 1:
88
+ line = line.replace("`", r"\`")
89
+ line = line.replace("<", "&lt;")
90
+ line = line.replace(">", "&gt;")
91
+ line = line.replace(" ", "&nbsp;")
92
+ line = line.replace("*", "&ast;")
93
+ line = line.replace("_", "&lowbar;")
94
+ line = line.replace("-", "&#45;")
95
+ line = line.replace(".", "&#46;")
96
+ line = line.replace("!", "&#33;")
97
+ line = line.replace("(", "&#40;")
98
+ line = line.replace(")", "&#41;")
99
+ line = line.replace("$", "&#36;")
100
+ lines[i] = "<br>" + line
101
+ text = "".join(lines)
102
+ return text
103
+
104
+
105
+ def _launch_demo(args, model, tokenizer):
106
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
107
+ Path(tempfile.gettempdir()) / "gradio"
108
+ )
109
+
110
+ def predict(_chatbot, task_history):
111
+ query = task_history[-1][0]
112
+ print("User: " + _parse_text(query))
113
+ history_cp = copy.deepcopy(task_history)
114
+ full_response = ""
115
+
116
+ history_filter = []
117
+ audio_idx = 1
118
+ pre = ""
119
+ global last_audio
120
+ for i, (q, a) in enumerate(history_cp):
121
+ if isinstance(q, (tuple, list)):
122
+ last_audio = q[0]
123
+ q = f'Audio {audio_idx}: <audio>{q[0]}</audio>'
124
+ pre += q + '\n'
125
+ audio_idx += 1
126
+ else:
127
+ pre += q
128
+ history_filter.append((pre, a))
129
+ pre = ""
130
+ history, message = history_filter[:-1], history_filter[-1][0]
131
+ response, history = model.chat(tokenizer, message, history=history)
132
+ ts_pattern = r"<\|\d{1,2}\.\d+\|>"
133
+ all_time_stamps = re.findall(ts_pattern, response)
134
+ print(response)
135
+ if (len(all_time_stamps) > 0) and (len(all_time_stamps) % 2 ==0) and last_audio:
136
+ ts_float = [ float(t.replace("<|","").replace("|>","")) for t in all_time_stamps]
137
+ ts_float_pair = [ts_float[i:i + 2] for i in range(0,len(all_time_stamps),2)]
138
+ # 读取音频文件
139
+ format = os.path.splitext(last_audio)[-1].replace(".","")
140
+ audio_file = AudioSegment.from_file(last_audio, format=format)
141
+ chat_response_t = response.replace("<|", "").replace("|>", "")
142
+ chat_response = chat_response_t
143
+ temp_dir = secrets.token_hex(20)
144
+ temp_dir = Path(uploaded_file_dir) / temp_dir
145
+ temp_dir.mkdir(exist_ok=True, parents=True)
146
+ # 截取音频文件
147
+ for pair in ts_float_pair:
148
+ audio_clip = audio_file[pair[0] * 1000: pair[1] * 1000]
149
+ # 保存音频文件
150
+ name = f"tmp{secrets.token_hex(5)}.{format}"
151
+ filename = temp_dir / name
152
+ audio_clip.export(filename, format=format)
153
+ _chatbot[-1] = (_parse_text(query), chat_response)
154
+ _chatbot.append((None, (str(filename),)))
155
+ else:
156
+ _chatbot[-1] = (_parse_text(query), response)
157
+
158
+ full_response = _parse_text(response)
159
+
160
+ task_history[-1] = (query, full_response)
161
+ print("Qwen-Audio-Chat: " + _parse_text(full_response))
162
+ return _chatbot
163
+
164
+ def regenerate(_chatbot, task_history):
165
+ if not task_history:
166
+ return _chatbot
167
+ item = task_history[-1]
168
+ if item[1] is None:
169
+ return _chatbot
170
+ task_history[-1] = (item[0], None)
171
+ chatbot_item = _chatbot.pop(-1)
172
+ if chatbot_item[0] is None:
173
+ _chatbot[-1] = (_chatbot[-1][0], None)
174
+ else:
175
+ _chatbot.append((chatbot_item[0], None))
176
+ return predict(_chatbot, task_history)
177
+
178
+ def add_text(history, task_history, text):
179
+ history = history + [(_parse_text(text), None)]
180
+ task_history = task_history + [(text, None)]
181
+ return history, task_history, ""
182
+
183
+ def add_file(history, task_history, file):
184
+ history = history + [((file.name,), None)]
185
+ task_history = task_history + [((file.name,), None)]
186
+ return history, task_history
187
+
188
+ def add_mic(history, task_history, file):
189
+ if file is None:
190
+ return history, task_history
191
+ os.rename(file, file + '.wav')
192
+ print("add_mic file:", file)
193
+ print("add_mic history:", history)
194
+ print("add_mic task_history:", task_history)
195
+ # history = history + [((file.name,), None)]
196
+ # task_history = task_history + [((file.name,), None)]
197
+ task_history = task_history + [((file + '.wav',), None)]
198
+ history = history + [((file + '.wav',), None)]
199
+ print("task_history", task_history)
200
+ return history, task_history
201
+
202
+ def reset_user_input():
203
+ return gr.update(value="")
204
+
205
+ def reset_state(task_history):
206
+ task_history.clear()
207
+ return []
208
+
209
+ with gr.Blocks() as demo:
210
+ gr.Markdown("""<p align="center"><img src="https://modelscope.cn/api/v1/models/qwen/Qwen-VL-Chat/repo?Revision=master&FilePath=assets/logo.jpg&View=true" style="height: 80px"/><p>""") ## todo
211
+ gr.Markdown("""<center><font size=8>Qwen-Audio-Chat Bot</center>""")
212
+ gr.Markdown(
213
+ """\
214
+ <center><font size=3>This WebUI is based on Qwen-Audio-Chat, developed by Alibaba Cloud. </center>""")
215
+ gr.Markdown("""\
216
+ <center><font size=4>Qwen-Audio <a href="https://modelscope.cn/models/qwen/Qwen-Audio/summary">🤖 </a>
217
+ | <a href="https://huggingface.co/Qwen/Qwen-Audio">🤗</a>&nbsp |
218
+ Qwen-Audio-Chat <a href="https://modelscope.cn/models/qwen/Qwen-Audio-Chat/summary">🤖 </a> |
219
+ <a href="https://huggingface.co/Qwen/Qwen-Audio-Chat">🤗</a>&nbsp |
220
+ &nbsp<a href="https://github.com/QwenLM/Qwen-Audio">Github</a></center>""")
221
+
222
+ chatbot = gr.Chatbot(label='Qwen-Audio-Chat', elem_classes="control-height", height=750)
223
+ query = gr.Textbox(lines=2, label='Input')
224
+ task_history = gr.State([])
225
+ mic = gr.Audio(source="microphone", type="filepath")
226
+
227
+ with gr.Row():
228
+ empty_bin = gr.Button("🧹 Clear History")
229
+ submit_btn = gr.Button("🚀 Submit")
230
+ regen_btn = gr.Button("🤔️ Regenerate")
231
+ addfile_btn = gr.UploadButton("📁 Upload", file_types=["audio"])
232
+
233
+ mic.change(add_mic, [chatbot, task_history, mic], [chatbot, task_history])
234
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
235
+ predict, [chatbot, task_history], [chatbot], show_progress=True
236
+ )
237
+ submit_btn.click(reset_user_input, [], [query])
238
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
239
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
240
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
241
+
242
+ gr.Markdown("""\
243
+ <font size=2>Note: This demo is governed by the original license of Qwen-Audio. \
244
+ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
245
+ including hate speech, violence, pornography, deception, etc. \
246
+ """)
247
+
248
+ demo.queue().launch(
249
+ share=args.share,
250
+ inbrowser=args.inbrowser,
251
+ server_port=args.server_port,
252
+ server_name=args.server_name,
253
+ file_directories=["/tmp/"]
254
+ )
255
+
256
+
257
+ def main():
258
+ args = _get_args()
259
+
260
+ model, tokenizer = _load_model_tokenizer(args)
261
+
262
+ _launch_demo(args, model, tokenizer)
263
+
264
+
265
+ if __name__ == '__main__':
266
+ main()