amaye15
Debug - Handler
0366c0a
raw
history blame
5.11 kB
import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO
import logging
from transformers import AutoImageProcessor, AutoModel
import os
from dataclasses import dataclass
# Define a dataclass for the results
@dataclass
class ImageEncodingResult:
image_encoded: List[List[float]] # Full encoded embeddings
image_encoded_average: List[float] # Average of the embeddings
class EndpointHandler:
"""
A handler class for processing images and generating embeddings using a pre-trained model.
Attributes:
processor: The pre-trained image processor.
model: The pre-trained model for generating embeddings.
device: The device (CPU or CUDA) used to run model inference.
"""
def __init__(self, path: str = ""):
"""
Initializes the EndpointHandler with the model and processor from the current directory.
"""
# Initialize logging
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# Determine the device (CPU or CUDA)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger.info(f"Using device: {self.device}")
# Load the model and processor from the current directory
self.logger.info("Loading model and processor from the current directory.")
try:
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(
self.device
)
self.logger.info("Model and processor loaded successfully.")
except Exception as e:
self.logger.error(f"Failed to load model or processor: {e}")
raise
def _resize_image_if_large(
self, image: Image.Image, max_size: int = 1080
) -> Image.Image:
"""
Resizes an image if its dimensions exceed the specified maximum size.
Args:
image (Image.Image): Input image.
max_size (int): Maximum size for the image dimensions.
Returns:
Image.Image: Resized image.
"""
width, height = image.size
if width > max_size or height > max_size:
scale = max_size / max(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), resample=Image.BILINEAR)
return image
def _encode_image(self, image: Image.Image) -> ImageEncodingResult:
"""
Encodes an image into embeddings using the model.
Args:
image (Image.Image): Input image.
Returns:
ImageEncodingResult: Dataclass containing the encoded embeddings and their average.
"""
try:
# Resize the image if necessary
image = self._resize_image_if_large(image)
# Process the image and generate embeddings
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.inference_mode():
outputs = self.model(**inputs)
last_hidden_state = outputs.last_hidden_state
image_encoded = last_hidden_state.squeeze().tolist()
image_encoded_average = last_hidden_state.mean(dim=1).squeeze().tolist()
return ImageEncodingResult(
image_encoded=image_encoded,
image_encoded_average=image_encoded_average,
)
except Exception as e:
self.logger.error(f"Error encoding image: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Processes input data containing base64-encoded images and generates embeddings.
Args:
data (Dict[str, Any]): Dictionary containing input images.
Returns:
Dict[str, Any]: Dictionary containing encoded embeddings or error messages.
"""
images_data = data.get("inputs", [])
if not images_data:
return {"error": "No image data provided."}
results = []
for img_data in images_data:
if isinstance(img_data, str):
try:
# Decode the base64-encoded image
image_bytes = base64.b64decode(img_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
# Encode the image
encoded_image = self._encode_image(image)
results.append(encoded_image)
except Exception as e:
self.logger.error(f"Invalid image data: {e}")
return {"error": f"Invalid image data: {e}"}
else:
self.logger.error("Images should be base64-encoded strings.")
return {"error": "Images should be base64-encoded strings."}
# Convert the results to a dictionary for JSON serialization
return {"results": [result.__dict__ for result in results]}