File size: 3,982 Bytes
ff47bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89baa6d
ff47bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f317d
ff47bc8
 
 
 
 
 
 
 
 
 
 
06eb103
 
 
ff47bc8
06eb103
 
 
 
 
 
 
d1cf6f0
ff47bc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq, GenerationConfig
from transformers.image_utils import load_image

from typing import Any, Dict

import base64
import re
from copy import deepcopy


def is_base64(s: str) -> bool:
    try:
        return base64.b64encode(base64.b64decode(s)).decode() == s
    except Exception:
        return False


def is_url(s: str) -> bool:
    url_pattern = re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+")
    return bool(url_pattern.match(s))


class EndpointHandler:
    def __init__(
        self,
        model_dir: str = "HuggingFaceTB/SmolVLM-Instruct",
        **kwargs: Any,  # type: ignore
    ) -> None:
        self.processor = AutoProcessor.from_pretrained(model_dir)
        self.model = AutoModelForVision2Seq.from_pretrained(
            model_dir,
            torch_dtype=torch.bfloat16,
            _attn_implementation="eager", # "flash_attention_2",
            device_map="auto",
        ).eval()
        self.generation_config = GenerationConfig.from_pretrained(model_dir)

    def __call__(self, data: Dict[str, Any]) -> Any:
        if "inputs" not in data:
            raise ValueError(
                "The request body must contain a key 'inputs' with a list of inputs."
            )

        if not isinstance(data["inputs"], list):
            raise ValueError(
                "The request inputs must be a list of dictionaries with the keys 'text' and 'images', being a"
                " string with the prompt and a list with the image URLs or base64 encodings, respectively; and"
                " optionally including the key 'generation_parameters' key too."
            )

        predictions = []
        for input in data["inputs"]:
            if "text" not in input:
                raise ValueError(
                    "The request input body must contain the key 'text' with the prompt to use."
                )

            if "images" not in input or (
                not isinstance(input["images"], list)
                and all(isinstance(i, str) for i in input["images"])
            ):
                raise ValueError(
                    "The request input body must contain the key 'images' with a list of strings,"
                    " where each string corresponds to an image on either base64 encoding, or provided"
                    " as a valid URL (needs to be publicly accessible and contain a valid image)."
                )

            images = []
            for image in input["images"]:
                try:
                    images.append(load_image(image))
                except Exception as e:
                    raise ValueError(
                        f"Provided {image=} is not valid, please make sure that's either a base64 encoding"
                        f" of a valid image, or a publicly accesible URL to a valid image.\nFailed with {e=}."
                    )

            generation_config = deepcopy(self.generation_config)
            generation_config.update(**input.get("generation_parameters", {"max_new_tokens": 128}))

            messages = [
                {
                    "role": "user",
                    "content": [{"type": "image"} for _ in images]
                    + [{"type": "text", "text": input["text"]}],
                },
            ]
            prompt = self.processor.apply_chat_template(
                messages, add_generation_prompt=True
            )
            processed_inputs = self.processor(
                text=prompt, images=images, return_tensors="pt"
            ).to(self.model.device)

            generated_ids = self.model.generate(
                **processed_inputs, generation_config=generation_config
            )
            generated_texts = self.processor.batch_decode(
                generated_ids,
                skip_special_tokens=True,
            )
            predictions.append(generated_texts[0])

        return {"predictions": predictions}