Jasper Lu
commited on
Commit
•
3455ede
1
Parent(s):
0fcb299
Needs to be inputs
Browse files- handler.py +3 -4
handler.py
CHANGED
@@ -13,7 +13,8 @@ class EndpointHandler():
|
|
13 |
self.model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
|
14 |
|
15 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
16 |
-
|
|
|
17 |
inputs = self.processor(images=image, return_tensors="pt")
|
18 |
|
19 |
with torch.no_grad():
|
@@ -23,8 +24,6 @@ class EndpointHandler():
|
|
23 |
embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
|
24 |
return {"embedding": embedding}
|
25 |
|
26 |
-
"""
|
27 |
handler = EndpointHandler()
|
28 |
-
output = handler({"
|
29 |
print(output)
|
30 |
-
"""
|
|
|
13 |
self.model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
|
14 |
|
15 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
16 |
+
url = data.pop("inputs", data)
|
17 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
18 |
inputs = self.processor(images=image, return_tensors="pt")
|
19 |
|
20 |
with torch.no_grad():
|
|
|
24 |
embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
|
25 |
return {"embedding": embedding}
|
26 |
|
|
|
27 |
handler = EndpointHandler()
|
28 |
+
output = handler({"inputs": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"})
|
29 |
print(output)
|
|