anvilum commited on
Commit
d5da980
·
1 Parent(s): 0603c91

modify handler to receive a list of url instead of image bytes

Browse files
Files changed (1) hide show
  1. handler.py +34 -14
handler.py CHANGED
@@ -2,15 +2,16 @@ import requests
2
  from typing import Dict, Any
3
  from PIL import Image
4
  import torch
5
- import base64
6
  from io import BytesIO
7
  from transformers import BlipForConditionalGeneration, BlipProcessor
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
- self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
 
14
  self.model = BlipForConditionalGeneration.from_pretrained(
15
  "Salesforce/blip-image-captioning-large"
16
  ).to(device)
@@ -18,29 +19,48 @@ class EndpointHandler():
18
 
19
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
  input_data = data.get("inputs", {})
21
- encoded_images = input_data.get("images")
22
-
23
- if not encoded_images:
24
- return {"captions": [], "error": "No images provided"}
 
 
 
 
 
 
 
 
 
25
 
26
- texts = input_data.get("texts", ["a photography of"] * len(encoded_images))
27
 
28
  try:
29
- raw_images = [Image.open(BytesIO(base64.b64decode(img))).convert("RGB") for img in encoded_images]
 
 
30
  processed_inputs = [
31
- self.processor(image, text, return_tensors="pt") for image, text in zip(raw_images, texts)
 
32
  ]
33
  processed_inputs = {
34
- "pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
35
- "input_ids": torch.cat([inp["input_ids"] for inp in processed_inputs], dim=0).to(device),
36
- "attention_mask": torch.cat([inp["attention_mask"] for inp in processed_inputs], dim=0).to(device)
 
 
 
 
 
 
37
  }
38
 
39
  with torch.no_grad():
40
  out = self.model.generate(**processed_inputs)
41
 
42
- captions = self.processor.batch_decode(out, skip_special_tokens=True)
 
43
  return {"captions": captions}
44
  except Exception as e:
45
  print(f"Error during processing: {str(e)}")
46
- return {"captions": [], "error": str(e)}
 
2
  from typing import Dict, Any
3
  from PIL import Image
4
  import torch
 
5
  from io import BytesIO
6
  from transformers import BlipForConditionalGeneration, BlipProcessor
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
+
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
+ self.processor = BlipProcessor.from_pretrained(
14
+ "Salesforce/blip-image-captioning-large")
15
  self.model = BlipForConditionalGeneration.from_pretrained(
16
  "Salesforce/blip-image-captioning-large"
17
  ).to(device)
 
19
 
20
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
21
  input_data = data.get("inputs", {})
22
+ image_urls = input_data.get("image_urls", [])
23
+
24
+ if not image_urls:
25
+ return {"captions": [], "error": "No images provided"}
26
+
27
+ texts = input_data.get(
28
+ "texts", [""] * len(image_urls))
29
+
30
+ if len(image_urls) != len(texts):
31
+ return {
32
+ "captions": [],
33
+ "error": "Texts and images should have the same length"
34
+ }
35
 
36
+ images_data = [requests.get(url).content for url in image_urls]
37
 
38
  try:
39
+ raw_images = [
40
+ Image.open(BytesIO((img))).convert("RGB")
41
+ for img in images_data]
42
  processed_inputs = [
43
+ self.processor(image, text, return_tensors="pt")
44
+ for image, text in zip(raw_images, texts)
45
  ]
46
  processed_inputs = {
47
+ "pixel_values": torch.cat(
48
+ [inp["pixel_values"]
49
+ for inp in processed_inputs], dim=0).to(device),
50
+ "input_ids": torch.cat(
51
+ [inp["input_ids"]
52
+ for inp in processed_inputs], dim=0).to(device),
53
+ "attention_mask": torch.cat(
54
+ [inp["attention_mask"]
55
+ for inp in processed_inputs], dim=0).to(device)
56
  }
57
 
58
  with torch.no_grad():
59
  out = self.model.generate(**processed_inputs)
60
 
61
+ captions = self.processor.batch_decode(
62
+ out, skip_special_tokens=True)
63
  return {"captions": captions}
64
  except Exception as e:
65
  print(f"Error during processing: {str(e)}")
66
+ return {"captions": [], "error": str(e)}