File size: 3,460 Bytes
732325f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr 
from lavis.models import load_model_and_preprocess
import torch

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"


model_name = "blip2_t5_instruct"
model_type = "flant5xl"
model, vis_processors, _ = load_model_and_preprocess(
    name=args.model_name,
    model_type=args.model_type,
    is_eval=True,
    device=device,
)

def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method):
        use_nucleus_sampling = decoding_method == "Nucleus sampling"
        print(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling)
        image = vis_processors["eval"](image).unsqueeze(0).to(device)

        samples = {
            "image": image,
            "prompt": prompt,
        }

        output = model.generate(
            samples,
            length_penalty=float(len_penalty),
            repetition_penalty=float(repetition_penalty),
            num_beams=beam_size,
            max_length=max_len,
            min_length=min_len,
            top_p=top_p,
            use_nucleus_sampling=use_nucleus_sampling,
        )

        return output[0]
    
theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
css = ".generating {visibility: hidden}"

with gr.Blocks(theme=theme, analytics_enabled=False,css=css) as demo:
    with gr.Column(scale=3):
        image_input = gr.Image(type="pil")
        prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2)
        output = gr.Textbox(label="Output")
        submit = gr.Button("Run", variant="primary")

    with gr.Column(scale=1):
        min_len = gr.Slider(
            minimum=1,
            maximum=50,
            value=1,
            step=1,
            interactive=True,
            label="Min Length",
        )
    
        max_len = gr.Slider(
            minimum=10,
            maximum=500,
            value=250,
            step=5,
            interactive=True,
            label="Max Length",
        )
    
        sampling = gr.Radio(
            choices=["Beam search", "Nucleus sampling"],
            value="Beam search",
            label="Text Decoding Method",
            interactive=True,
        )
    
        top_p = gr.Slider(
            minimum=0.5,
            maximum=1.0,
            value=0.9,
            step=0.1,
            interactive=True,
            label="Top p",
        )
    
        beam_size = gr.Slider(
            minimum=1,
            maximum=10,
            value=5,
            step=1,
            interactive=True,
            label="Beam Size",
        )
    
        len_penalty = gr.Slider(
            minimum=-1,
            maximum=2,
            value=1,
            step=0.2,
            interactive=True,
            label="Length Penalty",
        )
    
        repetition_penalty = gr.Slider(
            minimum=-1,
            maximum=3,
            value=1,
            step=0.2,
            interactive=True,
            label="Repetition Penalty",
        )
    
    submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output])

demo.queue(concurrency_count=16).launch(debug=True)