File size: 3,227 Bytes
120a3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from PIL import Image

import requests
import json
import gradio as gr


from io import BytesIO

def encode_image(image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    buffered.seek(0)

    return buffered


def query_api(image, prompt, decoding_method):
    # local host for testing
    url = "http://34.132.142.70:5000/api/generate"

    data = {"prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling"}
    
    image = encode_image(image)
    files = {"image": image}

    response = requests.post(url, data=data, files=files)

    if response.status_code == 200:
        return response.json()
    else:
        return "Error: " + response.text


def prepend_question(text):
    text = text.strip().lower()

    return "question: " + text
    

def prepend_answer(text):
    text = text.strip().lower()

    return "answer: " + text


def get_prompt_from_history(history):
    prompts = []

    for i in range(len(history)):
        if i % 2 == 0:
            prompts.append(prepend_question(history[i]))
        else:
            prompts.append(prepend_answer(history[i]))
    
    return "\n".join(prompts)


def postp_answer(text):
    if text.startswith("answer: "):
        return text[8:]
    elif text.startswith("a: "):
        return text[2:]
    else:
        return text


def prep_question(text):
    if text.startswith("question: "):
        text = text[10:]
    elif text.startswith("q: "):
        text = text[2:]
    
    if not text.endswith("?"):
        text += "?"
    
    return text


def inference(image, text_input, decoding_method, history=[]):
    text_input = prep_question(text_input)
    history.append(text_input)

    # prompt = '\n'.join(history)
    prompt = get_prompt_from_history(history)
    # print("prompt: " + prompt)

    output = query_api(image, prompt, decoding_method)
    output = [postp_answer(output[0])]
    history += output
    
    chat = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)]  # convert to tuples of list
        
    return chat, history


inputs = [gr.inputs.Image(type='pil'),
          gr.inputs.Textbox(lines=2, label="Text input"),
          gr.inputs.Radio(choices=['Nucleus sampling','Beam search'], type="value", default="Nucleus sampling", label="Text Decoding Method"),
          "state",
         ]

outputs = ["chatbot", "state"]
           
title = "BLIP-2"
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> 
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"

iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article)
iface.launch(enable_queue=True)