|
from PIL import Image |
|
|
|
import requests |
|
import json |
|
import gradio as gr |
|
|
|
|
|
from io import BytesIO |
|
|
|
def encode_image(image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
buffered.seek(0) |
|
|
|
return buffered |
|
|
|
|
|
def query_api(image, prompt, decoding_method): |
|
|
|
url = "http://34.132.142.70:5000/api/generate" |
|
|
|
headers = { |
|
'User-Agent': 'BLIP-2 HuggingFace Space' |
|
} |
|
|
|
data = {"prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling"} |
|
|
|
image = encode_image(image) |
|
files = {"image": image} |
|
|
|
response = requests.post(url, data=data, files=files, headers=headers) |
|
|
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
return "Error: " + response.text |
|
|
|
|
|
def prepend_question(text): |
|
text = text.strip().lower() |
|
|
|
return "question: " + text |
|
|
|
|
|
def prepend_answer(text): |
|
text = text.strip().lower() |
|
|
|
return "answer: " + text |
|
|
|
|
|
def get_prompt_from_history(history): |
|
prompts = [] |
|
|
|
for i in range(len(history)): |
|
if i % 2 == 0: |
|
prompts.append(prepend_question(history[i])) |
|
else: |
|
prompts.append(prepend_answer(history[i])) |
|
|
|
return "\n".join(prompts) |
|
|
|
|
|
def postp_answer(text): |
|
if text.startswith("answer: "): |
|
return text[8:] |
|
elif text.startswith("a: "): |
|
return text[2:] |
|
else: |
|
return text |
|
|
|
|
|
def prep_question(text): |
|
if text.startswith("question: "): |
|
text = text[10:] |
|
elif text.startswith("q: "): |
|
text = text[2:] |
|
|
|
if not text.endswith("?"): |
|
text += "?" |
|
|
|
return text |
|
|
|
|
|
def inference(image, text_input, decoding_method, history=[]): |
|
text_input = prep_question(text_input) |
|
history.append(text_input) |
|
|
|
|
|
prompt = get_prompt_from_history(history) |
|
|
|
|
|
output = query_api(image, prompt, decoding_method) |
|
output = [postp_answer(output[0])] |
|
history += output |
|
|
|
chat = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] |
|
|
|
return chat, history |
|
|
|
|
|
inputs = [gr.inputs.Image(type='pil'), |
|
gr.inputs.Textbox(lines=2, label="Text input"), |
|
gr.inputs.Radio(choices=['Nucleus sampling','Beam search'], type="value", default="Nucleus sampling", label="Text Decoding Method"), |
|
"state", |
|
] |
|
|
|
outputs = ["chatbot", "state"] |
|
|
|
title = "BLIP-2" |
|
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> |
|
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>""" |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>" |
|
|
|
iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article) |
|
iface.launch(enable_queue=True) |