mattmdjaga commited on
Commit
d364efb
·
1 Parent(s): d8e929e

added custom handler

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import base64
6
+ import torch
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path="."):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ self.model = CLIPModel.from_pretrained(path).to(self.device).eval()
13
+ self.processor = CLIPProcessor.from_pretrained(path)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ data args:
18
+ images (:obj:`PIL.Image`)
19
+ candiates (:obj:`list`)
20
+ Return:
21
+ A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
22
+ """
23
+ inputs = data.pop("inputs", data)
24
+
25
+ # decode base64 image to PIL
26
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
27
+ txt = inputs['text']
28
+ # preprocess image
29
+ txt = self.processor(text=txt, return_tensors="pt",padding=True).to(self.device)
30
+ image = self.processor(images=image, return_tensors="pt",padding=True).to(self.device)
31
+ with torch.no_grad():
32
+ txt_features = self.model.get_text_features(**txt)
33
+ image_features = self.model.get_image_features(**image)
34
+ img = image_features.tolist()
35
+ txt = txt_features.tolist()
36
+ pred = {"image": img, "text": txt}
37
+
38
+ return pred