aaabiao commited on
Commit
9e5e896
·
verified ·
1 Parent(s): 2560819

Delete demo.py

Browse files
Files changed (1) hide show
  1. demo.py +0 -242
demo.py DELETED
@@ -1,242 +0,0 @@
1
- """A simple web interactive chat demo based on gradio."""
2
-
3
- from argparse import ArgumentParser
4
- from threading import Thread
5
-
6
- import gradio as gr
7
- import torch
8
- from transformers import (
9
- AutoModelForCausalLM,
10
- AutoTokenizer,
11
- StoppingCriteria,
12
- StoppingCriteriaList,
13
- TextIteratorStreamer,
14
- )
15
-
16
-
17
- class StopOnTokens(StoppingCriteria):
18
- def __call__(
19
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
20
- ) -> bool:
21
- stop_ids = (
22
- [2, 6, 7, 8],
23
- ) # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
24
- for stop_id in stop_ids:
25
- if input_ids[0][-1] == stop_id:
26
- return True
27
- return False
28
-
29
- class StoppingCriteriaSub(StoppingCriteria):
30
- def __init__(self, stops = [], encounters=1):
31
- super().__init__()
32
- self.stops = [stop.to("cuda") for stop in stops]
33
-
34
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
35
- last_token = input_ids[0][-1]
36
- for stop in self.stops:
37
- if tokenizer.decode(stop) == tokenizer.decode(last_token):
38
- return True
39
- return False
40
-
41
-
42
- def parse_text(text):
43
- lines = text.split("\n")
44
- lines = [line for line in lines if line != ""]
45
- count = 0
46
- for i, line in enumerate(lines):
47
- if "```" in line:
48
- count += 1
49
- items = line.split("`")
50
- if count % 2 == 1:
51
- lines[i] = f'<pre><code class="language-{items[-1]}">'
52
- else:
53
- lines[i] = f"<br></code></pre>"
54
- else:
55
- if i > 0:
56
- if count % 2 == 1:
57
- line = line.replace("`", "\`")
58
- line = line.replace("<", "&lt;")
59
- line = line.replace(">", "&gt;")
60
- line = line.replace(" ", "&nbsp;")
61
- line = line.replace("*", "&ast;")
62
- line = line.replace("_", "&lowbar;")
63
- line = line.replace("-", "&#45;")
64
- line = line.replace(".", "&#46;")
65
- line = line.replace("!", "&#33;")
66
- line = line.replace("(", "&#40;")
67
- line = line.replace(")", "&#41;")
68
- line = line.replace("$", "&#36;")
69
- lines[i] = "<br>" + line
70
- text = "".join(lines)
71
- return text
72
-
73
-
74
- def predict(history, max_length, top_p, temperature):
75
- stop = StopOnTokens()
76
- # messages = [{"role": "system", "content": "You are a helpful assistant"}]
77
- messages = [{"role": "system", "content": ""}]
78
- # messages = []
79
- for idx, (user_msg, model_msg) in enumerate(history):
80
- if idx == len(history) - 1 and not model_msg:
81
- messages.append({"role": "user", "content": user_msg})
82
- break
83
- if user_msg:
84
- messages.append({"role": "user", "content": user_msg})
85
- if model_msg:
86
- messages.append({"role": "assistant", "content": model_msg})
87
-
88
- print("\n\n====conversation====\n", messages)
89
- model_inputs = tokenizer.apply_chat_template(
90
- messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
91
- ).to(next(model.parameters()).device)
92
- streamer = TextIteratorStreamer(
93
- tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
94
- )
95
-
96
- # stop_words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"]
97
- stop_words = ["</s>"]
98
- stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
99
- stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
100
-
101
- generate_kwargs = {
102
- "input_ids": model_inputs,
103
- "streamer": streamer,
104
- "max_new_tokens": max_length,
105
- "do_sample": True,
106
- "top_p": top_p,
107
- "temperature": temperature,
108
- "stopping_criteria": stopping_criteria,
109
- "repetition_penalty": 1.1,
110
- }
111
- t = Thread(target=model.generate, kwargs=generate_kwargs)
112
- t.start()
113
-
114
- for new_token in streamer:
115
- if new_token != "":
116
- history[-1][1] += new_token
117
- yield history
118
-
119
-
120
- def main(args):
121
- with gr.Blocks() as demo:
122
- # gr.Markdown(
123
- # """\
124
- # <p align="center"><img src="https://raw.githubusercontent.com/01-ai/Yi/main/assets/img/Yi_logo_icon_light.svg" style="height: 80px"/><p>"""
125
- # )
126
- # gr.Markdown("""<center><font size=8>Yi-Chat Bot</center>""")
127
- gr.Markdown("""<center><font size=8>🦣MAmmoTH2</center>""")
128
- # gr.Markdown(
129
- # """\
130
- # <center><font size=3>This WebUI is based on Yi-Chat, developed by 01-AI.</center>"""
131
- # )
132
- gr.Markdown(
133
- """\
134
- <center><font size=4>
135
- MAmmoTH2-8x7B-Plus <a style="text-decoration: none" href="https://huggingface.co/TIGER-Lab/MAmmoTH2-8x7B-Plus/">🤗</a> """
136
- # <a style="text-decoration: none" href="https://www.modelscope.cn/models/01ai/Yi-34B-Chat/summary">🤖</a>&nbsp
137
- # &nbsp<a style="text-decoration: none" href="https://github.com/01-ai/Yi">Yi GitHub</a></center>
138
-
139
- )
140
-
141
- chatbot = gr.Chatbot()
142
-
143
- with gr.Row():
144
- with gr.Column(scale=4):
145
- with gr.Column(scale=12):
146
- user_input = gr.Textbox(
147
- show_label=False,
148
- placeholder="Input...",
149
- lines=10,
150
- container=False,
151
- )
152
- with gr.Column(min_width=32, scale=1):
153
- submitBtn = gr.Button("🚀 Submit")
154
- with gr.Column(scale=1):
155
- emptyBtn = gr.Button("🧹 Clear History")
156
- max_length = gr.Slider(
157
- 0,
158
- 32768,
159
- value=4096,
160
- step=1.0,
161
- label="Maximum length",
162
- interactive=True,
163
- )
164
- top_p = gr.Slider(
165
- 0, 1, value=1.0, step=0.01, label="Top P", interactive=True
166
- )
167
- temperature = gr.Slider(
168
- 0.01, 1, value=0.7, step=0.01, label="Temperature", interactive=True
169
- )
170
-
171
- def user(query, history):
172
- # return "", history + [[parse_text(query), ""]]
173
- return "", history + [[query, ""]]
174
-
175
- submitBtn.click(
176
- user, [user_input, chatbot], [user_input, chatbot], queue=False
177
- ).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
178
- user_input.submit(
179
- user, [user_input, chatbot], [user_input, chatbot], queue=False
180
- ).then(predict, [chatbot, max_length, top_p, temperature], chatbot)
181
- emptyBtn.click(lambda: None, None, chatbot, queue=False)
182
-
183
- demo.queue()
184
-
185
- demo.launch(
186
- server_name=args.server_name,
187
- server_port=args.server_port,
188
- inbrowser=args.inbrowser,
189
- share=args.share
190
- )
191
-
192
-
193
- if __name__ == "__main__":
194
- parser = ArgumentParser()
195
- parser.add_argument(
196
- "-c",
197
- "--checkpoint-path",
198
- type=str,
199
- default="TIGER-Lab/MAmmoTH2-8B-Plus",
200
- help="Checkpoint name or path, default to %(default)r",
201
- )
202
- parser.add_argument(
203
- "--cpu-only", action="store_true", help="Run demo with CPU only"
204
- )
205
- parser.add_argument(
206
- "--share",
207
- action="store_true",
208
- default=False,
209
- help="Create a publicly shareable link for the interface.",
210
- )
211
- parser.add_argument(
212
- "--inbrowser",
213
- action="store_true",
214
- default=True,
215
- help="Automatically launch the interface in a new tab on the default browser.",
216
- )
217
- parser.add_argument(
218
- "--server-port", type=int, default=8110, help="Demo server port."
219
- )
220
- parser.add_argument(
221
- "--server-name", type=str, default="127.0.0.1", help="Demo server name."
222
- )
223
-
224
- args = parser.parse_args()
225
-
226
- tokenizer = AutoTokenizer.from_pretrained(
227
- args.checkpoint_path, trust_remote_code=True
228
- )
229
-
230
- if args.cpu_only:
231
- device_map = "cpu"
232
- else:
233
- device_map = "auto"
234
-
235
- model = AutoModelForCausalLM.from_pretrained(
236
- args.checkpoint_path,
237
- device_map=device_map,
238
- torch_dtype="auto",
239
- trust_remote_code=True,
240
- ).eval()
241
-
242
- main(args)