File size: 1,927 Bytes
b1cc8b6
bb2a012
e76d7b2
 
cd5795f
b1cc8b6
bb2a012
cf7e461
e76d7b2
 
 
cd5795f
e76d7b2
 
 
 
a1c2e19
e76d7b2
 
 
 
 
cd5795f
bb2a012
b1cc8b6
bb2a012
e76d7b2
 
 
b1cc8b6
e76d7b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1cc8b6
e76d7b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import AutoModel, AutoTokenizer

class EndpointHandler:
    def __init__(self, path="/repository"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load the model
        self.model = AutoModel.from_pretrained(
            path,
            trust_remote_code=True,
            attn_implementation='sdpa',
            torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32,
        ).to(self.device)
        self.model.eval()
        
        # Load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            path,
            trust_remote_code=True,
        )

    def __call__(self, data):
        # Extract image and text from the input data
        image_data = data.get("inputs", {}).get("image", "")
        text_prompt = data.get("inputs", {}).get("text", "")

        if not image_data or not text_prompt:
            return {"error": "Both 'image' and 'text' must be provided in the input data."}

        # Process the image data
        try:
            image_bytes = base64.b64decode(image_data)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
        except Exception as e:
            return {"error": f"Failed to process image data: {e}"}

        # Prepare the messages for the model
        msgs = [{'role': 'user', 'content': [image, text_prompt]}]

        # Generate output
        with torch.no_grad():
            res = self.model.chat(
                image=None,
                msgs=msgs,
                tokenizer=self.tokenizer,
                sampling=True,
                temperature=0.7,
                top_p=0.95,
                max_length=2000,
            )
        
        # The result is the generated text
        output_text = res

        return {"generated_text": output_text}