Jasper Lu commited on
Commit
3455ede
1 Parent(s): 0fcb299

Needs to be inputs

Browse files
Files changed (1) hide show
  1. handler.py +3 -4
handler.py CHANGED
@@ -13,7 +13,8 @@ class EndpointHandler():
13
  self.model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
14
 
15
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
16
- image = Image.open(requests.get(data["url"], stream=True).raw)
 
17
  inputs = self.processor(images=image, return_tensors="pt")
18
 
19
  with torch.no_grad():
@@ -23,8 +24,6 @@ class EndpointHandler():
23
  embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
24
  return {"embedding": embedding}
25
 
26
- """
27
  handler = EndpointHandler()
28
- output = handler({"url": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"})
29
  print(output)
30
- """
 
13
  self.model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
14
 
15
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
16
+ url = data.pop("inputs", data)
17
+ image = Image.open(requests.get(url, stream=True).raw)
18
  inputs = self.processor(images=image, return_tensors="pt")
19
 
20
  with torch.no_grad():
 
24
  embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
25
  return {"embedding": embedding}
26
 
 
27
  handler = EndpointHandler()
28
+ output = handler({"inputs": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"})
29
  print(output)