JustinLin610 commited on
Commit
75ec303
·
1 Parent(s): cf7e1f2

add app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """A simple web interactive chat demo based on gradio."""
7
+
8
+ from argparse import ArgumentParser
9
+
10
+ import gradio as gr
11
+ import mdtex2html
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from transformers.generation import GenerationConfig
14
+
15
+ DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat'
16
+
17
+
18
+ def _get_args():
19
+ parser = ArgumentParser()
20
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
21
+ help="Checkpoint name or path, default to %(default)r")
22
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
23
+
24
+ parser.add_argument("--share", action="store_true", default=False,
25
+ help="Create a publicly shareable link for the interface.")
26
+ parser.add_argument("--inbrowser", action="store_true", default=False,
27
+ help="Automatically launch the interface in a new tab on the default browser.")
28
+ parser.add_argument("--server-port", type=int, default=8000,
29
+ help="Demo server port.")
30
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
31
+ help="Demo server name.")
32
+
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def _load_model_tokenizer(args):
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
40
+ )
41
+
42
+ if args.cpu_only:
43
+ device_map = "cpu"
44
+ else:
45
+ device_map = "auto"
46
+
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ args.checkpoint_path,
49
+ device_map=device_map,
50
+ trust_remote_code=True,
51
+ resume_download=True,
52
+ ).eval()
53
+ model.generation_config = GenerationConfig.from_pretrained(
54
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
55
+ )
56
+
57
+ return model, tokenizer
58
+
59
+
60
+ def postprocess(self, y):
61
+ if y is None:
62
+ return []
63
+ for i, (message, response) in enumerate(y):
64
+ y[i] = (
65
+ None if message is None else mdtex2html.convert(message),
66
+ None if response is None else mdtex2html.convert(response),
67
+ )
68
+ return y
69
+
70
+
71
+ gr.Chatbot.postprocess = postprocess
72
+
73
+
74
+ def _parse_text(text):
75
+ lines = text.split("\n")
76
+ lines = [line for line in lines if line != ""]
77
+ count = 0
78
+ for i, line in enumerate(lines):
79
+ if "```" in line:
80
+ count += 1
81
+ items = line.split("`")
82
+ if count % 2 == 1:
83
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
84
+ else:
85
+ lines[i] = f"<br></code></pre>"
86
+ else:
87
+ if i > 0:
88
+ if count % 2 == 1:
89
+ line = line.replace("`", r"\`")
90
+ line = line.replace("<", "&lt;")
91
+ line = line.replace(">", "&gt;")
92
+ line = line.replace(" ", "&nbsp;")
93
+ line = line.replace("*", "&ast;")
94
+ line = line.replace("_", "&lowbar;")
95
+ line = line.replace("-", "&#45;")
96
+ line = line.replace(".", "&#46;")
97
+ line = line.replace("!", "&#33;")
98
+ line = line.replace("(", "&#40;")
99
+ line = line.replace(")", "&#41;")
100
+ line = line.replace("$", "&#36;")
101
+ lines[i] = "<br>" + line
102
+ text = "".join(lines)
103
+ return text
104
+
105
+
106
+ def _launch_demo(args, model, tokenizer):
107
+
108
+ def predict(_query, _chatbot, _task_history):
109
+ print(f"User: {_parse_text(_query)}")
110
+ _chatbot.append((_parse_text(_query), ""))
111
+ full_response = ""
112
+
113
+ for response in model.chat_stream(tokenizer, _query, history=_task_history):
114
+ _chatbot[-1] = (_parse_text(_query), _parse_text(response))
115
+
116
+ yield _chatbot
117
+ full_response = _parse_text(response)
118
+
119
+ print(f"History: {_task_history}")
120
+ _task_history.append((_query, full_response))
121
+ print(f"Qwen-7B-Chat: {_parse_text(full_response)}")
122
+
123
+ def regenerate(_chatbot, _task_history):
124
+ if not _task_history:
125
+ yield _chatbot
126
+ return
127
+ item = _task_history.pop(-1)
128
+ _chatbot.pop(-1)
129
+ yield from predict(item[0], _chatbot, _task_history)
130
+
131
+ def reset_user_input():
132
+ return gr.update(value="")
133
+
134
+ def reset_state(_task_history):
135
+ _task_history.clear()
136
+ return []
137
+
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown("""\
140
+ <p align="center"><img src="https://modelscope.cn/api/v1/models/qwen/Qwen-7B-Chat/repo?
141
+ Revision=master&FilePath=assets/logo.jpeg&View=true" style="height: 80px"/><p>""")
142
+ gr.Markdown("""<center><font size=8>Qwen-7B-Chat Bot</center>""")
143
+ gr.Markdown(
144
+ """\
145
+ <center><font size=3>This WebUI is based on Qwen-7B-Chat, developed by Alibaba Cloud. \
146
+ (本WebUI基于Qwen-7B-Chat打造,实现聊天机器人功能。)</center>""")
147
+ gr.Markdown("""\
148
+ <center><font size=4>Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a>
149
+ | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp |
150
+ Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a> |
151
+ <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp |
152
+ &nbsp<a href="https://github.com/QwenLM/Qwen-7B">Github</a></center>""")
153
+
154
+ chatbot = gr.Chatbot(label='Qwen-7B-Chat', elem_classes="control-height")
155
+ query = gr.Textbox(lines=2, label='Input')
156
+ task_history = gr.State([])
157
+
158
+ with gr.Row():
159
+ empty_btn = gr.Button("🧹 Clear History (清除历史)")
160
+ submit_btn = gr.Button("🚀 Submit (发送)")
161
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
162
+
163
+ submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
164
+ submit_btn.click(reset_user_input, [], [query])
165
+ empty_btn.click(reset_state, [task_history], outputs=[chatbot], show_progress=True)
166
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
167
+
168
+ gr.Markdown("""\
169
+ <font size=2>Note: This demo is governed by the original license of Qwen-7B. \
170
+ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
171
+ including hate speech, violence, pornography, deception, etc. \
172
+ (注:本演示受Qwen-7B的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
173
+ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
174
+
175
+ demo.queue().launch(
176
+ share=args.share,
177
+ inbrowser=args.inbrowser,
178
+ server_port=args.server_port,
179
+ server_name=args.server_name,
180
+ )
181
+
182
+
183
+ def main():
184
+ args = _get_args()
185
+
186
+ model, tokenizer = _load_model_tokenizer(args)
187
+
188
+ _launch_demo(args, model, tokenizer)
189
+
190
+
191
+ if __name__ == '__main__':
192
+ main()