Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import torch
3
+ from transformers import InstructBlipForConditionalGeneration, InstructBlipTokenizer
4
+
5
+ class InstructBlipHandler:
6
+ def __init__(self, model, tokenizer):
7
+ self.model = model
8
+ self.tokenizer = tokenizer
9
+
10
+ def __call__(self, input_data):
11
+ # Preprocess the input data
12
+ inputs = self.preprocess(input_data)
13
+ # Generate the output using the model
14
+ outputs = self.model.generate(**inputs)
15
+ # Postprocess the output
16
+ result = self.postprocess(outputs)
17
+ return result
18
+
19
+ def preprocess(self, input_data):
20
+ image_data = input_data["image"]
21
+ text_prompt = input_data["text"]
22
+
23
+ image = torch.tensor(base64.b64decode(image_data)).unsqueeze(0)
24
+ text_inputs = self.tokenizer(text_prompt, return_tensors="pt")
25
+
26
+ inputs = {
27
+ "input_ids": text_inputs["input_ids"],
28
+ "attention_mask": text_inputs["attention_mask"],
29
+ "pixel_values": image
30
+ }
31
+ return inputs
32
+
33
+ def postprocess(self, outputs):
34
+ return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
35
+
36
+ model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl")
37
+ tokenizer = InstructBlipTokenizer.from_pretrained("Salesforce/instructblip-flan-t5-xl")
38
+ handler = InstructBlipHandler(model, tokenizer)