Niki Zhang commited on
Commit
f950d25
·
verified ·
1 Parent(s): 891d27d

fix gpt usage

Browse files
Files changed (1) hide show
  1. app.py +95 -55
app.py CHANGED
@@ -7,7 +7,7 @@ import requests
7
  from packaging import version
8
  from PIL import Image, ImageDraw
9
  import functools
10
-
11
  from caption_anything.model import CaptionAnything
12
  from caption_anything.utils.image_editing_utils import create_bubble_frame
13
  from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
@@ -68,17 +68,38 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
68
  return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def init_openai_api_key(api_key=""):
72
  text_refiner = None
73
  visual_chatgpt = None
74
  if api_key and len(api_key) > 30:
75
- try:
76
- text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
77
- assert len(text_refiner.llm('hi')) > 0 # test
78
- visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
79
- except:
80
- text_refiner = None
81
- visual_chatgpt = None
 
 
 
 
 
 
 
 
82
  openai_available = text_refiner is not None
83
  if openai_available:
84
  return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
@@ -175,7 +196,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
175
 
176
  def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
177
  length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
178
- evt: gr.SelectData):
179
  click_index = evt.index
180
 
181
  if point_prompt == 'Positive':
@@ -212,11 +233,13 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
212
  text = out['generated_captions']['raw_caption']
213
  input_mask = np.array(out['mask'].convert('P'))
214
  image_input = mask_painter(np.array(image_input), input_mask)
215
- origin_image_input = image_input
216
- # image_input = create_bubble_frame(image_input, None, (click_index[0], click_index[1]), input_mask,
217
- # input_points=input_points, input_labels=input_labels)
218
- x, y = input_points[-1]
219
-
 
 
220
  if visual_chatgpt is not None:
221
  print('inference_click: add caption to chatGPT memory')
222
  new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
@@ -224,50 +247,59 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
224
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
225
  visual_chatgpt.point_prompt = point_prompt
226
 
227
-
228
  generated_caption = text
229
  print(generated_caption)
230
 
231
- yield state, state, click_state, image_input, generated_caption
232
 
233
- if not args.disable_gpt and model.text_refiner:
234
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
235
- enable_wiki=enable_wiki)
236
- new_cap = refined_caption['caption']
237
- if refined_caption['wiki']:
238
- state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
239
- state = state + [(None, f"caption: {new_cap}")]
240
- # refined_image_input = create_bubble_frame(origin_image_input, None, (click_index[0], click_index[1]),
241
- # input_mask,
242
- # input_points=input_points, input_labels=input_labels)
243
- yield state, state, click_state, image_input, new_cap
244
-
245
- def submit_caption(image_input, state,generated_caption):
246
- print(state)
247
- if state and isinstance(state[-1][1], dict):
248
- params = state[-1][1]
249
- else:
250
- params = {}
251
 
252
- click_index = params.get("click_index", (0, 0))
253
- input_mask = params.get("input_mask", np.zeros((1, 1)))
254
- input_points = params.get("input_points", [])
255
- input_labels = params.get("input_labels", [])
256
 
257
- click_index = params.get("click_index", (0, 0))
258
- input_mask = params.get("input_mask", np.zeros((1, 1)))
259
- input_points = params.get("input_points", [])
260
- input_labels = params.get("input_labels", [])
261
 
262
- image_input = create_bubble_frame(np.array(image_input), generated_caption, (click_index[0], click_index[1]), input_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  input_points=input_points, input_labels=input_labels)
264
 
265
-
266
  if generated_caption:
267
  state = state + [(None, f"RAW_Caption: {generated_caption}")]
268
- txt2speech(generated_caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- yield state,state,image_input
271
 
272
 
273
  def txt2speech(text):
@@ -427,8 +459,8 @@ def create_ui():
427
  css=get_style()
428
  ) as iface:
429
  state = gr.State([])
 
430
  click_state = gr.State([[], [], []])
431
- # chat_state = gr.State([])
432
  origin_image = gr.State(None)
433
  image_embedding = gr.State(None)
434
  text_refiner = gr.State(None)
@@ -436,8 +468,11 @@ def create_ui():
436
  original_size = gr.State(None)
437
  input_size = gr.State(None)
438
  generated_caption = gr.State("")
439
- # img_caption = gr.State(None)
440
  aux_state = gr.State([])
 
 
 
 
441
 
442
  gr.Markdown(title)
443
  gr.Markdown(description)
@@ -619,18 +654,22 @@ def create_ui():
619
  inference_click,
620
  inputs=[
621
  origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
622
- image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
 
623
  ],
624
- outputs=[chatbot, state, click_state, image_input, generated_caption],
625
  show_progress=False, queue=True
626
  )
627
 
 
628
  submit_button_click.click(
629
- submit_caption,
630
- inputs=[image_input, state, generated_caption],
631
- outputs=[chatbot,state,image_input],
632
- show_progress=True, queue=True
633
- )
 
 
634
 
635
 
636
 
@@ -651,3 +690,4 @@ if __name__ == '__main__':
651
  iface = create_ui()
652
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
653
  iface.launch(server_name="0.0.0.0", enable_queue=True)
 
 
7
  from packaging import version
8
  from PIL import Image, ImageDraw
