Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
from API_LLaVA.functions import get_model as llava_get_model, get_preanswer as llava_get_preanswer, from_preanswer_to_mask as llava_from_preanswer_to_mask | |
from API_LLaVA.hook import hook_logger as llava_hook_logger | |
from API_LLaVA.main import blend_mask as llava_blend_mask | |
from API_CLIP.main import get_model as clip_get_model, gen_mask as clip_gen_mask, blend_mask as clip_blend_mask | |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
MARKDOWN = """ | |
<div align='center'> | |
<h1> | |
API: Attention Prompting on Image for Large Vision-Language Models | |
</h1> | |
<br> | |
[<a href="https://arxiv.org/abs/"> arXiv paper </a>] | |
[<a href="https://api.github.io"> project page </a>] | |
[<a href="https://github.com/roboflow/api"> python package </a>] | |
[<a href="https://github.com/yu-rp/apiprompting"> code </a>] | |
</div> | |
""" | |
def init_clip(): | |
clip_model, clip_prs, clip_preprocess, _, clip_tokenizer = clip_get_model(model_name = "ViT-L-14-336", layer_index = 22, device= DEVICE) | |
return {"clip_model": clip_model, "clip_prs": clip_prs, "clip_preprocess": clip_preprocess, "clip_tokenizer": clip_tokenizer} | |
def init_llava(): | |
llava_tokenizer, llava_model, llava_image_processor, llava_context_len, llava_model_name = llava_get_model("llava-v1.5-13b", device= DEVICE) | |
llava_hl = llava_hook_logger(llava_model, DEVICE, layer_index = 20) | |
return {"llava_tokenizer": llava_tokenizer, "llava_model": llava_model, "llava_image_processor": llava_image_processor, "llava_context_len": llava_context_len, "llava_model_name": llava_model_name, "llava_hl": llava_hl} | |
def change_api_method(api_method): | |
new_text_pre_answer = gr.Textbox( | |
label="LLaVA Response", | |
info = 'Only used for LLaVA-Based API. Press "Pre-Answer" to generate the response.', | |
placeholder="", | |
value = "", | |
lines=4, | |
interactive=False, | |
type="text") | |
new_image_output = gr.Image( | |
label="API Masked Image", | |
type="pil", | |
interactive=False, | |
height=512 | |
) | |
if api_method == "CLIP_Based API": | |
model_dict = init_clip() | |
new_generate_llava_response_button = gr.Button("Pre-Answer", interactive=False) | |
elif api_method == "LLaVA_Based API": | |
model_dict = init_llava() | |
new_generate_llava_response_button = gr.Button("Pre-Answer", interactive=True) | |
else: | |
raise NotImplementedError | |
return model_dict, {}, new_generate_llava_response_button, new_text_pre_answer, new_image_output | |
def clear_cache(cache_dict): | |
return {} | |
def clear_mask_cache(cache_dict): | |
if "llava_mask" in cache_dict.keys(): | |
del cache_dict["llava_mask"] | |
if "clip_mask" in cache_dict.keys(): | |
del cache_dict["clip_mask"] | |
return cache_dict | |
def llava_pre_answer(image, query, cache_dict, model_dict): | |
pre_answer, cache_dict_update = llava_get_preanswer( | |
model_dict["llava_model"], | |
model_dict["llava_model_name"], | |
model_dict["llava_hl"], | |
model_dict["llava_tokenizer"], | |
model_dict["llava_image_processor"], | |
model_dict["llava_context_len"], | |
query, image) | |
cache_dict.update(cache_dict_update) | |
return pre_answer, cache_dict | |
def generate_mask( | |
image, | |
query, | |
pre_answer, | |
highlight_text, | |
api_method, | |
enhance_coe, | |
kernel_size, | |
interpolate_method_name, | |
mask_grayscale, | |
cache_dict, | |
model_dict): | |
if api_method == "LLaVA_Based API": | |
assert highlight_text.strip() in pre_answer | |
if "llava_mask" in cache_dict.keys() and cache_dict["llava_mask"] is not None: | |
pass | |
else: | |
cache_dict["llava_mask"] = llava_from_preanswer_to_mask(highlight_text, pre_answer, cache_dict) | |
masked_image = llava_blend_mask(image, cache_dict["llava_mask"], enhance_coe, kernel_size, interpolate_method_name, mask_grayscale) | |
elif api_method == "CLIP_Based API": | |
# assert highlight_text in query | |
if "clip_mask" in cache_dict.keys() and cache_dict["clip_mask"] is not None: | |
pass | |
else: | |
cache_dict["clip_mask"] = clip_gen_mask( | |
model_dict["clip_model"], | |
model_dict["clip_prs"], | |
model_dict["clip_preprocess"], | |
DEVICE, | |
model_dict["clip_tokenizer"], | |
[image], | |
[highlight_text if highlight_text.strip() != "" else query]) | |
masked_image = clip_blend_mask(image, *cache_dict["clip_mask"], enhance_coe, kernel_size, interpolate_method_name, mask_grayscale) | |
else: | |
raise NotImplementedError | |
return masked_image, cache_dict | |
image_input = gr.Image( | |
label="Input Image", | |
type="pil", | |
interactive=True, | |
height=512 | |
) | |
image_output = gr.Image( | |
label="API Masked Image", | |
type="pil", | |
interactive=False, | |
height=512 | |
) | |
text_query = gr.Textbox( | |
label="Query", | |
placeholder="Enter a query about the image", | |
lines=4, | |
type="text") | |
text_pre_answer = gr.Textbox( | |
label="LLaVA Response", | |
info = 'Only used for LLaVA-Based API. Press "Pre-Answer" to generate the response.', | |
placeholder="", | |
lines=4, | |
interactive=False, | |
type="text") | |
text_highlight_text = gr.Textbox( | |
label = "Hint Text.", | |
info = "The text based on which the mask will be generated. For CLIP-Based API, it should be a substring of the query. For LLaVA-Based API, it should be a substring of the pre-answer.", | |
placeholder="Enter the hint text", | |
lines=1, | |
type="text") | |
radio_api_method = gr.Radio( | |
["CLIP_Based API", "LLaVA_Based API"] if torch.cuda.is_available() else ["CLIP_Based API"], | |
interactive=True, | |
value = "CLIP_Based API", | |
label="Type of API") | |
slider_mask_grayscale = gr.Slider( | |
minimum=0, | |
maximum=255, | |
step = 0.5, | |
value=100, | |
interactive=True, | |
info = "0: black mask, 255: white mask.", | |
label="Grayscale") | |
slider_enhance_coe = gr.Slider( | |
minimum=1, | |
maximum=50, | |
step = 1, | |
value=1, | |
interactive=True, | |
info = "The larger contrast, the greater the contrast between the bright and dark areas of the mask.", | |
label="Contrast") | |
slider_kernel_size = gr.Slider( | |
minimum=1, | |
maximum=9, | |
step = 2, | |
value=1, | |
interactive=True, | |
info = "The larger smoothness, the smoother the mask appears, reducing the rectangular shapes.", | |
label="Smoothness") | |
radio_interpolate_method_name = gr.Radio( | |
["BICUBIC", "BILINEAR","BOX","LANCZOS", "NEAREST"], | |
value = "BICUBIC", | |
interactive=True, | |
label="Interpolation Method", | |
info="The interpolation method used during mask resizing.") | |
generate_llava_response_button = gr.Button("Pre-Answer", interactive=False) | |
generate_mask_button = gr.Button("API Go!") | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
state_cache = gr.State({}) | |
state_model = gr.State(init_clip()) | |
with gr.Row(): | |
with gr.Column(): | |
image_input.render() | |
with gr.Column(): | |
image_output.render() | |
with gr.Row(): | |
radio_api_method.render() | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
text_query.render() | |
with gr.Row(): | |
generate_llava_response_button.render() | |
with gr.Row(): | |
text_pre_answer.render() | |
with gr.Row(): | |
text_highlight_text.render() | |
with gr.Column(): | |
with gr.Row(): | |
slider_enhance_coe.render() | |
with gr.Row(): | |
slider_kernel_size.render() | |
with gr.Row(): | |
radio_interpolate_method_name.render() | |
with gr.Row(): | |
slider_mask_grayscale.render() | |
generate_mask_button.render() | |
radio_api_method.change( | |
fn=change_api_method, | |
inputs = [radio_api_method], | |
outputs=[state_model, state_cache, generate_llava_response_button, text_pre_answer, image_output] | |
) | |
image_input.change( | |
fn=clear_cache, | |
inputs = state_cache, | |
outputs=state_cache | |
) | |
text_query.change( | |
fn=clear_cache, | |
inputs = state_cache, | |
outputs=state_cache | |
) | |
text_highlight_text.change( | |
fn=clear_mask_cache, | |
inputs = state_cache, | |
outputs=state_cache | |
) | |
generate_llava_response_button.click( | |
fn=llava_pre_answer, | |
inputs=[image_input, text_query, state_cache, state_model], | |
outputs=[text_pre_answer, state_cache] | |
) | |
generate_mask_button.click( | |
fn=generate_mask, | |
inputs=[ | |
image_input, | |
text_query, | |
text_pre_answer, | |
text_highlight_text, | |
radio_api_method, | |
slider_enhance_coe, | |
slider_kernel_size, | |
radio_interpolate_method_name, | |
slider_mask_grayscale, | |
state_cache, | |
state_model | |
], | |
outputs=[image_output, state_cache] | |
) | |
demo.queue(max_size = 1).launch(show_error=True) |