Diplab commited on
Commit
3f4cef5
·
verified ·
1 Parent(s): 37c5869

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +619 -0
demo.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ import cv2
7
+ import re
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import html
13
+ import gradio as gr
14
+
15
+ import torchvision.transforms as T
16
+ import torch.backends.cudnn as cudnn
17
+
18
+ from minigpt4.common.config import Config
19
+
20
+ from minigpt4.common.registry import registry
21
+ from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
22
+
23
+ # imports modules for registration
24
+ from minigpt4.datasets.builders import *
25
+ from minigpt4.models import *
26
+ from minigpt4.processors import *
27
+ from minigpt4.runners import *
28
+ from minigpt4.tasks import *
29
+
30
+
31
+ def parse_args():
32
+ parser = argparse.ArgumentParser(description="Demo")
33
+ parser.add_argument("--cfg-path", default='eval_configs/Florence.yaml',
34
+ help="path to configuration file.")
35
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
36
+ parser.add_argument(
37
+ "--options",
38
+ nargs="+",
39
+ help="override some settings in the used config, the key-value pair "
40
+ "in xxx=yyy format will be merged into config file (deprecate), "
41
+ "change to --cfg-options instead.",
42
+ )
43
+ args = parser.parse_args()
44
+ return args
45
+
46
+
47
+ random.seed(42)
48
+ np.random.seed(42)
49
+ torch.manual_seed(42)
50
+
51
+ cudnn.benchmark = False
52
+ cudnn.deterministic = True
53
+
54
+ print('Initializing Chat')
55
+ args = parse_args()
56
+ cfg = Config(args)
57
+
58
+ device = 'cuda:{}'.format(args.gpu_id)
59
+
60
+ model_config = cfg.model_cfg
61
+ model_config.device_8bit = args.gpu_id
62
+ model_cls = registry.get_model_class(model_config.arch)
63
+ model = model_cls.from_config(model_config).to(device)
64
+ bounding_box_size = 100
65
+
66
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
67
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
+
69
+ model = model.eval()
70
+
71
+ CONV_VISION = Conversation(
72
+ system="",
73
+ roles=(r"<s>[INST] ", r" [/INST]"),
74
+ messages=[],
75
+ offset=2,
76
+ sep_style=SeparatorStyle.SINGLE,
77
+ sep="",
78
+ )
79
+
80
+
81
+ def extract_substrings(string):
82
+ # first check if there is no-finished bracket
83
+ index = string.rfind('}')
84
+ if index != -1:
85
+ string = string[:index + 1]
86
+
87
+ pattern = r'<p>(.*?)\}(?!<)'
88
+ matches = re.findall(pattern, string)
89
+ substrings = [match for match in matches]
90
+
91
+ return substrings
92
+
93
+
94
+ def is_overlapping(rect1, rect2):
95
+ x1, y1, x2, y2 = rect1
96
+ x3, y3, x4, y4 = rect2
97
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
98
+
99
+
100
+ def computeIoU(bbox1, bbox2):
101
+ x1, y1, x2, y2 = bbox1
102
+ x3, y3, x4, y4 = bbox2
103
+ intersection_x1 = max(x1, x3)
104
+ intersection_y1 = max(y1, y3)
105
+ intersection_x2 = min(x2, x4)
106
+ intersection_y2 = min(y2, y4)
107
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
108
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
109
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
110
+ union_area = bbox1_area + bbox2_area - intersection_area
111
+ iou = intersection_area / union_area
112
+ return iou
113
+
114
+
115
+ def save_tmp_img(visual_img):
116
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
117
+ file_path = "tmp\gradio" + file_name
118
+ visual_img.save(file_path)
119
+ return file_path
120
+
121
+
122
+ def mask2bbox(mask):
123
+ if mask is None:
124
+ return ''
125
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
126
+ mask = np.array(mask)[:, :, 0]
127
+
128
+ rows = np.any(mask, axis=1)
129
+ cols = np.any(mask, axis=0)
130
+
131
+ if rows.sum():
132
+ # Get the top, bottom, left, and right boundaries
133
+ rmin, rmax = np.where(rows)[0][[0, -1]]
134
+ cmin, cmax = np.where(cols)[0][[0, -1]]
135
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
136
+ else:
137
+ bbox = ''
138
+
139
+ return bbox
140
+
141
+
142
+ def escape_markdown(text):
143
+ # List of Markdown special characters that need to be escaped
144
+ md_chars = ['<', '>']
145
+
146
+ # Escape each special character
147
+ for char in md_chars:
148
+ text = text.replace(char, '\\' + char)
149
+
150
+ return text
151
+
152
+
153
+ def reverse_escape(text):
154
+ md_chars = ['\\<', '\\>']
155
+
156
+ for char in md_chars:
157
+ text = text.replace(char, char[1:])
158
+
159
+ return text
160
+
161
+
162
+ colors = [
163
+ (255, 0, 0),
164
+ (0, 255, 0),
165
+ (0, 0, 255),
166
+ (210, 210, 0),
167
+ (255, 0, 255),
168
+ (0, 255, 255),
169
+ (114, 128, 250),
170
+ (0, 165, 255),
171
+ (0, 128, 0),
172
+ (144, 238, 144),
173
+ (238, 238, 175),
174
+ (255, 191, 0),
175
+ (0, 128, 0),
176
+ (226, 43, 138),
177
+ (255, 0, 255),
178
+ (0, 215, 255),
179
+ ]
180
+
181
+ color_map = {
182
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
183
+ color_id, color in enumerate(colors)
184
+ }
185
+
186
+ used_colors = colors
187
+
188
+
189
+ def visualize_all_bbox_together(image, generation):
190
+ if image is None:
191
+ return None, ''
192
+
193
+ generation = html.unescape(generation)
194
+ print('gen begin', generation)
195
+ image_width, image_height = image.size
196
+ image = image.resize([500, int(500 / image_width * image_height)])
197
+ image_width, image_height = image.size
198
+
199
+ string_list = extract_substrings(generation)
200
+ if string_list: # it is grounding or detection
201
+ mode = 'all'
202
+ entities = defaultdict(list)
203
+ i = 0
204
+ j = 0
205
+ for string in string_list:
206
+ try:
207
+ obj, string = string.split('</p>')
208
+ except ValueError:
209
+ print('wrong string: ', string)
210
+ continue
211
+ bbox_list = string.split('<delim>')
212
+ flag = False
213
+ for bbox_string in bbox_list:
214
+ integers = re.findall(r'-?\d+', bbox_string)
215
+ if len(integers) == 4:
216
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
217
+ left = x0 / bounding_box_size * image_width
218
+ bottom = y0 / bounding_box_size * image_height
219
+ right = x1 / bounding_box_size * image_width
220
+ top = y1 / bounding_box_size * image_height
221
+
222
+ entities[obj].append([left, bottom, right, top])
223
+
224
+ j += 1
225
+ flag = True
226
+ if flag:
227
+ i += 1
228
+ else:
229
+ integers = re.findall(r'-?\d+', generation)
230
+
231
+ if len(integers) == 4: # it is refer
232
+ mode = 'single'
233
+
234
+ entities = list()
235
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
236
+ left = x0 / bounding_box_size * image_width
237
+ bottom = y0 / bounding_box_size * image_height
238
+ right = x1 / bounding_box_size * image_width
239
+ top = y1 / bounding_box_size * image_height
240
+ entities.append([left, bottom, right, top])
241
+ else:
242
+ # don't detect any valid bbox to visualize
243
+ return None, ''
244
+
245
+ if len(entities) == 0:
246
+ return None, ''
247
+
248
+ if isinstance(image, Image.Image):
249
+ image_h = image.height
250
+ image_w = image.width
251
+ image = np.array(image)
252
+
253
+ elif isinstance(image, str):
254
+ if os.path.exists(image):
255
+ pil_img = Image.open(image).convert("RGB")
256
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
257
+ image_h = pil_img.height
258
+ image_w = pil_img.width
259
+ else:
260
+ raise ValueError(f"invaild image path, {image}")
261
+ elif isinstance(image, torch.Tensor):
262
+
263
+ image_tensor = image.cpu()
264
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
265
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
266
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
267
+ pil_img = T.ToPILImage()(image_tensor)
268
+ image_h = pil_img.height
269
+ image_w = pil_img.width
270
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
271
+ else:
272
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
273
+
274
+ indices = list(range(len(entities)))
275
+
276
+ new_image = image.copy()
277
+
278
+ previous_bboxes = []
279
+ # size of text
280
+ text_size = 0.5
281
+ # thickness of text
282
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
283
+ box_line = 2
284
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
285
+ base_height = int(text_height * 0.675)
286
+ text_offset_original = text_height - base_height
287
+ text_spaces = 2
288
+
289
+ # num_bboxes = sum(len(x[-1]) for x in entities)
290
+ used_colors = colors # random.sample(colors, k=num_bboxes)
291
+
292
+ color_id = -1
293
+ for entity_idx, entity_name in enumerate(entities):
294
+ if mode == 'single' or mode == 'identify':
295
+ bboxes = entity_name
296
+ bboxes = [bboxes]
297
+ else:
298
+ bboxes = entities[entity_name]
299
+ color_id += 1
300
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
301
+ skip_flag = False
302
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
303
+
304
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
305
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
306
+
307
+ if mode == 'all':
308
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
309
+
310
+ x1 = orig_x1 - l_o
311
+ y1 = orig_y1 - l_o
312
+
313
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
314
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
315
+ x1 = orig_x1 + r_o
316
+
317
+ # add text background
318
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
319
+ text_line)
320
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
321
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
322
+
323
+ for prev_bbox in previous_bboxes:
324
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
325
+ prev_bbox['phrase'] == entity_name:
326
+ skip_flag = True
327
+ break
328
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
329
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
330
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
331
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
332
+
333
+ if text_bg_y2 >= image_h:
334
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
335
+ text_bg_y2 = image_h
336
+ y1 = image_h
337
+ break
338
+ if not skip_flag:
339
+ alpha = 0.5
340
+ for i in range(text_bg_y1, text_bg_y2):
341
+ for j in range(text_bg_x1, text_bg_x2):
342
+ if i < image_h and j < image_w:
343
+ if j < text_bg_x1 + 1.35 * c_width:
344
+ # original color
345
+ bg_color = color
346
+ else:
347
+ # white
348
+ bg_color = [255, 255, 255]
349
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
350
+ np.uint8)
351
+
352
+ cv2.putText(
353
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
354
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
355
+ )
356
+
357
+ previous_bboxes.append(
358
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
359
+
360
+ if mode == 'all':
361
+ def color_iterator(colors):
362
+ while True:
363
+ for color in colors:
364
+ yield color
365
+
366
+ color_gen = color_iterator(colors)
367
+
368
+ # Add colors to phrases and remove <p></p>
369
+ def colored_phrases(match):
370
+ phrase = match.group(1)
371
+ color = next(color_gen)
372
+ return f'<span style="color:rgb{color}">{phrase}</span>'
373
+
374
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
375
+ generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
376
+ else:
377
+ generation_colored = ''
378
+
379
+ pil_image = Image.fromarray(new_image)
380
+ return pil_image, generation_colored
381
+
382
+
383
+ def gradio_reset(chat_state, img_list):
384
+ if chat_state is not None:
385
+ chat_state.messages = []
386
+ if img_list is not None:
387
+ img_list = []
388
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
389
+ interactive=True), chat_state, img_list
390
+
391
+
392
+ def image_upload_trigger(upload_flag, replace_flag, img_list):
393
+ # set the upload flag to true when receive a new image.
394
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
395
+ upload_flag = 1
396
+ if img_list:
397
+ replace_flag = 1
398
+ return upload_flag, replace_flag
399
+
400
+
401
+ def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
402
+ # set the upload flag to true when receive a new image.
403
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
404
+ upload_flag = 1
405
+ if img_list or replace_flag == 1:
406
+ replace_flag = 1
407
+
408
+ return upload_flag, replace_flag
409
+
410
+
411
+ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
412
+ if len(user_message) == 0:
413
+ text_box_show = 'Input should not be empty!'
414
+ else:
415
+ text_box_show = ''
416
+
417
+ if isinstance(gr_img, dict):
418
+ gr_img, mask = gr_img['image'], gr_img['mask']
419
+ else:
420
+ mask = None
421
+
422
+ if '[identify]' in user_message:
423
+ # check if user provide bbox in the text input
424
+ integers = re.findall(r'-?\d+', user_message)
425
+ if len(integers) != 4: # no bbox in text
426
+ bbox = mask2bbox(mask)
427
+ user_message = user_message + bbox
428
+
429
+ if chat_state is None:
430
+ chat_state = CONV_VISION.copy()
431
+
432
+ if upload_flag:
433
+ if replace_flag:
434
+ chat_state = CONV_VISION.copy() # new image, reset everything
435
+ replace_flag = 0
436
+ chatbot = []
437
+ img_list = []
438
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
439
+ upload_flag = 0
440
+
441
+ chat.ask(user_message, chat_state)
442
+
443
+ chatbot = chatbot + [[user_message, None]]
444
+
445
+ if '[identify]' in user_message:
446
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
447
+ if visual_img is not None:
448
+ file_path = save_tmp_img(visual_img)
449
+ chatbot = chatbot + [[(file_path,), None]]
450
+
451
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
452
+
453
+
454
+ def gradio_answer(chatbot, chat_state, img_list, temperature):
455
+ llm_message = chat.answer(conv=chat_state,
456
+ img_list=img_list,
457
+ temperature=temperature,
458
+ max_new_tokens=500,
459
+ max_length=2000)[0]
460
+ chatbot[-1][1] = llm_message
461
+ return chatbot, chat_state
462
+
463
+
464
+ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
465
+ if len(img_list) > 0:
466
+ if not isinstance(img_list[0], torch.Tensor):
467
+ chat.encode_img(img_list)
468
+ streamer = chat.stream_answer(conv=chat_state,
469
+ img_list=img_list,
470
+ temperature=temperature,
471
+ max_new_tokens=500,
472
+ max_length=2000)
473
+ output = ''
474
+ for new_output in streamer:
475
+ escapped = escape_markdown(new_output)
476
+ output += escapped
477
+ chatbot[-1][1] = output
478
+ yield chatbot, chat_state
479
+ chat_state.messages[-1][1] = '</s>'
480
+ return chatbot, chat_state
481
+
482
+
483
+ def gradio_visualize(chatbot, gr_img):
484
+ if isinstance(gr_img, dict):
485
+ gr_img, mask = gr_img['image'], gr_img['mask']
486
+
487
+ unescaped = reverse_escape(chatbot[-1][1])
488
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
489
+ if visual_img is not None:
490
+ if len(generation_color):
491
+ chatbot[-1][1] = generation_color
492
+ file_path = save_tmp_img(visual_img)
493
+ chatbot = chatbot + [[None, (file_path,)]]
494
+
495
+ return chatbot
496
+
497
+
498
+ def gradio_taskselect(idx):
499
+ prompt_list = [
500
+ '',
501
+ '[grounding] describe this image in detail',
502
+ '[refer] ',
503
+ '[detection] ',
504
+ '[identify] what is this ',
505
+ '[vqa] '
506
+ ]
507
+ instruct_list = [
508
+ '**Hint:** Type in whatever you want',
509
+ '**Hint:** Send the command to generate a grounded image description',
510
+ '**Hint:** Type in a phrase about an object in the image and send the command',
511
+ '**Hint:** Type in a caption or phrase, and see object locations in the image',
512
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
513
+ '**Hint:** Send a question to get a short answer',
514
+ ]
515
+ return prompt_list[idx], instruct_list[idx]
516
+
517
+
518
+
519
+
520
+ chat = Chat(model, vis_processor, device=device)
521
+
522
+ title = """<h1 align="center">FlorenceMedVQA: A Versatile Vision-Language
523
+ Model for Comprehensive Medical Application </h1>"""
524
+ #description = 'Welcome to your AI-Powered Medical Chatbot!'
525
+
526
+ About_Chatbot= '''
527
+ This chatbot is designed to provide accurate and reliable answers to your medical questions.
528
+ Whether you're seeking general health advice or have specific inquiries about symptoms, treatments, or medical conditions, our AI is here to assist you 24/7.
529
+ While this tool offers valuable information, it’s important to consult with a healthcare professional for personalized medical advice.
530
+
531
+ '''
532
+
533
+ text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
534
+ scale=8)
535
+ with gr.Blocks() as demo:
536
+ gr.Markdown(title)
537
+ gr.Markdown()
538
+ gr.Markdown()
539
+
540
+ with gr.Row():
541
+ with gr.Column(scale=0.5):
542
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
543
+
544
+ temperature = gr.Slider(
545
+ minimum=0.1,
546
+ maximum=1.5,
547
+ value=0.6,
548
+ step=0.1,
549
+ interactive=True,
550
+ label="Temperature",
551
+ )
552
+
553
+ clear = gr.Button("Restart")
554
+
555
+ gr.Markdown(About_Chatbot)
556
+
557
+ with gr.Column():
558
+ chat_state = gr.State(value=None)
559
+ img_list = gr.State(value=[])
560
+ chatbot = gr.Chatbot()
561
+
562
+ dataset = gr.Dataset(
563
+ components=[gr.Textbox(visible=False)],
564
+ samples=[['VQA']],
565
+ type="index",
566
+ label='Task Shortcuts',
567
+ )
568
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
569
+ with gr.Row():
570
+ text_input.render()
571
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
572
+
573
+ upload_flag = gr.State(value=0)
574
+ replace_flag = gr.State(value=0)
575
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
576
+
577
+ dataset.click(
578
+ gradio_taskselect,
579
+ inputs=[dataset],
580
+ outputs=[text_input, task_inst],
581
+ show_progress="hidden",
582
+ postprocess=False,
583
+ queue=False,
584
+ )
585
+
586
+ text_input.submit(
587
+ gradio_ask,
588
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
589
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
590
+ ).success(
591
+ gradio_stream_answer,
592
+ [chatbot, chat_state, img_list, temperature],
593
+ [chatbot, chat_state]
594
+ ).success(
595
+ gradio_visualize,
596
+ [chatbot, image],
597
+ [chatbot],
598
+ queue=False,
599
+ )
600
+
601
+ send.click(
602
+ gradio_ask,
603
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
604
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
605
+ ).success(
606
+ gradio_stream_answer,
607
+ [chatbot, chat_state, img_list, temperature],
608
+ [chatbot, chat_state]
609
+ ).success(
610
+ gradio_visualize,
611
+ [chatbot, image],
612
+ [chatbot],
613
+ queue=False,
614
+ )
615
+
616
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
617
+
618
+ demo.launch()
619
+