File size: 2,710 Bytes
5110eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c157b6
5110eb7
ee6da3d
5110eb7
 
7130804
5110eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import urllib.request
import io
from pathlib import Path

from blip_vqa import blip_vqa

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384

class App():
    def __init__(self):
        self.selected_model=0
        
        # Load blip for question answer
        print("Loading Blip for question answering")
        model_url = str(Path(__file__).parent/'blip_vqa.pth')
        self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
        self.qa_model.eval()
        self.qa_model = self.qa_model.to(device)

        
        
        with gr.Blocks() as demo:
            gr.Markdown("# BLIP Image question and answer\nThis model allows you to ask questions about an image and get solid answers.\nIt can be used to caption images for stable diffusion fine tuning purposes or many other applications.\nBrought to gradio by @ParisNeo from the original github Blip code [https://github.com/salesforce/BLIP](https://github.com/salesforce/BLIP)\nThis model is described in this paper :[https://arxiv.org/abs/2201.12086](https://arxiv.org/abs/2201.12086)")
            with gr.Row():
                self.image_source = gr.inputs.Image(shape=(448, 448))
                with gr.Tabs():
                    with gr.Tab("Question/Answer"):
                        self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="Describe this image")
                        self.answer = gr.Button("Ask")
                        self.lbl_caption = gr.outputs.Label(label="Caption")
                        self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption)
        # Launch the interface
        demo.launch()
        
        

    def answer_question_image(self, img, custom_question="Describe this image"):
        # Load the selected PyTorch model
        
        # Preprocess the image
        preprocess = transforms.Compose([
            transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])
        img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB'))
        
        # Make a prediction with the model
        with torch.no_grad():
            output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate') 
            answer = output
        
        # Return the predicted label as a string
        return answer[0]

app = App()