iamrobotbear's picture
Create handler.py
0e7d253
raw
history blame
1.37 kB
import base64
import torch
from transformers import InstructBlipForConditionalGeneration, InstructBlipTokenizer
class InstructBlipHandler:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def __call__(self, input_data):
# Preprocess the input data
inputs = self.preprocess(input_data)
# Generate the output using the model
outputs = self.model.generate(**inputs)
# Postprocess the output
result = self.postprocess(outputs)
return result
def preprocess(self, input_data):
image_data = input_data["image"]
text_prompt = input_data["text"]
image = torch.tensor(base64.b64decode(image_data)).unsqueeze(0)
text_inputs = self.tokenizer(text_prompt, return_tensors="pt")
inputs = {
"input_ids": text_inputs["input_ids"],
"attention_mask": text_inputs["attention_mask"],
"pixel_values": image
}
return inputs
def postprocess(self, outputs):
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl")
tokenizer = InstructBlipTokenizer.from_pretrained("Salesforce/instructblip-flan-t5-xl")
handler = InstructBlipHandler(model, tokenizer)