Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForVision2Seq, AutoProcessor | |
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
from transformers.tools import PipelineTool | |
from transformers.tools.base import get_default_device | |
from transformers.utils import requires_backends | |
class InstructBLIPImageQuestionAnsweringTool(PipelineTool): | |
#default_checkpoint = "Salesforce/blip2-opt-2.7b" | |
#default_checkpoint = "Salesforce/instructblip-flan-t5-xl" | |
default_checkpoint = "Salesforce/instructblip-vicuna-7b" | |
#default_checkpoint = "Salesforce/instructblip-vicuna-13b" | |
description = ( | |
"This is a tool that answers a question about an image. It takes an input named `image` which should be the " | |
"image containing the information, as well as a `question` which should be the question in English. It " | |
"returns a text that is the answer to the question." | |
) | |
name = "image_qa" | |
pre_processor_class = AutoProcessor | |
model_class = AutoModelForVision2Seq | |
inputs = ["image", "text"] | |
outputs = ["text"] | |
def __init__(self, *args, **kwargs): | |
requires_backends(self, ["vision"]) | |
super().__init__(*args, **kwargs) | |
def setup(self): | |
""" | |
Instantiates the `pre_processor`, `model` and `post_processor` if necessary. | |
""" | |
if isinstance(self.pre_processor, str): | |
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) | |
if isinstance(self.model, str): | |
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs, load_in_4bit=True, torch_dtype=torch.float16) | |
if self.post_processor is None: | |
self.post_processor = self.pre_processor | |
elif isinstance(self.post_processor, str): | |
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) | |
if self.device is None: | |
if self.device_map is not None: | |
self.device = list(self.model.hf_device_map.values())[0] | |
else: | |
self.device = get_default_device() | |
self.is_initialized = True | |
def encode(self, image, question: str): | |
return self.pre_processor(images=image, text=question, return_tensors="pt").to(device="cuda", dtype=torch.float16) | |
def forward(self, inputs): | |
outputs = self.model.generate( | |
**inputs, | |
num_beams=5, | |
max_new_tokens=256, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=1.5, | |
length_penalty=1.0, | |
temperature=0.7, | |
) | |
return outputs | |
def decode(self, outputs): | |
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() | |