|
import torch |
|
from transformers import AutoProcessor, AutoModelForVision2Seq, GenerationConfig |
|
from transformers.image_utils import load_image |
|
|
|
from typing import Any, Dict |
|
|
|
import base64 |
|
import re |
|
from copy import deepcopy |
|
|
|
|
|
def is_base64(s: str) -> bool: |
|
try: |
|
return base64.b64encode(base64.b64decode(s)).decode() == s |
|
except Exception: |
|
return False |
|
|
|
|
|
def is_url(s: str) -> bool: |
|
url_pattern = re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+") |
|
return bool(url_pattern.match(s)) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__( |
|
self, |
|
model_dir: str = "HuggingFaceTB/SmolVLM-Instruct", |
|
**kwargs: Any, |
|
) -> None: |
|
self.processor = AutoProcessor.from_pretrained(model_dir) |
|
self.model = AutoModelForVision2Seq.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.bfloat16, |
|
_attn_implementation="flash_attention_2", |
|
device_map="auto", |
|
).eval() |
|
self.generation_config = GenerationConfig.from_pretrained(model_dir) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
if "inputs" not in data: |
|
raise ValueError( |
|
"The request body must contain a key 'inputs' with a list of inputs." |
|
) |
|
|
|
if not isinstance(data["inputs"], list): |
|
raise ValueError( |
|
"The request inputs must be a list of dictionaries with the keys 'text' and 'images', being a" |
|
" string with the prompt and a list with the image URLs or base64 encodings, respectively; and" |
|
" optionally including the key 'generation_parameters' key too." |
|
) |
|
|
|
predictions = [] |
|
for input in data["inputs"]: |
|
if "text" not in input: |
|
raise ValueError( |
|
"The request input body must contain the key 'text' with the prompt to use." |
|
) |
|
|
|
if "images" not in input or ( |
|
not isinstance(input["images"], list) |
|
and all(isinstance(i, str) for i in input["images"]) |
|
): |
|
raise ValueError( |
|
"The request input body must contain the key 'images' with a list of strings," |
|
" where each string corresponds to an image on either base64 encoding, or provided" |
|
" as a valid URL (needs to be publicly accessible and contain a valid image)." |
|
) |
|
|
|
images = [] |
|
for image in input["images"]: |
|
try: |
|
images.append(load_image(image)) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Provided {image=} is not valid, please make sure that's either a base64 encoding" |
|
f" of a valid image, or a publicly accesible URL to a valid image.\nFailed with {e=}." |
|
) |
|
|
|
generation_config = deepcopy(self.generation_config) |
|
generation_config.update(**input.get("generation_parameters", {"max_new_tokens": 128})) |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [{"type": "image"} for _ in images] |
|
+ [{"type": "text", "text": input["text"]}], |
|
}, |
|
] |
|
prompt = self.processor.apply_chat_template( |
|
messages, add_generation_prompt=True |
|
) |
|
processed_inputs = self.processor( |
|
text=prompt, images=images, return_tensors="pt" |
|
).to(self.model.device) |
|
|
|
generated_ids = self.model.generate( |
|
**processed_inputs, generation_config=generation_config |
|
) |
|
generated_texts = self.processor.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True, |
|
) |
|
predictions.append(generated_texts[0][len(prompt):]) |
|
|
|
return {"predictions": predictions} |