Jasper Lu commited on
Commit
0fcb299
1 Parent(s): d448e3a

Add handler

Browse files
Files changed (2) hide show
  1. handler.py +30 -0
  2. requirements.txt +24 -0
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import AutoProcessor, Pix2StructVisionModel
4
+ from PIL import Image
5
+ import pdb
6
+ import requests
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ #self.processor = AutoProcessor.from_pretrained("jasper-lu/pix2struct_embedding")
11
+ #self.model = MarkupLMModel.from_pretrained("jasper-lu/pix2struct_embedding")
12
+ self.processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
13
+ self.model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
14
+
15
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
16
+ image = Image.open(requests.get(data["url"], stream=True).raw)
17
+ inputs = self.processor(images=image, return_tensors="pt")
18
+
19
+ with torch.no_grad():
20
+ outputs = self.model(**inputs)
21
+
22
+ last_hidden_state = outputs['last_hidden_state']
23
+ embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
24
+ return {"embedding": embedding}
25
+
26
+ """
27
+ handler = EndpointHandler()
28
+ output = handler({"url": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"})
29
+ print(output)
30
+ """
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2023.7.22
2
+ charset-normalizer==3.3.2
3
+ filelock==3.13.1
4
+ fsspec==2023.10.0
5
+ huggingface-hub==0.17.3
6
+ idna==3.4
7
+ Jinja2==3.1.2
8
+ MarkupSafe==2.1.3
9
+ mpmath==1.3.0
10
+ networkx==3.2.1
11
+ numpy==1.26.2
12
+ packaging==23.2
13
+ Pillow==10.1.0
14
+ PyYAML==6.0.1
15
+ regex==2023.10.3
16
+ requests==2.31.0
17
+ safetensors==0.4.0
18
+ sympy==1.12
19
+ tokenizers==0.14.1
20
+ torch==2.1.0
21
+ tqdm==4.66.1
22
+ transformers==4.35.0
23
+ typing_extensions==4.8.0
24
+ urllib3==2.1.0