File size: 2,652 Bytes
25c91e1
 
 
 
 
 
92feb47
25c91e1
11ae0dd
 
25c91e1
 
 
92feb47
 
 
11ae0dd
25c91e1
 
11ae0dd
 
 
25c91e1
11ae0dd
7d928da
25c91e1
7d928da
25c91e1
11ae0dd
 
 
 
 
25c91e1
 
 
7d928da
 
25c91e1
11ae0dd
 
 
 
 
 
 
 
 
 
25c91e1
11ae0dd
 
 
92feb47
11ae0dd
 
 
 
 
 
 
 
25c91e1
11ae0dd
 
 
92feb47
11ae0dd
 
 
92feb47
11ae0dd
 
 
92feb47
11ae0dd
25c91e1
11ae0dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Dict, Any
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import io
import base64
import requests
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EndpointHandler():
    def __init__(self, path=""):
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            path, device_map="auto"
        )
        self.model.to(device)

    def __call__(self, data: Any) -> Dict[str, Any]:
        inputs = data.pop("inputs", data)
        image_input = inputs.get('image')
        text_input = inputs.get('text', "Describe this image.")

        if not image_input:
            return {"error": "No image provided."}

        try:
            if image_input.startswith('http'):
                response = requests.get(image_input, stream=True)
                if response.status_code == 200:
                    image = Image.open(response.raw).convert('RGB')
                else:
                    return {"error": f"Failed to fetch image. Status code: {response.status_code}"}
            else:
                image_data = base64.b64decode(image_input)
                image = Image.open(io.BytesIO(image_data)).convert('RGB')
        except Exception as e:
            return {"error": f"Failed to process the image. Details: {str(e)}"}

        try:
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": text_input},
                    ],
                }
            ]

            text_prompt = self.processor.apply_chat_template(
                conversation, add_generation_prompt=True
            )

            inputs = self.processor(
                text=[text_prompt],
                images=[image],
                padding=True,
                return_tensors="pt",
            )

            inputs = inputs.to(device)

            output_ids = self.model.generate(
                **inputs, max_new_tokens=128
            )

            generated_ids = [
                output_id[len(input_id):] for input_id, output_id in zip(inputs.input_ids, output_ids)
            ]

            output_text = self.processor.batch_decode(
                generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )[0]

            return {"generated_text": output_text}

        except Exception as e:
            return {"error": f"Failed during generation. Details: {str(e)}"}