adasdimchom commited on
Commit
61b80cc
1 Parent(s): 3daa1cf

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -3
handler.py CHANGED
@@ -14,19 +14,47 @@ class EndpointHandler():
14
  self.processor = Blip2Processor.from_pretrained(path)
15
  self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
16
  self.generate_model.to(self.device)
 
 
 
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
20
- data args:
21
  inputs (:obj: `str` | `PIL.Image` | `np.array`)
22
  kwargs
23
- Return:
24
  A :obj:`list` | `dict`: will be serialized and returned
25
  """
 
26
  inputs = data.pop("inputs", data)
27
  image_url = inputs['image_url']
 
 
 
 
 
 
 
 
 
28
  image = Image.open(requests.get(image_url, stream=True).raw)
29
  processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
30
  generated_ids = self.generate_model.generate(**processed_image)
31
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
- return image_url, generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self.processor = Blip2Processor.from_pretrained(path)
15
  self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
16
  self.generate_model.to(self.device)
17
+
18
+ self.feature_model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
19
+ self.feature_model.to(self.device)
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  """
23
+ data args:
24
  inputs (:obj: `str` | `PIL.Image` | `np.array`)
25
  kwargs
26
+ Return:
27
  A :obj:`list` | `dict`: will be serialized and returned
28
  """
29
+ result = {}
30
  inputs = data.pop("inputs", data)
31
  image_url = inputs['image_url']
32
+ if "prompt" in inputs:
33
+ prompt = inputs["prompt"]
34
+ else:
35
+ prompt = None
36
+ if "extract_feature" in inputs:
37
+ extract_feature = inputs["extract_feature"]
38
+ else:
39
+ extract_feature = False
40
+
41
  image = Image.open(requests.get(image_url, stream=True).raw)
42
  processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
43
  generated_ids = self.generate_model.generate(**processed_image)
44
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
45
+ result["image_caption"] = generated_text
46
+ if extract_feature:
47
+ caption_feature = self.feature_model(**processed_image)
48
+ result["caption_feature"] = caption_feature
49
+
50
+ if prompt:
51
+ prompt_image_processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
52
+ generated_ids = self.generate_model.generate(**prompt_image_processed)
53
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
54
+ result["image_prompt"] = generated_text
55
+ pass
56
+ if extract_feature:
57
+ prompt_feature = self.feature_model(**prompt_image_processed)
58
+ result["prompt_feature"] = prompt_feature
59
+
60
+ return result