Jasper Lu commited on
Commit
5fe03c4
1 Parent(s): b34fba5
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -13,12 +13,12 @@ class EndpointHandler():
13
  #self.model = MarkupLMModel.from_pretrained("jasper-lu/pix2struct_embedding")
14
  self.processor = AutoProcessor.from_pretrained(MODEL)
15
  self.processor.image_processor.is_vqa = False
16
- self.model = Pix2StructVisionModel.from_pretrained(MODEL)
17
 
18
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
19
  url = data.pop("inputs", data)
20
  image = Image.open(requests.get(url, stream=True).raw)
21
- inputs = self.processor(images=image, return_tensors="pt")
22
 
23
  with torch.no_grad():
24
  outputs = self.model(**inputs)
 
13
  #self.model = MarkupLMModel.from_pretrained("jasper-lu/pix2struct_embedding")
14
  self.processor = AutoProcessor.from_pretrained(MODEL)
15
  self.processor.image_processor.is_vqa = False
16
+ self.model = Pix2StructVisionModel.from_pretrained(MODEL).cuda()
17
 
18
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
19
  url = data.pop("inputs", data)
20
  image = Image.open(requests.get(url, stream=True).raw)
21
+ inputs = self.processor(images=image, return_tensors="pt").cuda()
22
 
23
  with torch.no_grad():
24
  outputs = self.model(**inputs)