Niki Zhang commited on
Commit
c5a524a
·
verified ·
1 Parent(s): 2a1fba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -171,6 +171,30 @@ def upload_callback(image_input, state, visual_chatgpt=None):
171
  return state, state, image_input, click_state, image_input, image_input, image_embedding, \
172
  original_size, input_size
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
175
  length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
176
  evt: gr.SelectData):
@@ -185,18 +209,6 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
185
  input_points = prompt['input_point']
186
  input_labels = prompt['input_label']
187
 
188
- click_state[0] = input_points
189
- click_state[1] = input_labels
190
- state = state + [("Image point: {}, Input label: {}".format(click_state[0], click_state[1]), None)]
191
-
192
- return state, click_state
193
-
194
- def submit_caption(image_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt):
195
- coordinates = json.dumps(list(zip(click_state[0], click_state[1])))
196
- prompt = get_click_prompt(coordinates, click_state, 'Single')
197
- input_points = prompt['input_point']
198
- input_labels = prompt['input_label']
199
-
200
  controls = {'length': length,
201
  'sentiment': sentiment,
202
  'factuality': factuality,
@@ -217,39 +229,37 @@ def submit_caption(image_input, enable_wiki, language, sentiment, factuality, le
217
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
218
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
219
 
 
220
  state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
221
-
222
- update_click_state(click_state, out['generated_captions']['raw_caption'], 'Single')
223
  text = out['generated_captions']['raw_caption']
224
  input_mask = np.array(out['mask'].convert('P'))
225
  image_input = mask_painter(np.array(image_input), input_mask)
226
  origin_image_input = image_input
227
- image_input = create_bubble_frame(image_input, text, (input_points[-1][0], input_points[-1][1]), input_mask,
228
  input_points=input_points, input_labels=input_labels)
229
-
 
230
  if visual_chatgpt is not None:
231
- new_image_path = get_new_image_name('chat_image', func_name='upload')
232
- image_input.save(new_image_path)
233
- visual_chatgpt.current_image = new_image_path
234
- img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
235
- Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
236
- AI_prompt = "Received."
237
- visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
238
- visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
239
 
 
240
  if not args.disable_gpt and model.text_refiner:
241
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
242
  enable_wiki=enable_wiki)
 
243
  new_cap = refined_caption['caption']
244
  if refined_caption['wiki']:
245
  state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
246
  state = state + [(None, f"caption: {new_cap}")]
247
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (input_points[-1][0], input_points[-1][1]),
248
  input_mask,
249
  input_points=input_points, input_labels=input_labels)
250
- return state, state, click_state, refined_image_input
251
-
252
- return state, state, click_state, image_input
253
 
254
 
255
  def get_sketch_prompt(mask: Image.Image):
@@ -556,10 +566,8 @@ def create_ui():
556
  )
557
  clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
558
 
559
- submit_button_click.click(submit_caption,
560
- [origin_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt],
561
- [chatbot, state, click_state, image_input])
562
-
563
  image_input.clear(
564
  lambda: (None, [], [], [[], [], []], "", "", ""),
565
  [],
@@ -588,15 +596,7 @@ def create_ui():
588
  image_embedding, original_size, input_size])
589
  example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
590
  # select coordinate
591
- image_input.select(
592
- inference_click,
593
- inputs=[
594
- origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
595
- image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
596
- ],
597
- outputs=[chatbot, state, click_state, image_input],
598
- show_progress=False, queue=True
599
- )
600
 
601
  submit_button_sketcher.click(
602
  inference_traject,
 
171
  return state, state, image_input, click_state, image_input, image_input, image_embedding, \
172
  original_size, input_size
173
 
174
+ def store_click(image_input, point_prompt, click_mode, state, click_state, evt: gr.SelectData):
175
+ click_index = evt.index
176
+ if point_prompt == 'Positive':
177
+ coordinate = [click_index[0], click_index[1], 1]
178
+ else:
179
+ coordinate = [click_index[0], click_index[1], 0]
180
+
181
+ if click_mode == 'Continuous':
182
+ click_state[0].append(coordinate)
183
+ elif click_mode == 'Single':
184
+ click_state[0] = [coordinate] # Overwrite with latest click
185
+
186
+ return state, click_state
187
+
188
+ def generate_caption(image_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt):
189
+ last_click = click_state[0][-1]
190
+ point_prompt = 'Positive' if last_click[2] == 1 else 'Negative'
191
+
192
+ return inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt, gr.SelectData(index=(last_click[0], last_click[1])))
193
+
194
+
195
+
196
+
197
+
198
  def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
199
  length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
200
  evt: gr.SelectData):
 
209
  input_points = prompt['input_point']
210
  input_labels = prompt['input_label']
211
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  controls = {'length': length,
213
  'sentiment': sentiment,
214
  'factuality': factuality,
 
229
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
230
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
231
 
232
+ state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
233
  state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
234
+ update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
 
235
  text = out['generated_captions']['raw_caption']
236
  input_mask = np.array(out['mask'].convert('P'))
237
  image_input = mask_painter(np.array(image_input), input_mask)
238
  origin_image_input = image_input
239
+ image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
240
  input_points=input_points, input_labels=input_labels)
241
+ x, y = input_points[-1]
242
+
243
  if visual_chatgpt is not None:
244
+ print('inference_click: add caption to chatGPT memory')
245
+ new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
246
+ Image.open(out["crop_save_path"]).save(new_crop_save_path)
247
+ point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
248
+ visual_chatgpt.point_prompt = point_prompt
 
 
 
249
 
250
+ yield state, state, click_state, image_input
251
  if not args.disable_gpt and model.text_refiner:
252
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
253
  enable_wiki=enable_wiki)
254
+ # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
255
  new_cap = refined_caption['caption']
256
  if refined_caption['wiki']:
257
  state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
258
  state = state + [(None, f"caption: {new_cap}")]
259
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
260
  input_mask,
261
  input_points=input_points, input_labels=input_labels)
262
+ yield state, state, click_state, refined_image_input
 
 
263
 
264
 
265
  def get_sketch_prompt(mask: Image.Image):
 
566
  )
567
  clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
568
 
569
+ submit_button_click.click(generate_caption, inputs=[origin_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt], outputs=[chatbot, state, click_state, image_input])
570
+
 
 
571
  image_input.clear(
572
  lambda: (None, [], [], [[], [], []], "", "", ""),
573
  [],
 
596
  image_embedding, original_size, input_size])
597
  example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
598
  # select coordinate
599
+ image_input.select(store_click, inputs=[origin_image, point_prompt, click_mode, state, click_state], outputs=[state, click_state])
 
 
 
 
 
 
 
 
600
 
601
  submit_button_sketcher.click(
602
  inference_traject,