y0un92 commited on
Commit
be1029b
·
verified ·
1 Parent(s): e9aa51c

Create web_demo.py

Browse files
Files changed (1) hide show
  1. web_demo.py +258 -0
web_demo.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import traceback
6
+ import re
7
+ import torch
8
+ import argparse
9
+ from transformers import AutoModel, AutoTokenizer
10
+ #模型下载
11
+ from modelscope import snapshot_download
12
+
13
+ # README, How to run demo on different devices
14
+
15
+ # For Nvidia GPUs.
16
+ # python web_demo_2.5.py --device cuda
17
+
18
+ # For Mac with MPS (Apple silicon or AMD GPUs).
19
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
20
+
21
+ # Argparser
22
+ parser = argparse.ArgumentParser(description='demo')
23
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
24
+ args = parser.parse_args()
25
+ device = args.device
26
+ assert device in ['cuda', 'mps']
27
+
28
+ # Load model
29
+ model_path = snapshot_download('OpenBMB/MiniCPM-Llama3-V-2_5')
30
+ if 'int4' in model_path:
31
+ if device == 'mps':
32
+ print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
33
+ exit()
34
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
35
+ else:
36
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
38
+ model.eval()
39
+
40
+
41
+
42
+ ERROR_MSG = "Error, please retry"
43
+ model_name = 'MiniCPM-V 2.5'
44
+
45
+ form_radio = {
46
+ 'choices': ['Beam Search', 'Sampling'],
47
+ #'value': 'Beam Search',
48
+ 'value': 'Sampling',
49
+ 'interactive': True,
50
+ 'label': 'Decode Type'
51
+ }
52
+ # Beam Form
53
+ num_beams_slider = {
54
+ 'minimum': 0,
55
+ 'maximum': 5,
56
+ 'value': 3,
57
+ 'step': 1,
58
+ 'interactive': True,
59
+ 'label': 'Num Beams'
60
+ }
61
+ repetition_penalty_slider = {
62
+ 'minimum': 0,
63
+ 'maximum': 3,
64
+ 'value': 1.2,
65
+ 'step': 0.01,
66
+ 'interactive': True,
67
+ 'label': 'Repetition Penalty'
68
+ }
69
+ repetition_penalty_slider2 = {
70
+ 'minimum': 0,
71
+ 'maximum': 3,
72
+ 'value': 1.05,
73
+ 'step': 0.01,
74
+ 'interactive': True,
75
+ 'label': 'Repetition Penalty'
76
+ }
77
+ max_new_tokens_slider = {
78
+ 'minimum': 1,
79
+ 'maximum': 4096,
80
+ 'value': 1024,
81
+ 'step': 1,
82
+ 'interactive': True,
83
+ 'label': 'Max New Tokens'
84
+ }
85
+
86
+ top_p_slider = {
87
+ 'minimum': 0,
88
+ 'maximum': 1,
89
+ 'value': 0.8,
90
+ 'step': 0.05,
91
+ 'interactive': True,
92
+ 'label': 'Top P'
93
+ }
94
+ top_k_slider = {
95
+ 'minimum': 0,
96
+ 'maximum': 200,
97
+ 'value': 100,
98
+ 'step': 1,
99
+ 'interactive': True,
100
+ 'label': 'Top K'
101
+ }
102
+ temperature_slider = {
103
+ 'minimum': 0,
104
+ 'maximum': 2,
105
+ 'value': 0.7,
106
+ 'step': 0.05,
107
+ 'interactive': True,
108
+ 'label': 'Temperature'
109
+ }
110
+
111
+
112
+ def create_component(params, comp='Slider'):
113
+ if comp == 'Slider':
114
+ return gr.Slider(
115
+ minimum=params['minimum'],
116
+ maximum=params['maximum'],
117
+ value=params['value'],
118
+ step=params['step'],
119
+ interactive=params['interactive'],
120
+ label=params['label']
121
+ )
122
+ elif comp == 'Radio':
123
+ return gr.Radio(
124
+ choices=params['choices'],
125
+ value=params['value'],
126
+ interactive=params['interactive'],
127
+ label=params['label']
128
+ )
129
+ elif comp == 'Button':
130
+ return gr.Button(
131
+ value=params['value'],
132
+ interactive=True
133
+ )
134
+
135
+
136
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
137
+ default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
138
+ if params is None:
139
+ params = default_params
140
+ if img is None:
141
+ return -1, "Error, invalid image, please upload a new image", None, None
142
+ try:
143
+ image = img.convert('RGB')
144
+ answer = model.chat(
145
+ image=image,
146
+ msgs=msgs,
147
+ tokenizer=tokenizer,
148
+ **params
149
+ )
150
+ res = re.sub(r'(<box>.*</box>)', '', answer)
151
+ res = res.replace('<ref>', '')
152
+ res = res.replace('</ref>', '')
153
+ res = res.replace('<box>', '')
154
+ answer = res.replace('</box>', '')
155
+ return 0, answer, None, None
156
+ except Exception as err:
157
+ print(err)
158
+ traceback.print_exc()
159
+ return -1, ERROR_MSG, None, None
160
+
161
+
162
+ def upload_img(image, _chatbot, _app_session):
163
+ image = Image.fromarray(image)
164
+
165
+ _app_session['sts']=None
166
+ _app_session['ctx']=[]
167
+ _app_session['img']=image
168
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
169
+ return _chatbot, _app_session
170
+
171
+
172
+ def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
173
+ if _app_cfg.get('ctx', None) is None:
174
+ _chat_bot.append((_question, 'Please upload an image to start'))
175
+ return '', _chat_bot, _app_cfg
176
+
177
+ _context = _app_cfg['ctx'].copy()
178
+ if _context:
179
+ _context.append({"role": "user", "content": _question})
180
+ else:
181
+ _context = [{"role": "user", "content": _question}]
182
+ print('<User>:', _question)
183
+
184
+ if params_form == 'Beam Search':
185
+ params = {
186
+ 'sampling': False,
187
+ 'num_beams': num_beams,
188
+ 'repetition_penalty': repetition_penalty,
189
+ "max_new_tokens": 896
190
+ }
191
+ else:
192
+ params = {
193
+ 'sampling': True,
194
+ 'top_p': top_p,
195
+ 'top_k': top_k,
196
+ 'temperature': temperature,
197
+ 'repetition_penalty': repetition_penalty_2,
198
+ "max_new_tokens": 896
199
+ }
200
+ code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
201
+ print('<Assistant>:', _answer)
202
+
203
+ _context.append({"role": "assistant", "content": _answer})
204
+ _chat_bot.append((_question, _answer))
205
+ if code == 0:
206
+ _app_cfg['ctx']=_context
207
+ _app_cfg['sts']=sts
208
+ return '', _chat_bot, _app_cfg
209
+
210
+
211
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
212
+ if len(_chat_bot) <= 1:
213
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
214
+ return '', _chat_bot, _app_cfg
215
+ elif _chat_bot[-1][0] == 'Regenerate':
216
+ return '', _chat_bot, _app_cfg
217
+ else:
218
+ _question = _chat_bot[-1][0]
219
+ _chat_bot = _chat_bot[:-1]
220
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
221
+ return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
222
+
223
+
224
+
225
+ with gr.Blocks() as demo:
226
+ with gr.Row():
227
+ with gr.Column(scale=1, min_width=300):
228
+ params_form = create_component(form_radio, comp='Radio')
229
+ with gr.Accordion("Beam Search") as beams_according:
230
+ num_beams = create_component(num_beams_slider)
231
+ repetition_penalty = create_component(repetition_penalty_slider)
232
+ with gr.Accordion("Sampling") as sampling_according:
233
+ top_p = create_component(top_p_slider)
234
+ top_k = create_component(top_k_slider)
235
+ temperature = create_component(temperature_slider)
236
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
237
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
238
+ with gr.Column(scale=3, min_width=500):
239
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
240
+ bt_pic = gr.Image(label="Upload an image to start")
241
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
242
+ txt_message = gr.Textbox(label="Input text")
243
+
244
+ regenerate.click(
245
+ regenerate_button_clicked,
246
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
247
+ [txt_message, chat_bot, app_session]
248
+ )
249
+ txt_message.submit(
250
+ respond,
251
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
252
+ [txt_message, chat_bot, app_session]
253
+ )
254
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
255
+
256
+ # launch
257
+ demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
258
+