File size: 2,951 Bytes
2b2a664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
from transformers.tools import PipelineTool
from transformers.tools.base import get_default_device
from transformers.utils import requires_backends

class InstructBLIPImageQuestionAnsweringTool(PipelineTool):
    #default_checkpoint = "Salesforce/blip2-opt-2.7b"
    #default_checkpoint = "Salesforce/instructblip-flan-t5-xl"
    default_checkpoint = "Salesforce/instructblip-vicuna-7b"
    #default_checkpoint = "Salesforce/instructblip-vicuna-13b"

    description = (
        "This is a tool that answers a question about an image. It takes an input named `image` which should be the "
        "image containing the information, as well as a `question` which should be the question in English. It "
        "returns a text that is the answer to the question."
    )
    name = "image_qa"
    pre_processor_class = AutoProcessor
    model_class = AutoModelForVision2Seq
    inputs = ["image", "text"]
    outputs = ["text"]

    def __init__(self, *args, **kwargs):
        requires_backends(self, ["vision"])
        super().__init__(*args, **kwargs)

    def setup(self):
        """
        Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
        """
        if isinstance(self.pre_processor, str):
            self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)

        if isinstance(self.model, str):
            self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs, load_in_4bit=True, torch_dtype=torch.float16)

        if self.post_processor is None:
            self.post_processor = self.pre_processor
        elif isinstance(self.post_processor, str):
            self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)

        if self.device is None:
            if self.device_map is not None:
                self.device = list(self.model.hf_device_map.values())[0]
            else:
                self.device = get_default_device()

        # if self.device_map is None:
        #     self.model.to(self.device)
        self.is_initialized = True

    def encode(self, image, question: str):
        return self.pre_processor(images=image, text=question, return_tensors="pt").to(device="cuda", dtype=torch.float16)

    def forward(self, inputs):
        outputs = self.model.generate(
            **inputs,
            #max_new_tokens=50,
            num_beams=5,
            max_new_tokens=256,
            min_length=1,
            top_p=0.9,
            repetition_penalty=1.5,
            length_penalty=1.0,
            temperature=1,
        )
        return outputs

    def decode(self, outputs):
        return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()