import torch from transformers import AutoProcessor, AutoModelForVision2Seq, GenerationConfig from transformers.image_utils import load_image from typing import Any, Dict import base64 import re from copy import deepcopy def is_base64(s: str) -> bool: try: return base64.b64encode(base64.b64decode(s)).decode() == s except Exception: return False def is_url(s: str) -> bool: url_pattern = re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+") return bool(url_pattern.match(s)) class EndpointHandler: def __init__( self, model_dir: str = "HuggingFaceTB/SmolVLM-Instruct", **kwargs: Any, # type: ignore ) -> None: self.processor = AutoProcessor.from_pretrained(model_dir) self.model = AutoModelForVision2Seq.from_pretrained( model_dir, torch_dtype=torch.bfloat16, _attn_implementation="eager", # "flash_attention_2", device_map="auto", ).eval() self.generation_config = GenerationConfig.from_pretrained(model_dir) def __call__(self, data: Dict[str, Any]) -> Any: if "inputs" not in data: raise ValueError( "The request body must contain a key 'inputs' with a list of inputs." ) if not isinstance(data["inputs"], list): raise ValueError( "The request inputs must be a list of dictionaries with the keys 'text' and 'images', being a" " string with the prompt and a list with the image URLs or base64 encodings, respectively; and" " optionally including the key 'generation_parameters' key too." ) predictions = [] for input in data["inputs"]: if "text" not in input: raise ValueError( "The request input body must contain the key 'text' with the prompt to use." ) if "images" not in input or ( not isinstance(input["images"], list) and all(isinstance(i, str) for i in input["images"]) ): raise ValueError( "The request input body must contain the key 'images' with a list of strings," " where each string corresponds to an image on either base64 encoding, or provided" " as a valid URL (needs to be publicly accessible and contain a valid image)." ) images = [] for image in input["images"]: try: images.append(load_image(image)) except Exception as e: raise ValueError( f"Provided {image=} is not valid, please make sure that's either a base64 encoding" f" of a valid image, or a publicly accesible URL to a valid image.\nFailed with {e=}." ) generation_config = deepcopy(self.generation_config) generation_config.update(**input.get("generation_parameters", {"max_new_tokens": 128})) messages = [ { "role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": input["text"]}], }, ] prompt = self.processor.apply_chat_template( messages, add_generation_prompt=True ) processed_inputs = self.processor( text=prompt, images=images, return_tensors="pt" ).to(self.model.device) generated_ids = self.model.generate( **processed_inputs, generation_config=generation_config ) generated_texts = self.processor.batch_decode( generated_ids, skip_special_tokens=True, ) predictions.append(generated_texts[0]) return {"predictions": predictions}