iamrobotbear commited on
Commit
e07636d
1 Parent(s): 6c0cf6b

Create handler.py

Browse files

Trying to add a handler.py to deploy this to Huggingface Inference Endpoints

Files changed (1) hide show
  1. handler.py +48 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from PIL import Image
3
+ import torch
4
+ from io import BytesIO
5
+ from transformers import BlipForConditionalGeneration, BlipProcessor, AutoModelForSeq2SeqLM, AutoTokenizer
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # load the Blip model and processor
12
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
13
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
14
+ self.blip_model.eval()
15
+
16
+ # load the Flan model and tokenizer
17
+ self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device)
18
+ self.flan_tokenizer = AutoTokenizer.from_pretrained(path)
19
+
20
+ def __call__(self, data: Any) -> Dict[str, Any]:
21
+ # process input
22
+ inputs = data.pop("inputs", data)
23
+ parameters = data.pop("parameters", {})
24
+
25
+ # preprocess image with Blip
26
+ raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
27
+ processed_image = self.blip_processor(images=raw_images, return_tensors="pt")
28
+ processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
29
+ processed_image = {**processed_image, **parameters}
30
+
31
+ # generate caption with Blip
32
+ with torch.no_grad():
33
+ out = self.blip_model.generate(**processed_image)
34
+ captions = self.blip_processor.batch_decode(out, skip_special_tokens=True)
35
+
36
+ # preprocess caption with Flan
37
+ input_ids = self.flan_tokenizer(captions, return_tensors="pt").input_ids
38
+
39
+ # generate text with Flan
40
+ if parameters is not None:
41
+ outputs = self.flan_model.generate(input_ids, **parameters)
42
+ else:
43
+ outputs = self.flan_model.generate(input_ids)
44
+
45
+ # postprocess the prediction
46
+ prediction = self.flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+
48
+ return [{"generated_text": prediction}]