apiprompting / app.py
yu-rp
add examples
ab0ee40
raw
history blame
15.3 kB
import os
import gradio as gr
import torch
import base64
import requests
from io import BytesIO
from API_LLaVA.functions import get_model as llava_get_model, get_preanswer as llava_get_preanswer, from_preanswer_to_mask as llava_from_preanswer_to_mask
from API_LLaVA.hook import hook_logger as llava_hook_logger
from API_LLaVA.main import blend_mask as llava_blend_mask
from API_CLIP.main import get_model as clip_get_model, gen_mask as clip_gen_mask, blend_mask as clip_blend_mask
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MARKDOWN = """
<div align='center'>
<b style="font-size: 2em;">API: Attention Prompting on Image for Large Vision-Language Models</b>
<br>
<br>
<br>
[<a href="https://arxiv.org/abs/"> arXiv paper </a>]
[<a href="https://yu-rp.github.io/api-prompting/"> project page </a>]
[<a href="https://pypi.org/project/apiprompting/"> python package </a>]
[<a href="https://github.com/yu-rp/apiprompting"> code </a>]
</div>
"""
def init_clip_examples():
return gr.Examples(
[
[
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/clip_example_1_ori.jpg",
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/clip_example_1_masked.webp",
"What fruit is in the left part of the fridge?",
"-",
"fruit",
30,
3,
"BICUBIC",
0,
"On the left side of the fridge, there are strawberries and blueberries visible in containers.",
"In the left part of the fridge, there are strawberries in a clear plastic container."
],
[
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/example_2_ori.jpg",
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/clip_example_2_masked.webp",
"On the right desk, what is to the left of the laptop?",
"-",
"desk",
5,
3,
"BICUBIC",
0,
"On the right desk, to the left of the laptop, there is a small pile of books stacked next to the laptop.",
"To the left of the laptop on the right desk, there is a desk lamp."
],
],
[
image_input,
image_output,
text_query,
text_pre_answer,
text_highlight_text,
slider_enhance_coe,
slider_kernel_size,
radio_interpolate_method_name,
slider_mask_grayscale,
text_original_image_response,
text_API_image_response
],
label = "Examples for CLIP_Based API"
)
def init_llava_examples():
return gr.Examples(
[
[
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/llava_example_1_ori.jpg",
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/llava_example_1_masked.webp",
"Which direction is the red bird facing?",
" The red bird is facing away from the camera.</s>",
"red bird",
20,
3,
"BICUBIC",
100,
"The red bird in the image is facing to the right.",
"The red bird is facing towards the left in the image."
],
[
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/example_2_ori.jpg",
"https://raw.githubusercontent.com/yu-rp/asserts/main/API_Prompting/llava_example_2_masked.webp",
"On the right desk, what is to the left of the laptop?",
" On the right desk, there is a lamp to the left of the laptop.</s>",
" On the right desk, there is a lamp to the left of the laptop.</s>",
30,
3,
"BICUBIC",
0,
"On the right desk, to the left of the laptop, there is a small pile of books stacked next to the laptop.",
"To the left of the laptop on the right desk, there is a green desk lamp."
],
],
[
image_input,
image_output,
text_query,
text_pre_answer,
text_highlight_text,
slider_enhance_coe,
slider_kernel_size,
radio_interpolate_method_name,
slider_mask_grayscale,
text_original_image_response,
text_API_image_response
],
label = "Examples for LLaVA_Based API"
)
def get_base64_images(image):
image = image.convert('RGB')
buffer = BytesIO()
image.save(buffer, format='JPEG')
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return image_base64
def vqa(image, question, api_key):
base64_image = get_base64_images(image)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": "gpt-4-turbo-2024-04-09",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": question
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail":"low"
}
}
]
}
],
"max_tokens": 300
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
return response.json()["choices"][0]["message"]["content"]
def compare(input_image, output_image, query, api_key):
original_response = vqa(input_image, query, api_key)
api_response = vqa(output_image, query, api_key)
return original_response, api_response
def init_clip():
clip_model, clip_prs, clip_preprocess, _, clip_tokenizer = clip_get_model(
model_name = "ViT-L-14-336" if torch.cuda.is_available() else "ViT-L-14",
layer_index = 22, device= DEVICE)
return {"clip_model": clip_model, "clip_prs": clip_prs, "clip_preprocess": clip_preprocess, "clip_tokenizer": clip_tokenizer}
def init_llava():
llava_tokenizer, llava_model, llava_image_processor, llava_context_len, llava_model_name = llava_get_model("llava-v1.5-13b", device= DEVICE)
llava_hl = llava_hook_logger(llava_model, DEVICE, layer_index = 20)
return {"llava_tokenizer": llava_tokenizer, "llava_model": llava_model, "llava_image_processor": llava_image_processor, "llava_context_len": llava_context_len, "llava_model_name": llava_model_name, "llava_hl": llava_hl}
def change_api_method(api_method):
new_text_pre_answer = gr.Textbox(
label="LLaVA Response",
info = 'Only used for LLaVA-Based API. Press "Pre-Answer" to generate the response.',
placeholder="",
value = "",
lines=4,
interactive=False,
type="text")
new_image_output = gr.Image(
label="API Masked Image",
type="pil",
interactive=False,
height=512
)
if api_method == "CLIP_Based API":
model_dict = init_clip()
new_generate_llava_response_button = gr.Button("Pre-Answer", interactive=False)
# new_examples = init_clip_examples()
elif api_method == "LLaVA_Based API":
model_dict = init_llava()
new_generate_llava_response_button = gr.Button("Pre-Answer", interactive=True)
# new_examples = init_llava_examples()
else:
raise NotImplementedError
return model_dict, {}, new_generate_llava_response_button, new_text_pre_answer, new_image_output#, new_examples
def clear_cache(cache_dict):
return {}
def clear_mask_cache(cache_dict):
if "llava_mask" in cache_dict.keys():
del cache_dict["llava_mask"]
if "clip_mask" in cache_dict.keys():
del cache_dict["clip_mask"]
return cache_dict
def llava_pre_answer(image, query, cache_dict, model_dict):
pre_answer, cache_dict_update = llava_get_preanswer(
model_dict["llava_model"],
model_dict["llava_model_name"],
model_dict["llava_hl"],
model_dict["llava_tokenizer"],
model_dict["llava_image_processor"],
model_dict["llava_context_len"],
query, image)
cache_dict.update(cache_dict_update)
return pre_answer, cache_dict
def generate_mask(
image,
query,
pre_answer,
highlight_text,
api_method,
enhance_coe,
kernel_size,
interpolate_method_name,
mask_grayscale,
cache_dict,
model_dict):
if api_method == "LLaVA_Based API":
assert highlight_text.strip() in pre_answer
if "llava_mask" in cache_dict.keys() and cache_dict["llava_mask"] is not None:
pass
else:
cache_dict["llava_mask"] = llava_from_preanswer_to_mask(highlight_text, pre_answer, cache_dict)
masked_image = llava_blend_mask(image, cache_dict["llava_mask"], enhance_coe, kernel_size, interpolate_method_name, mask_grayscale)
elif api_method == "CLIP_Based API":
# assert highlight_text in query
if "clip_mask" in cache_dict.keys() and cache_dict["clip_mask"] is not None:
pass
else:
cache_dict["clip_mask"] = clip_gen_mask(
model_dict["clip_model"],
model_dict["clip_prs"],
model_dict["clip_preprocess"],
DEVICE,
model_dict["clip_tokenizer"],
[image],
[highlight_text if highlight_text.strip() != "" else query])
masked_image = clip_blend_mask(image, *cache_dict["clip_mask"], enhance_coe, kernel_size, interpolate_method_name, mask_grayscale)
else:
raise NotImplementedError
return masked_image, cache_dict
image_input = gr.Image(
label="Input Image",
type="pil",
interactive=True,
height=512
)
image_output = gr.Image(
label="API Masked Image",
type="pil",
interactive=False,
height=512
)
text_query = gr.Textbox(
label="Query",
placeholder="Enter a query about the image",
lines=2,
type="text")
text_pre_answer = gr.Textbox(
label="LLaVA Response",
info = 'Only used for LLaVA-Based API. Press "Pre-Answer" to generate the response.',
placeholder="",
lines=2,
interactive=False,
type="text")
text_highlight_text = gr.Textbox(
label = "Hint Text",
info = "The text based on which the mask will be generated. For LLaVA-Based API, it should be a substring of the pre-answer.",
placeholder="Enter the hint text",
lines=1,
type="text")
text_api_token = gr.Textbox(
label = "OpenAI API Token",
placeholder="Input your OpenAI API token",
lines=1,
type="text")
text_original_image_response = gr.Textbox(
label="GPT Response (Original Image)",
placeholder="",
lines=2,
interactive=False,
type="text")
text_API_image_response = gr.Textbox(
label="GPT Response (API-maksed Image)",
placeholder="",
lines=2,
interactive=False,
type="text")
radio_api_method = gr.Radio(
["CLIP_Based API", "LLaVA_Based API"] if torch.cuda.is_available() else ["CLIP_Based API"],
interactive=True,
value = "CLIP_Based API",
label="Type of API")
slider_mask_grayscale = gr.Slider(
minimum=0,
maximum=255,
step = 0.5,
value=100,
interactive=True,
info = "0: black mask, 255: white mask.",
label="Grayscale")
slider_enhance_coe = gr.Slider(
minimum=1,
maximum=50,
step = 1,
value=1,
interactive=True,
info = "The larger contrast, the greater the contrast between the bright and dark areas of the mask.",
label="Contrast")
slider_kernel_size = gr.Slider(
minimum=1,
maximum=9,
step = 2,
value=1,
interactive=True,
info = "The larger smoothness, the smoother the mask appears, reducing the rectangular shapes.",
label="Smoothness")
radio_interpolate_method_name = gr.Radio(
["BICUBIC", "BILINEAR","BOX","LANCZOS", "NEAREST"],
value = "BICUBIC",
interactive=True,
label="Interpolation Method",
info="The interpolation method used during mask resizing.")
generate_llava_response_button = gr.Button("Pre-Answer", interactive=False)
generate_mask_button = gr.Button("API Go!")
ask_gpt_button = gr.Button("GPT Go!")
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
state_cache = gr.State({})
state_model = gr.State(init_clip())
with gr.Row():
image_input.render()
image_output.render()
with gr.Accordion("Query and API Processing"):
with gr.Row():
radio_api_method.render()
with gr.Row(equal_height=True):
with gr.Column():
text_query.render()
generate_llava_response_button.render()
text_pre_answer.render()
text_highlight_text.render()
with gr.Column():
slider_enhance_coe.render()
slider_kernel_size.render()
radio_interpolate_method_name.render()
slider_mask_grayscale.render()
with gr.Row():
generate_mask_button.render()
with gr.Accordion("GPT Response"):
text_api_token.render()
ask_gpt_button.render()
with gr.Row():
text_original_image_response.render()
text_API_image_response.render()
with gr.Accordion("Examples"):
clip_examples = init_clip_examples()
llava_examples = init_llava_examples()
radio_api_method.change(
fn=change_api_method,
inputs = [radio_api_method],
outputs=[state_model, state_cache, generate_llava_response_button, text_pre_answer, image_output]
)
image_input.change(
fn=clear_cache,
inputs = state_cache,
outputs=state_cache
)
text_query.change(
fn=clear_cache,
inputs = state_cache,
outputs=state_cache
)
text_highlight_text.change(
fn=clear_mask_cache,
inputs = state_cache,
outputs=state_cache
)
generate_llava_response_button.click(
fn=llava_pre_answer,
inputs=[image_input, text_query, state_cache, state_model],
outputs=[text_pre_answer, state_cache]
)
generate_mask_button.click(
fn=generate_mask,
inputs=[
image_input,
text_query,
text_pre_answer,
text_highlight_text,
radio_api_method,
slider_enhance_coe,
slider_kernel_size,
radio_interpolate_method_name,
slider_mask_grayscale,
state_cache,
state_model
],
outputs=[image_output, state_cache]
)
ask_gpt_button.click(
fn=compare,
inputs=[image_input, image_output, text_query, text_api_token],
outputs=[text_original_image_response, text_API_image_response]
)
demo.queue(max_size = 1).launch(show_error=True)