import torch from typing import Dict, Any, List from PIL import Image import base64 from io import BytesIO import logging from transformers import AutoImageProcessor, AutoModel import os from dataclasses import dataclass # Define a dataclass for the results @dataclass class ImageEncodingResult: image_encoded: List[List[float]] # Full encoded embeddings image_encoded_average: List[float] # Average of the embeddings class EndpointHandler: """ A handler class for processing images and generating embeddings using a pre-trained model. Attributes: processor: The pre-trained image processor. model: The pre-trained model for generating embeddings. device: The device (CPU or CUDA) used to run model inference. """ def __init__(self, path: str = ""): """ Initializes the EndpointHandler with the model and processor from the current directory. """ # Initialize logging logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) # Determine the device (CPU or CUDA) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.logger.info(f"Using device: {self.device}") # Load the model and processor from the current directory self.logger.info("Loading model and processor from the current directory.") try: self.processor = AutoImageProcessor.from_pretrained(path) self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to( self.device ) self.logger.info("Model and processor loaded successfully.") except Exception as e: self.logger.error(f"Failed to load model or processor: {e}") raise def _resize_image_if_large( self, image: Image.Image, max_size: int = 1080 ) -> Image.Image: """ Resizes an image if its dimensions exceed the specified maximum size. Args: image (Image.Image): Input image. max_size (int): Maximum size for the image dimensions. Returns: Image.Image: Resized image. """ width, height = image.size if width > max_size or height > max_size: scale = max_size / max(width, height) new_width = int(width * scale) new_height = int(height * scale) image = image.resize((new_width, new_height), resample=Image.BILINEAR) return image def _encode_image(self, image: Image.Image) -> ImageEncodingResult: """ Encodes an image into embeddings using the model. Args: image (Image.Image): Input image. Returns: ImageEncodingResult: Dataclass containing the encoded embeddings and their average. """ try: # Resize the image if necessary image = self._resize_image_if_large(image) # Process the image and generate embeddings inputs = self.processor(image, return_tensors="pt").to(self.device) with torch.inference_mode(): outputs = self.model(**inputs) last_hidden_state = outputs.last_hidden_state image_encoded = last_hidden_state.squeeze().tolist() image_encoded_average = last_hidden_state.mean(dim=1).squeeze().tolist() return ImageEncodingResult( image_encoded=image_encoded, image_encoded_average=image_encoded_average, ) except Exception as e: self.logger.error(f"Error encoding image: {e}") raise def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Processes input data containing base64-encoded images and generates embeddings. Args: data (Dict[str, Any]): Dictionary containing input images. Returns: Dict[str, Any]: Dictionary containing encoded embeddings or error messages. """ images_data = data.get("inputs", []) if not images_data: return {"error": "No image data provided."} results = [] for img_data in images_data: if isinstance(img_data, str): try: # Decode the base64-encoded image image_bytes = base64.b64decode(img_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") # Encode the image encoded_image = self._encode_image(image) results.append(encoded_image) except Exception as e: self.logger.error(f"Invalid image data: {e}") return {"error": f"Invalid image data: {e}"} else: self.logger.error("Images should be base64-encoded strings.") return {"error": "Images should be base64-encoded strings."} # Convert the results to a dictionary for JSON serialization return {"results": [result.__dict__ for result in results]}