File size: 3,483 Bytes
11b246c
 
 
 
 
 
 
 
 
 
 
 
 
d48d98c
11b246c
d48d98c
11b246c
 
 
 
 
 
 
 
 
 
d48d98c
 
 
 
 
11b246c
 
 
 
d48d98c
 
 
11b246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f0d721
 
 
11b246c
 
 
 
 
 
 
 
 
1f0d721
 
 
 
 
 
11b246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import Dict, List, Any
from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import base64
import io
import numpy as np

class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device)
        self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        if "inputs" not in data:
            return [{"error": "Missing 'inputs' key"}]

        inputs_data = data["inputs"]
        if "image" not in inputs_data or "text" not in inputs_data:
            return [{"error": "Missing 'image' or 'text' key in input data"}]

        try:
            # Decode base64 image
            image = self.decode_image(inputs_data["image"])
            prompts = inputs_data["text"]
        
            # Preprocess input
            inputs = self.processor(
                text=prompts,
                images=[image] * len(prompts),
                padding="max_length",
                return_tensors="pt"
            ).to("cuda")

            # Run inference
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            segmentation_mask = outputs.logits.cpu().numpy()
            segmentation_mask = segmentation_mask.squeeze()

            segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6)  # Normalize to 0-1
            segmentation_mask = (segmentation_mask * 255).astype(np.uint8)
            
            seg_image = Image.fromarray(segmentation_mask)

            seg_image_base64 = self.encode_image(seg_image)

            return [{"seg_image": seg_image_base64}]
        
        except Exception as e:
            return [{"error": str(e)}]

    # helper functions
    def decode_image(self, image_data: str) -> Image.Image:
        """Decodes a base64-encoded image into a PIL image."""
        image_bytes = base64.b64decode(image_data)
        return Image.open(io.BytesIO(image_bytes)).convert("RGB")

    def encode_image(self, image: Image.Image) -> str:
        """Encodes a PIL image to a base64 string."""
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    def process_depth(self, image):
        print("Processing depth")
        print(type(image))
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype("uint8"))
        output = self.depth_pipe(image)
        depth_map = np.array(output["depth"])

        # Normalize to 0-255
        depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6)
        depth_map = (depth_map * 255).astype(np.uint8)

        return Image.fromarray(depth_map)