finalf0 commited on
Commit
593598d
·
1 Parent(s): 3c1e619
Files changed (1) hide show
  1. app.py +20 -259
app.py CHANGED
@@ -1,262 +1,23 @@
1
- #!/usr/bin/env python
2
- # encoding: utf-8
3
  import spaces
4
- import gradio as gr
5
- from PIL import Image
6
- import traceback
7
- import re
8
  import torch
9
- import argparse
10
- from transformers import AutoModel, AutoTokenizer
11
-
12
- # README, How to run demo on different devices
13
- # For Nvidia GPUs support BF16 (like A100, H100, RTX3090)
14
- # python web_demo.py --device cuda --dtype bf16
15
-
16
- # For Nvidia GPUs do NOT support BF16 (like V100, T4, RTX2080)
17
- # python web_demo.py --device cuda --dtype fp16
18
-
19
- # For Mac with MPS (Apple silicon or AMD GPUs).
20
- # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo.py --device mps --dtype fp16
21
-
22
- # Argparser
23
- parser = argparse.ArgumentParser(description='demo')
24
- parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
25
- parser.add_argument('--dtype', type=str, default='bf16', help='bf16 or fp16')
26
- args = parser.parse_args()
27
- device = args.device
28
- assert device in ['cuda', 'mps']
29
- if args.dtype == 'bf16':
30
- dtype = torch.bfloat16
31
- else:
32
- dtype = torch.float16
33
-
34
- # Load model
35
- model_path = 'openbmb/MiniCPM-V-2'
36
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
37
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
38
-
39
- model = model.to(device=device, dtype=dtype)
40
- model.eval()
41
-
42
-
43
-
44
- ERROR_MSG = "Error, please retry"
45
- model_name = 'MiniCPM-V 2.0'
46
-
47
- form_radio = {
48
- 'choices': ['Beam Search', 'Sampling'],
49
- #'value': 'Beam Search',
50
- 'value': 'Sampling',
51
- 'interactive': True,
52
- 'label': 'Decode Type'
53
- }
54
- # Beam Form
55
- num_beams_slider = {
56
- 'minimum': 0,
57
- 'maximum': 5,
58
- 'value': 3,
59
- 'step': 1,
60
- 'interactive': True,
61
- 'label': 'Num Beams'
62
- }
63
- repetition_penalty_slider = {
64
- 'minimum': 0,
65
- 'maximum': 3,
66
- 'value': 1.2,
67
- 'step': 0.01,
68
- 'interactive': True,
69
- 'label': 'Repetition Penalty'
70
- }
71
- repetition_penalty_slider2 = {
72
- 'minimum': 0,
73
- 'maximum': 3,
74
- 'value': 1.05,
75
- 'step': 0.01,
76
- 'interactive': True,
77
- 'label': 'Repetition Penalty'
78
- }
79
- max_new_tokens_slider = {
80
- 'minimum': 1,
81
- 'maximum': 4096,
82
- 'value': 1024,
83
- 'step': 1,
84
- 'interactive': True,
85
- 'label': 'Max New Tokens'
86
- }
87
-
88
- top_p_slider = {
89
- 'minimum': 0,
90
- 'maximum': 1,
91
- 'value': 0.8,
92
- 'step': 0.05,
93
- 'interactive': True,
94
- 'label': 'Top P'
95
- }
96
- top_k_slider = {
97
- 'minimum': 0,
98
- 'maximum': 200,
99
- 'value': 100,
100
- 'step': 1,
101
- 'interactive': True,
102
- 'label': 'Top K'
103
- }
104
- temperature_slider = {
105
- 'minimum': 0,
106
- 'maximum': 2,
107
- 'value': 0.7,
108
- 'step': 0.05,
109
- 'interactive': True,
110
- 'label': 'Temperature'
111
- }
112
-
113
-
114
- def create_component(params, comp='Slider'):
115
- if comp == 'Slider':
116
- return gr.Slider(
117
- minimum=params['minimum'],
118
- maximum=params['maximum'],
119
- value=params['value'],
120
- step=params['step'],
121
- interactive=params['interactive'],
122
- label=params['label']
123
- )
124
- elif comp == 'Radio':
125
- return gr.Radio(
126
- choices=params['choices'],
127
- value=params['value'],
128
- interactive=params['interactive'],
129
- label=params['label']
130
- )
131
- elif comp == 'Button':
132
- return gr.Button(
133
- value=params['value'],
134
- interactive=True
135
- )
136
-
137
- @spaces.GPU(duration=120)
138
- def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
139
- default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
140
- if params is None:
141
- params = default_params
142
- if img is None:
143
- return -1, "Error, invalid image, please upload a new image", None, None
144
- try:
145
- image = img.convert('RGB')
146
- answer, context, _ = model.chat(
147
- image=image,
148
- msgs=msgs,
149
- context=None,
150
- tokenizer=tokenizer,
151
- **params
152
- )
153
- res = re.sub(r'(<box>.*</box>)', '', answer)
154
- res = res.replace('<ref>', '')
155
- res = res.replace('</ref>', '')
156
- res = res.replace('<box>', '')
157
- answer = res.replace('</box>', '')
158
- return -1, answer, None, None
159
- except Exception as err:
160
- print(err)
161
- traceback.print_exc()
162
- return -1, ERROR_MSG, None, None
163
-
164
-
165
- def upload_img(image, _chatbot, _app_session):
166
- image = Image.fromarray(image)
167
-
168
- _app_session['sts']=None
169
- _app_session['ctx']=[]
170
- _app_session['img']=image
171
- _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
172
- return _chatbot, _app_session
173
-
174
-
175
- def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
176
- if _app_cfg.get('ctx', None) is None:
177
- _chat_bot.append((_question, 'Please upload an image to start'))
178
- return '', _chat_bot, _app_cfg
179
-
180
- _context = _app_cfg['ctx'].copy()
181
- if _context:
182
- _context.append({"role": "user", "content": _question})
183
- else:
184
- _context = [{"role": "user", "content": _question}]
185
- print('<User>:', _question)
186
-
187
- if params_form == 'Beam Search':
188
- params = {
189
- 'sampling': False,
190
- 'num_beams': num_beams,
191
- 'repetition_penalty': repetition_penalty,
192
- "max_new_tokens": 896
193
- }
194
- else:
195
- params = {
196
- 'sampling': True,
197
- 'top_p': top_p,
198
- 'top_k': top_k,
199
- 'temperature': temperature,
200
- 'repetition_penalty': repetition_penalty_2,
201
- "max_new_tokens": 896
202
- }
203
- code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
204
- print('<Assistant>:', _answer)
205
-
206
- _context.append({"role": "assistant", "content": _answer})
207
- _chat_bot.append((_question, _answer))
208
- if code == 0:
209
- _app_cfg['ctx']=_context
210
- _app_cfg['sts']=sts
211
- return '', _chat_bot, _app_cfg
212
-
213
-
214
- def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
215
- if len(_chat_bot) <= 1:
216
- _chat_bot.append(('Regenerate', 'No question for regeneration.'))
217
- return '', _chat_bot, _app_cfg
218
- elif _chat_bot[-1][0] == 'Regenerate':
219
- return '', _chat_bot, _app_cfg
220
- else:
221
- _question = _chat_bot[-1][0]
222
- _chat_bot = _chat_bot[:-1]
223
- _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
224
- return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
225
-
226
-
227
-
228
- with gr.Blocks() as demo:
229
- with gr.Row():
230
- with gr.Column(scale=1, min_width=300):
231
- params_form = create_component(form_radio, comp='Radio')
232
- with gr.Accordion("Beam Search") as beams_according:
233
- num_beams = create_component(num_beams_slider)
234
- repetition_penalty = create_component(repetition_penalty_slider)
235
- with gr.Accordion("Sampling") as sampling_according:
236
- top_p = create_component(top_p_slider)
237
- top_k = create_component(top_k_slider)
238
- temperature = create_component(temperature_slider)
239
- repetition_penalty_2 = create_component(repetition_penalty_slider2)
240
- regenerate = create_component({'value': 'Regenerate'}, comp='Button')
241
- with gr.Column(scale=3, min_width=500):
242
- app_session = gr.State({'sts':None,'ctx':None,'img':None})
243
- bt_pic = gr.Image(label="Upload an image to start")
244
- chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
245
- txt_message = gr.Textbox(label="Input text")
246
-
247
- regenerate.click(
248
- regenerate_button_clicked,
249
- [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
250
- [txt_message, chat_bot, app_session]
251
- )
252
- txt_message.submit(
253
- respond,
254
- [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
255
- [txt_message, chat_bot, app_session]
256
- )
257
- bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
258
-
259
- # launch
260
- #demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
261
- demo.launch()
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
 
2
  import spaces
3
+ import timm
4
+ import gradio as gr
 
 
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ model = timm.create_model(
8
+ 'vit_so400m_patch14_siglip_384.webli',
9
+ pretrained=False,
10
+ num_classes=0,
11
+ dynamic_img_size=True,
12
+ dynamic_img_pad=True
13
+ )
14
+
15
+ @spaces.GPU
16
+ def generate(prompt):
17
+ return type(model)
18
+
19
+ gr.Interface(
20
+ fn=generate,
21
+ inputs=gr.Text(),
22
+ outputs=gr.Gallery(),
23
+ ).launch()