ckandemir commited on
Commit
895781a
·
1 Parent(s): f287aea

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +34 -0
handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import requests
3
+ from PIL import Image
4
+ from typing import Dict, List, Any, Union
5
+ import torch
6
+ from io import BytesIO
7
+ from transformers import BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, model_dir="Salesforce/blip-image-captioning-large"):
13
+ self.model = BlipForConditionalGeneration.from_pretrained(model_dir).to(device).eval()
14
+ self.processor = BlipProcessor.from_pretrained(model_dir)
15
+
16
+ def __call__(self, data):
17
+ img_url = data.get('img_url')
18
+ text_prompt = data.get('text', None)
19
+
20
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
21
+
22
+ if text_prompt:
23
+ inputs = self.processor(raw_image, text_prompt, return_tensors="pt").to(device)
24
+ else:
25
+ inputs = self.processor(raw_image, return_tensors="pt").to(device)
26
+
27
+ with torch.no_grad():
28
+ generated_ids = self.model.generate(
29
+ **inputs,
30
+ max_new_tokens=150
31
+ )
32
+ generated_text = self.processor.decode(generated_ids[0], skip_special_tokens=True)
33
+
34
+ return {"responses": generated_text}