kumararvindibs's picture
Update handler.py
74b1841 verified
import requests
from typing import Dict, Any
from PIL import Image
import torch
import base64
import io
from transformers import BlipForConditionalGeneration, BlipProcessor
import logging
from io import BytesIO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Configure logging
logging.basicConfig(level=logging.ERROR)
# Configure logging
logging.basicConfig(level=logging.WARNING)
class EndpointHandler():
def __init__(self, path=""):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
logging.error(f"----------This is an error message {str(data)}")
input_data = data.get("inputs", {})
logging.warning(f"------input_data-- {str(input_data)}")
encoded_images = input_data.get("url")
print("url---",encoded_images)
# Convert image to bytes
# image = Image.open(encoded_images[0])
#url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
# Send a GET request to the URL to get the image data
response = requests.get(encoded_images)
# Read the image data from the response
image_data = BytesIO(response.content)
#image_bytes = image.tobytes()
# img = Image.open(image_data)
#print("testing img--------------", img)
#logging.warning(f"---image_bytes----- {str(image_bytes)}")
# Encode image bytes as base64
#image_base64 = base64.b64encode(image_bytes).decode("utf-8")
#logging.warning(f"---encoded_images----- {str(image_base64)}")
# print("000--------", image_base64)
if not encoded_images:
logging.warning(f"---encoded_images--not provided in if block--- {str(encoded_images)}")
return {"captions": [], "error": "No images provided"}
try:
logging.warning(f"---encoded_images-- provided in try block--- {str(encoded_images)}")
byteImgIO = io.BytesIO()
byteImg = Image.open(image_data)
print("testing img---byteImg-----------", byteImg)
byteImg.save(byteImgIO, "PNG")
byteImgIO.seek(0)
byteImg = byteImgIO.read()
# Non test code
dataBytesIO = io.BytesIO(byteImg)
raw_images =[Image.open(dataBytesIO)]
logging.warning(f"----raw_images----0--- {str(raw_images)}")
# Check if any images were successfully decoded
if not raw_images:
print("No valid images found.")
processed_inputs = [
self.processor(image, return_tensors="pt") for image in zip(raw_images)
]
processed_inputs = {
"pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
"max_new_tokens":40
}
with torch.no_grad():
out = self.model.generate(**processed_inputs)
captions = self.processor.batch_decode(out, skip_special_tokens=True)
logging.warning(f"----captions---- {str(captions)}")
print("caption is here-------",captions)
return {"captions": captions}
except Exception as e:
print(f"Error during processing: {str(e)}")
logging.error(f"Error during processing: ----------------{str(e)}")
return {"captions": [], "error": str(e)}