|
from typing import Dict, List, Any |
|
from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation |
|
from PIL import Image |
|
import torch |
|
import base64 |
|
import io |
|
import numpy as np |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device) |
|
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
if "inputs" not in data: |
|
return [{"error": "Missing 'inputs' key"}] |
|
|
|
inputs_data = data["inputs"] |
|
if "image" not in inputs_data or "text" not in inputs_data: |
|
return [{"error": "Missing 'image' or 'text' key in input data"}] |
|
|
|
try: |
|
|
|
image = self.decode_image(inputs_data["image"]) |
|
prompts = inputs_data["text"] |
|
|
|
|
|
inputs = self.processor( |
|
text=prompts, |
|
images=[image] * len(prompts), |
|
padding="max_length", |
|
return_tensors="pt" |
|
).to("cuda") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
segmentation_mask = outputs.logits.cpu().numpy() |
|
segmentation_mask = segmentation_mask.squeeze() |
|
|
|
segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) |
|
segmentation_mask = (segmentation_mask * 255).astype(np.uint8) |
|
|
|
seg_image = Image.fromarray(segmentation_mask) |
|
|
|
seg_image_base64 = self.encode_image(seg_image) |
|
|
|
return [{"seg_image": seg_image_base64}] |
|
|
|
except Exception as e: |
|
return [{"error": str(e)}] |
|
|
|
|
|
def decode_image(self, image_data: str) -> Image.Image: |
|
"""Decodes a base64-encoded image into a PIL image.""" |
|
image_bytes = base64.b64decode(image_data) |
|
return Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
def encode_image(self, image: Image.Image) -> str: |
|
"""Encodes a PIL image to a base64 string.""" |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
def process_depth(self, image): |
|
print("Processing depth") |
|
print(type(image)) |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image.astype("uint8")) |
|
output = self.depth_pipe(image) |
|
depth_map = np.array(output["depth"]) |
|
|
|
|
|
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6) |
|
depth_map = (depth_map * 255).astype(np.uint8) |
|
|
|
return Image.fromarray(depth_map) |
|
|
|
|