zhuohaoyu commited on
Commit
207dbd2
·
1 Parent(s): 01ab9cb
Files changed (1) hide show
  1. app.py +208 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This code is based on Qwen's web Demo. It has been modified from
17
+ # its original forms to accommodate CodeShell.
18
+
19
+ # Copyright (c) Alibaba Cloud.
20
+ #
21
+ # This source code is licensed under the license found in the
22
+ # LICENSE file in the root directory of this source tree.
23
+
24
+ """A simple web interactive chat demo based on gradio."""
25
+ import os
26
+ from argparse import ArgumentParser
27
+
28
+ import gradio as gr
29
+ import mdtex2html
30
+
31
+ import torch
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+ from transformers.generation import GenerationConfig
34
+
35
+
36
+ DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat'
37
+
38
+
39
+ def _get_args():
40
+ parser = ArgumentParser()
41
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
42
+ help="Checkpoint name or path, default to %(default)r")
43
+ parser.add_argument("--device", type=str, default="cuda:0", help="GPU device.")
44
+
45
+ parser.add_argument("--share", action="store_true", default=False,
46
+ help="Create a publicly shareable link for the interface.")
47
+ parser.add_argument("--inbrowser", action="store_true", default=False,
48
+ help="Automatically launch the interface in a new tab on the default browser.")
49
+ parser.add_argument("--server-port", type=int, default=8000,
50
+ help="Demo server port.")
51
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
52
+ help="Demo server name.")
53
+
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ def _load_model_tokenizer(args):
59
+ tokenizer = AutoTokenizer.from_pretrained(
60
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
61
+ )
62
+
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ args.checkpoint_path,
65
+ device_map=args.device,
66
+ trust_remote_code=True,
67
+ resume_download=True,
68
+ torch_dtype=torch.bfloat16
69
+ ).eval()
70
+
71
+ config = GenerationConfig.from_pretrained(
72
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
73
+ )
74
+
75
+ return model, tokenizer, config
76
+
77
+
78
+ def postprocess(self, y):
79
+ if y is None:
80
+ return []
81
+ for i, (message, response) in enumerate(y):
82
+ y[i] = (
83
+ None if message is None else mdtex2html.convert(message),
84
+ None if response is None else mdtex2html.convert(response),
85
+ )
86
+ return y
87
+
88
+
89
+ gr.Chatbot.postprocess = postprocess
90
+
91
+
92
+ def _parse_text(text):
93
+ lines = text.split("\n")
94
+ lines = [line for line in lines if line != ""]
95
+ count = 0
96
+ for i, line in enumerate(lines):
97
+ if "```" in line:
98
+ count += 1
99
+ items = line.split("`")
100
+ if count % 2 == 1:
101
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
102
+ else:
103
+ lines[i] = f"<br></code></pre>"
104
+ else:
105
+ if i > 0:
106
+ if count % 2 == 1:
107
+ line = line.replace("`", r"\`")
108
+ line = line.replace("<", "&lt;")
109
+ line = line.replace(">", "&gt;")
110
+ line = line.replace(" ", "&nbsp;")
111
+ line = line.replace("*", "&ast;")
112
+ line = line.replace("_", "&lowbar;")
113
+ line = line.replace("-", "&#45;")
114
+ line = line.replace(".", "&#46;")
115
+ line = line.replace("!", "&#33;")
116
+ line = line.replace("(", "&#40;")
117
+ line = line.replace(")", "&#41;")
118
+ line = line.replace("$", "&#36;")
119
+ lines[i] = "<br>" + line
120
+ text = "".join(lines)
121
+ return text
122
+
123
+
124
+ def _gc():
125
+ import gc
126
+ gc.collect()
127
+ if torch.cuda.is_available():
128
+ torch.cuda.empty_cache()
129
+
130
+
131
+ def _launch_demo(args, model, tokenizer, config):
132
+
133
+ def predict(_query, _chatbot, _task_history):
134
+ print(f"User: {_parse_text(_query)}")
135
+ _chatbot.append((_parse_text(_query), ""))
136
+ full_response = ""
137
+
138
+ for response in model.chat(_query, _task_history, tokenizer, generation_config=config, stream=True):
139
+ response = response.replace('|end|', '')
140
+ response = response.replace('|<end>|', '')
141
+ _chatbot[-1] = (_parse_text(_query), _parse_text(response))
142
+
143
+ yield _chatbot
144
+ full_response = _parse_text(response)
145
+
146
+ print(f"History: {_task_history}")
147
+ _task_history.append((_query, full_response))
148
+ print(f"CodeShell-Chat: {_parse_text(full_response)}")
149
+
150
+ def regenerate(_chatbot, _task_history):
151
+ if not _task_history:
152
+ yield _chatbot
153
+ return
154
+ item = _task_history.pop(-1)
155
+ _chatbot.pop(-1)
156
+ yield from predict(item[0], _chatbot, _task_history)
157
+
158
+ def reset_user_input():
159
+ return gr.update(value="")
160
+
161
+ def reset_state(_chatbot, _task_history):
162
+ _task_history.clear()
163
+ _chatbot.clear()
164
+ _gc()
165
+ return _chatbot
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.Markdown("""<center><font size=8>CodeShell-Chat Bot</center>""")
169
+
170
+ chatbot = gr.Chatbot(label='CodeShell-Chat', elem_classes="control-height")
171
+ query = gr.Textbox(lines=2, label='Input')
172
+ task_history = gr.State([])
173
+
174
+ with gr.Row():
175
+ empty_btn = gr.Button("🧹 Clear History (清除历史)")
176
+ submit_btn = gr.Button("🚀 Submit (发送)")
177
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
178
+
179
+ submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
180
+ submit_btn.click(reset_user_input, [], [query])
181
+ empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
182
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
183
+
184
+ gr.Markdown("""\
185
+ <font size=2>Note: This demo is governed by the original license of CodeShell. \
186
+ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
187
+ including hate speech, violence, pornography, deception, etc. \
188
+ (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
189
+ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
190
+
191
+ demo.queue().launch(
192
+ share=args.share,
193
+ inbrowser=args.inbrowser,
194
+ server_port=args.server_port,
195
+ server_name=args.server_name,
196
+ )
197
+
198
+
199
+ def main():
200
+ args = _get_args()
201
+
202
+ model, tokenizer, config = _load_model_tokenizer(args)
203
+
204
+ _launch_demo(args, model, tokenizer, config)
205
+
206
+
207
+ if __name__ == '__main__':
208
+ main()