File size: 2,088 Bytes
2b05c9e
 
 
28a75c8
 
 
 
 
 
 
 
 
 
ae1956b
28a75c8
ae1956b
 
 
28a75c8
ae1956b
 
28a75c8
ae1956b
 
 
9f01411
2b05c9e
 
ae1956b
28a75c8
2b05c9e
 
 
 
 
 
 
 
ae1956b
28a75c8
ae1956b
28a75c8
ae1956b
28a75c8
 
ae1956b
28a75c8
 
 
2b05c9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import base64
from io import BytesIO
from typing import Dict, Any
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch

class EndpointHandler():
    def __init__(self, path="./"):
        # Load the processor and model, and move to CUDA if available
        self.processor = BlipProcessor.from_pretrained(path)
        self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu")

    def __call__(self, data: Any) -> Dict[str, str]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. The object returned should be a dict like {"caption": "Generated caption for the image"} containing:
                - "caption": The generated caption as a string.
        """
        # Extract inputs and parameters
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {"mode": "image"})

        # Get base64 image data and prompt from the inputs
        image_base64 = inputs.get("image_base64")
        prompt = inputs.get("prompt", "")  # Optional prompt for conditional captioning

        # Ensure base64-encoded image is provided
        if not image_base64:
            raise ValueError("No image data provided. Please provide 'image_base64'.")

        # Decode base64 string and convert to RGB image
        image_data = BytesIO(base64.b64decode(image_base64))
        image = Image.open(image_data).convert("RGB")

        # Process inputs with or without a prompt
        if prompt:
            processed_inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
        else:
            processed_inputs = self.processor(image, return_tensors="pt").to(self.model.device)

        # Generate caption
        out = self.model.generate(**processed_inputs)
        caption = self.processor.decode(out[0], skip_special_tokens=True)

        # Return the generated caption
        return {"caption": caption}