instructblip / app.py
WhiteWolf21's picture
Initialization
be13417
raw
history blame
3.43 kB
import gradio as gr
from lavis.models import load_model_and_preprocess
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--model-name", default="blip2_vicuna_instruct")
parser.add_argument("--model-type", default="vicuna7b")
args = parser.parse_args()
image_input = gr.Image(type="pil")
min_len = gr.Slider(
minimum=1,
maximum=50,
value=1,
step=1,
interactive=True,
label="Min Length",
)
max_len = gr.Slider(
minimum=10,
maximum=500,
value=250,
step=5,
interactive=True,
label="Max Length",
)
sampling = gr.Radio(
choices=["Beam search", "Nucleus sampling"],
value="Beam search",
label="Text Decoding Method",
interactive=True,
)
top_p = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.9,
step=0.1,
interactive=True,
label="Top p",
)
beam_size = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
interactive=True,
label="Beam Size",
)
len_penalty = gr.Slider(
minimum=-1,
maximum=2,
value=1,
step=0.2,
interactive=True,
label="Length Penalty",
)
repetition_penalty = gr.Slider(
minimum=-1,
maximum=3,
value=1,
step=0.2,
interactive=True,
label="Repetition Penalty",
)
# prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print('Loading model...')
model, vis_processors, _ = load_model_and_preprocess(
name=args.model_name,
model_type=args.model_type,
is_eval=True,
device=device,
)
print('Loading model done!')
# def inference(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, modeltype):
def inference(image, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, modeltype):
use_nucleus_sampling = decoding_method == "Nucleus sampling"
# print(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling)
image = vis_processors["eval"](image).unsqueeze(0).to(device)
samples = {
"image": image,
# "prompt": prompt,
"prompt": "Describe the image in detail and where are the violence objects position in the image (center, left, right, top, bottom)."
}
output = model.generate(
samples,
length_penalty=float(len_penalty),
repetition_penalty=float(repetition_penalty),
num_beams=beam_size,
max_length=max_len,
min_length=min_len,
top_p=top_p,
use_nucleus_sampling=use_nucleus_sampling,
)
return output[0]
gr.Interface(
fn=inference,
# inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling],
inputs=[image_input, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling],
outputs="text",
allow_flagging="never",
).launch()