Spaces:
Running
Running
Niki Zhang
commited on
fix gpt usage
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import requests
|
|
7 |
from packaging import version
|
8 |
from PIL import Image, ImageDraw
|
9 |
import functools
|
10 |
-
|
11 |
from caption_anything.model import CaptionAnything
|
12 |
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
13 |
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
|
@@ -68,17 +68,38 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
|
|
68 |
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def init_openai_api_key(api_key=""):
|
72 |
text_refiner = None
|
73 |
visual_chatgpt = None
|
74 |
if api_key and len(api_key) > 30:
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
openai_available = text_refiner is not None
|
83 |
if openai_available:
|
84 |
return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
|
@@ -175,7 +196,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
|
|
175 |
|
176 |
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
177 |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
|
178 |
-
evt: gr.SelectData):
|
179 |
click_index = evt.index
|
180 |
|
181 |
if point_prompt == 'Positive':
|
@@ -212,11 +233,13 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
212 |
text = out['generated_captions']['raw_caption']
|
213 |
input_mask = np.array(out['mask'].convert('P'))
|
214 |
image_input = mask_painter(np.array(image_input), input_mask)
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
220 |
if visual_chatgpt is not None:
|
221 |
print('inference_click: add caption to chatGPT memory')
|
222 |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
@@ -224,50 +247,59 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
224 |
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
225 |
visual_chatgpt.point_prompt = point_prompt
|
226 |
|
227 |
-
|
228 |
generated_caption = text
|
229 |
print(generated_caption)
|
230 |
|
231 |
-
yield state, state, click_state, image_input, generated_caption
|
232 |
|
233 |
-
if not args.disable_gpt and model.text_refiner:
|
234 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
235 |
-
enable_wiki=enable_wiki)
|
236 |
-
new_cap = refined_caption['caption']
|
237 |
-
if refined_caption['wiki']:
|
238 |
-
state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
|
239 |
-
state = state + [(None, f"caption: {new_cap}")]
|
240 |
-
# refined_image_input = create_bubble_frame(origin_image_input, None, (click_index[0], click_index[1]),
|
241 |
-
# input_mask,
|
242 |
-
# input_points=input_points, input_labels=input_labels)
|
243 |
-
yield state, state, click_state, image_input, new_cap
|
244 |
-
|
245 |
-
def submit_caption(image_input, state,generated_caption):
|
246 |
-
print(state)
|
247 |
-
if state and isinstance(state[-1][1], dict):
|
248 |
-
params = state[-1][1]
|
249 |
-
else:
|
250 |
-
params = {}
|
251 |
|
252 |
-
click_index = params.get("click_index", (0, 0))
|
253 |
-
input_mask = params.get("input_mask", np.zeros((1, 1)))
|
254 |
-
input_points = params.get("input_points", [])
|
255 |
-
input_labels = params.get("input_labels", [])
|
256 |
|
257 |
-
click_index = params.get("click_index", (0, 0))
|
258 |
-
input_mask = params.get("input_mask", np.zeros((1, 1)))
|
259 |
-
input_points = params.get("input_points", [])
|
260 |
-
input_labels = params.get("input_labels", [])
|
261 |
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
input_points=input_points, input_labels=input_labels)
|
264 |
|
265 |
-
|
266 |
if generated_caption:
|
267 |
state = state + [(None, f"RAW_Caption: {generated_caption}")]
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
-
yield state,state,image_input
|
271 |
|
272 |
|
273 |
def txt2speech(text):
|
@@ -427,8 +459,8 @@ def create_ui():
|
|
427 |
css=get_style()
|
428 |
) as iface:
|
429 |
state = gr.State([])
|
|
|
430 |
click_state = gr.State([[], [], []])
|
431 |
-
# chat_state = gr.State([])
|
432 |
origin_image = gr.State(None)
|
433 |
image_embedding = gr.State(None)
|
434 |
text_refiner = gr.State(None)
|
@@ -436,8 +468,11 @@ def create_ui():
|
|
436 |
original_size = gr.State(None)
|
437 |
input_size = gr.State(None)
|
438 |
generated_caption = gr.State("")
|
439 |
-
# img_caption = gr.State(None)
|
440 |
aux_state = gr.State([])
|
|
|
|
|
|
|
|
|
441 |
|
442 |
gr.Markdown(title)
|
443 |
gr.Markdown(description)
|
@@ -619,18 +654,22 @@ def create_ui():
|
|
619 |
inference_click,
|
620 |
inputs=[
|
621 |
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
|
622 |
-
image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
|
|
|
623 |
],
|
624 |
-
outputs=[chatbot, state, click_state, image_input, generated_caption],
|
625 |
show_progress=False, queue=True
|
626 |
)
|
627 |
|
|
|
628 |
submit_button_click.click(
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
634 |
|
635 |
|
636 |
|
@@ -651,3 +690,4 @@ if __name__ == '__main__':
|
|
651 |
iface = create_ui()
|
652 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
653 |
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
|
7 |
from packaging import version
|
8 |
from PIL import Image, ImageDraw
|
9 |
import functools
|
10 |
+
from langchain.llms.openai import OpenAI
|
11 |
from caption_anything.model import CaptionAnything
|
12 |
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
13 |
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
|
|
|
68 |
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
|
69 |
|
70 |
|
71 |
+
def validate_api_key(api_key):
|
72 |
+
api_key = str(api_key).strip()
|
73 |
+
print(api_key)
|
74 |
+
try:
|
75 |
+
test_llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
|
76 |
+
response = test_llm("Test API call")
|
77 |
+
print(response)
|
78 |
+
return True
|
79 |
+
except Exception as e:
|
80 |
+
print(f"API key validation failed: {e}")
|
81 |
+
return False
|
82 |
+
|
83 |
+
|
84 |
def init_openai_api_key(api_key=""):
|
85 |
text_refiner = None
|
86 |
visual_chatgpt = None
|
87 |
if api_key and len(api_key) > 30:
|
88 |
+
print(api_key)
|
89 |
+
if validate_api_key(api_key):
|
90 |
+
try:
|
91 |
+
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
92 |
+
assert len(text_refiner.llm('hi')) > 0 # test
|
93 |
+
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
|
94 |
+
except Exception as e:
|
95 |
+
print(f"Error initializing TextRefiner or ConversationBot: {e}")
|
96 |
+
text_refiner = None
|
97 |
+
visual_chatgpt = None
|
98 |
+
else:
|
99 |
+
print("Invalid API key.")
|
100 |
+
else:
|
101 |
+
print("API key is too short.")
|
102 |
+
print(text_refiner)
|
103 |
openai_available = text_refiner is not None
|
104 |
if openai_available:
|
105 |
return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
|
|
|
196 |
|
197 |
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
198 |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
|
199 |
+
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData):
|
200 |
click_index = evt.index
|
201 |
|
202 |
if point_prompt == 'Positive':
|
|
|
233 |
text = out['generated_captions']['raw_caption']
|
234 |
input_mask = np.array(out['mask'].convert('P'))
|
235 |
image_input = mask_painter(np.array(image_input), input_mask)
|
236 |
+
|
237 |
+
click_index_state = click_index
|
238 |
+
input_mask_state = input_mask
|
239 |
+
input_points_state = input_points
|
240 |
+
input_labels_state = input_labels
|
241 |
+
out_state = out
|
242 |
+
|
243 |
if visual_chatgpt is not None:
|
244 |
print('inference_click: add caption to chatGPT memory')
|
245 |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
|
|
247 |
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
248 |
visual_chatgpt.point_prompt = point_prompt
|
249 |
|
|
|
250 |
generated_caption = text
|
251 |
print(generated_caption)
|
252 |
|
253 |
+
yield state, state, click_state, image_input, generated_caption, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
|
|
|
|
|
|
|
|
256 |
|
|
|
|
|
|
|
|
|
257 |
|
258 |
+
def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
|
259 |
+
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state):
|
260 |
+
print("state",state)
|
261 |
+
|
262 |
+
click_index = click_index_state
|
263 |
+
input_mask = input_mask_state
|
264 |
+
input_points = input_points_state
|
265 |
+
input_labels = input_labels_state
|
266 |
+
out = out_state
|
267 |
+
print("click",click_index)
|
268 |
+
|
269 |
+
origin_image_input = image_input
|
270 |
+
|
271 |
+
controls = {
|
272 |
+
'length': length,
|
273 |
+
'sentiment': sentiment,
|
274 |
+
'factuality': factuality,
|
275 |
+
'language': language
|
276 |
+
}
|
277 |
+
|
278 |
+
image_input = create_bubble_frame(np.array(image_input), generated_caption, click_index, input_mask,
|
279 |
input_points=input_points, input_labels=input_labels)
|
280 |
|
|
|
281 |
if generated_caption:
|
282 |
state = state + [(None, f"RAW_Caption: {generated_caption}")]
|
283 |
+
|
284 |
+
|
285 |
+
if not args.disable_gpt and text_refiner:
|
286 |
+
refined_caption = text_refiner.inference(query=generated_caption, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
|
287 |
+
new_cap = refined_caption['caption']
|
288 |
+
if refined_caption.get('wiki'):
|
289 |
+
state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
|
290 |
+
state = state + [(None, f"RAW_Caption: {new_cap}")]
|
291 |
+
print("new_cap",new_cap)
|
292 |
+
refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
|
293 |
+
input_points=input_points, input_labels=input_labels)
|
294 |
+
txt2speech(new_cap)
|
295 |
+
yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
|
296 |
+
|
297 |
+
else:
|
298 |
+
txt2speech(generated_caption)
|
299 |
+
yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
|
300 |
+
|
301 |
+
|
302 |
|
|
|
303 |
|
304 |
|
305 |
def txt2speech(text):
|
|
|
459 |
css=get_style()
|
460 |
) as iface:
|
461 |
state = gr.State([])
|
462 |
+
out_state = gr.State(None)
|
463 |
click_state = gr.State([[], [], []])
|
|
|
464 |
origin_image = gr.State(None)
|
465 |
image_embedding = gr.State(None)
|
466 |
text_refiner = gr.State(None)
|
|
|
468 |
original_size = gr.State(None)
|
469 |
input_size = gr.State(None)
|
470 |
generated_caption = gr.State("")
|
|
|
471 |
aux_state = gr.State([])
|
472 |
+
click_index_state = gr.State((0, 0))
|
473 |
+
input_mask_state = gr.State(np.zeros((1, 1)))
|
474 |
+
input_points_state = gr.State([])
|
475 |
+
input_labels_state = gr.State([])
|
476 |
|
477 |
gr.Markdown(title)
|
478 |
gr.Markdown(description)
|
|
|
654 |
inference_click,
|
655 |
inputs=[
|
656 |
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
|
657 |
+
image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
|
658 |
+
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state
|
659 |
],
|
660 |
+
outputs=[chatbot, state, click_state, image_input, generated_caption, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state],
|
661 |
show_progress=False, queue=True
|
662 |
)
|
663 |
|
664 |
+
|
665 |
submit_button_click.click(
|
666 |
+
submit_caption,
|
667 |
+
inputs=[image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
|
668 |
+
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state],
|
669 |
+
outputs=[chatbot, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state],
|
670 |
+
show_progress=True, queue=True
|
671 |
+
)
|
672 |
+
|
673 |
|
674 |
|
675 |
|
|
|
690 |
iface = create_ui()
|
691 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
692 |
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
693 |
+
|