9
  import functools
10
+ from langchain.llms.openai import OpenAI
11
  from caption_anything.model import CaptionAnything
12
  from caption_anything.utils.image_editing_utils import create_bubble_frame
13
  from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
 
68
  return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
69
 
70
 
71
+ def validate_api_key(api_key):
72
+ api_key = str(api_key).strip()
73
+ print(api_key)
74
+ try:
75
+ test_llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
76
+ response = test_llm("Test API call")
77
+ print(response)
78
+ return True
79
+ except Exception as e:
80
+ print(f"API key validation failed: {e}")
81
+ return False
82
+
83
+
84
  def init_openai_api_key(api_key=""):
85
  text_refiner = None
86
  visual_chatgpt = None
87
  if api_key and len(api_key) > 30:
88
+ print(api_key)
89
+ if validate_api_key(api_key):
90
+ try:
91
+ text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
92
+ assert len(text_refiner.llm('hi')) > 0 # test
93
+ visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
94
+ except Exception as e:
95
+ print(f"Error initializing TextRefiner or ConversationBot: {e}")
96
+ text_refiner = None
97
+ visual_chatgpt = None
98
+ else:
99
+ print("Invalid API key.")
100
+ else:
101
+ print("API key is too short.")
102
+ print(text_refiner)
103
  openai_available = text_refiner is not None
104
  if openai_available:
105
  return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
 
196
 
197
  def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
198
  length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
199
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData):
200
  click_index = evt.index
201
 
202
  if point_prompt == 'Positive':
 
233
  text = out['generated_captions']['raw_caption']
234
  input_mask = np.array(out['mask'].convert('P'))
235
  image_input = mask_painter(np.array(image_input), input_mask)
236
+
237
+ click_index_state = click_index
238
+ input_mask_state = input_mask
239
+ input_points_state = input_points
240
+ input_labels_state = input_labels
241
+ out_state = out
242
+
243
  if visual_chatgpt is not None:
244
  print('inference_click: add caption to chatGPT memory')
245
  new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
 
247
  point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
248
  visual_chatgpt.point_prompt = point_prompt
249
 
 
250
  generated_caption = text
251
  print(generated_caption)
252
 
253
+ yield state, state, click_state, image_input, generated_caption, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
 
 
 
 
256
 
 
 
 
 
257
 
258
+ def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
259
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state):
260
+ print("state",state)
261
+
262
+ click_index = click_index_state
263
+ input_mask = input_mask_state
264
+ input_points = input_points_state
265
+ input_labels = input_labels_state
266
+ out = out_state
267
+ print("click",click_index)
268
+
269
+ origin_image_input = image_input
270
+
271
+ controls = {
272
+ 'length': length,
273
+ 'sentiment': sentiment,
274
+ 'factuality': factuality,
275
+ 'language': language
276
+ }
277
+
278
+ image_input = create_bubble_frame(np.array(image_input), generated_caption, click_index, input_mask,
279
  input_points=input_points, input_labels=input_labels)
280
 
 
281
  if generated_caption:
282
  state = state + [(None, f"RAW_Caption: {generated_caption}")]
283
+
284
+
285
+ if not args.disable_gpt and text_refiner:
286
+ refined_caption = text_refiner.inference(query=generated_caption, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
287
+ new_cap = refined_caption['caption']
288
+ if refined_caption.get('wiki'):
289
+ state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
290
+ state = state + [(None, f"RAW_Caption: {new_cap}")]
291
+ print("new_cap",new_cap)
292
+ refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
293
+ input_points=input_points, input_labels=input_labels)
294
+ txt2speech(new_cap)
295
+ yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
296
+
297
+ else:
298
+ txt2speech(generated_caption)
299
+ yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
300
+
301
+
302
 
 
303
 
304
 
305
  def txt2speech(text):
 
459
  css=get_style()
460
  ) as iface:
461
  state = gr.State([])
462
+ out_state = gr.State(None)
463
  click_state = gr.State([[], [], []])
 
464
  origin_image = gr.State(None)
465
  image_embedding = gr.State(None)
466
  text_refiner = gr.State(None)
 
468
  original_size = gr.State(None)
469
  input_size = gr.State(None)
470
  generated_caption = gr.State("")
 
471
  aux_state = gr.State([])
472
+ click_index_state = gr.State((0, 0))
473
+ input_mask_state = gr.State(np.zeros((1, 1)))
474
+ input_points_state = gr.State([])
475
+ input_labels_state = gr.State([])
476
 
477
  gr.Markdown(title)
478
  gr.Markdown(description)
 
654
  inference_click,
655
  inputs=[
656
  origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
657
+ image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
658
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state
659
  ],
660
+ outputs=[chatbot, state, click_state, image_input, generated_caption, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state],
661
  show_progress=False, queue=True
662
  )
663
 
664
+
665
  submit_button_click.click(
666
+ submit_caption,
667
+ inputs=[image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
668
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state],
669
+ outputs=[chatbot, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state],
670
+ show_progress=True, queue=True
671
+ )
672
+
673
 
674
 
675
 
 
690
  iface = create_ui()
691
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
692
  iface.launch(server_name="0.0.0.0", enable_queue=True)
693
+