adasdimchom commited on
Commit
3daa1cf
·
1 Parent(s): b2d86ed

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import Blip2Processor, Blip2Model
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  from transformers import pipeline
@@ -12,8 +12,8 @@ class EndpointHandler():
12
  """
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.processor = Blip2Processor.from_pretrained(path)
15
- self.model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
16
- self.model.to(self.device)
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
@@ -27,6 +27,6 @@ class EndpointHandler():
27
  image_url = inputs['image_url']
28
  image = Image.open(requests.get(image_url, stream=True).raw)
29
  processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
30
- generated_ids = self.model.generate(**processed_image)
31
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
  return image_url, generated_text
 
1
+ from transformers import Blip2Processor, Blip2Model, Blip2ForConditionalGeneration
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  from transformers import pipeline
 
12
  """
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.processor = Blip2Processor.from_pretrained(path)
15
+ self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
16
+ self.generate_model.to(self.device)
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
 
27
  image_url = inputs['image_url']
28
  image = Image.open(requests.get(image_url, stream=True).raw)
29
  processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
30
+ generated_ids = self.generate_model.generate(**processed_image)
31
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
  return image_url, generated_text