Jasper Lu
commited on
Commit
•
0fcb299
1
Parent(s):
d448e3a
Add handler
Browse files- handler.py +30 -0
- requirements.txt +24 -0
handler.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
import torch
|
3 |
+
from transformers import AutoProcessor, Pix2StructVisionModel
|
4 |
+
from PIL import Image
|
5 |
+
import pdb
|
6 |
+
import requests
|
7 |
+
|
8 |
+
class EndpointHandler():
|
9 |
+
def __init__(self, path=""):
|
10 |
+
#self.processor = AutoProcessor.from_pretrained("jasper-lu/pix2struct_embedding")
|
11 |
+
#self.model = MarkupLMModel.from_pretrained("jasper-lu/pix2struct_embedding")
|
12 |
+
self.processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
13 |
+
self.model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
|
14 |
+
|
15 |
+
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
16 |
+
image = Image.open(requests.get(data["url"], stream=True).raw)
|
17 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
18 |
+
|
19 |
+
with torch.no_grad():
|
20 |
+
outputs = self.model(**inputs)
|
21 |
+
|
22 |
+
last_hidden_state = outputs['last_hidden_state']
|
23 |
+
embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
|
24 |
+
return {"embedding": embedding}
|
25 |
+
|
26 |
+
"""
|
27 |
+
handler = EndpointHandler()
|
28 |
+
output = handler({"url": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"})
|
29 |
+
print(output)
|
30 |
+
"""
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
certifi==2023.7.22
|
2 |
+
charset-normalizer==3.3.2
|
3 |
+
filelock==3.13.1
|
4 |
+
fsspec==2023.10.0
|
5 |
+
huggingface-hub==0.17.3
|
6 |
+
idna==3.4
|
7 |
+
Jinja2==3.1.2
|
8 |
+
MarkupSafe==2.1.3
|
9 |
+
mpmath==1.3.0
|
10 |
+
networkx==3.2.1
|
11 |
+
numpy==1.26.2
|
12 |
+
packaging==23.2
|
13 |
+
Pillow==10.1.0
|
14 |
+
PyYAML==6.0.1
|
15 |
+
regex==2023.10.3
|
16 |
+
requests==2.31.0
|
17 |
+
safetensors==0.4.0
|
18 |
+
sympy==1.12
|
19 |
+
tokenizers==0.14.1
|
20 |
+
torch==2.1.0
|
21 |
+
tqdm==4.66.1
|
22 |
+
transformers==4.35.0
|
23 |
+
typing_extensions==4.8.0
|
24 |
+
urllib3==2.1.0
|