pdich2085 commited on
Commit
6e37f66
·
1 Parent(s): 1a5fe4e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -4
handler.py CHANGED
@@ -15,11 +15,13 @@ class EndpointHandler():
15
  self.model.eval()
16
  self.max_length = 16
17
  self.num_beams = 4
18
-
19
- def __call__(self, image_data: str) -> dict:
20
  try:
 
 
21
  # Convert base64 encoded image string to a PIL Image
22
- raw_image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
23
 
24
  # Ensure the image is in RGB mode
25
  if raw_image.mode != "RGB":
@@ -32,7 +34,7 @@ class EndpointHandler():
32
  gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
33
  output_ids = self.model.generate(pixel_values, **gen_kwargs)
34
 
35
- caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
36
 
37
  return {"caption": caption}
38
  except Exception as e:
 
15
  self.model.eval()
16
  self.max_length = 16
17
  self.num_beams = 4
18
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
+ # def __call__(self, image_data: str) -> dict:
20
  try:
21
+ image_bytes = data.get("inputs", None)
22
+
23
  # Convert base64 encoded image string to a PIL Image
24
+ raw_image = Image.open(BytesIO(image_bytes))
25
 
26
  # Ensure the image is in RGB mode
27
  if raw_image.mode != "RGB":
 
34
  gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
35
  output_ids = self.model.generate(pixel_values, **gen_kwargs)
36
 
37
+ caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
38
 
39
  return {"caption": caption}
40
  except Exception as e: