Vidensogende commited on
Commit
692e318
·
1 Parent(s): 4134b65

updated handler to use or receive multiple images

Browse files
Files changed (1) hide show
  1. handler.py +36 -8
handler.py CHANGED
@@ -1,22 +1,50 @@
1
  import requests
2
  from PIL import Image
3
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
- from typing import Dict, List, Any
5
  import torch
 
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
- self.processor = Blip2Processor.from_pretrained("Salesforce/blip-image-captioning-large")
10
- self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
11
 
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
  self.model.to(self.device)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
- image = data.pop("inputs", data)
 
 
 
 
 
 
 
17
 
18
- processed = self.processor(images=image, return_tensors="pt").to(self.device)
 
 
 
19
 
20
- out = self.model.generate(**processed)
 
 
 
21
 
22
- return self.processor.decode(out[0], skip_special_tokens=True)
 
 
 
1
  import requests
2
  from PIL import Image
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
 
4
  import torch
5
+ from typing import Dict, List, Any
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
10
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
11
 
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
  self.model.to(self.device)
14
 
15
+ def process_single_image(self, img_url, text=None):
16
+ # Loading and processing the image
17
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
18
+ if text:
19
+ # Conditional image captioning
20
+ inputs = self.processor(raw_image, text, return_tensors="pt").to(self.device)
21
+ else:
22
+ # Unconditional image captioning
23
+ inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
24
+
25
+ out = self.model.generate(**inputs)
26
+ return self.processor.decode(out[0], skip_special_tokens=True)
27
+
28
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
+ try:
30
+ img_urls = data.get("image_urls")
31
+ texts = data.get("texts", [None] * len(img_urls)) # Texts are optional for conditional captioning
32
+
33
+ # Check if inputs are for single or multiple images
34
+ if isinstance(img_urls, str):
35
+ img_urls = [img_urls]
36
+ texts = [texts]
37
 
38
+ captions = []
39
+ for img_url, text in zip(img_urls, texts):
40
+ caption = self.process_single_image(img_url, text)
41
+ captions.append({"image_url": img_url, "caption": caption})
42
 
43
+ return captions
44
+ except Exception as e:
45
+ print(f"Error processing data: {e}")
46
+ return [{"error": str(e)}]
47
 
48
+ # You may need to add a function to load this handler if the inference toolkit expects it
49
+ def get_pipeline(model_dir, task):
50
+ return EndpointHandler(model_dir)