Meng Chen
update handler
1f0d721
from typing import Dict, List, Any
from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import base64
import io
import numpy as np
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device)
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
if "inputs" not in data:
return [{"error": "Missing 'inputs' key"}]
inputs_data = data["inputs"]
if "image" not in inputs_data or "text" not in inputs_data:
return [{"error": "Missing 'image' or 'text' key in input data"}]
try:
# Decode base64 image
image = self.decode_image(inputs_data["image"])
prompts = inputs_data["text"]
# Preprocess input
inputs = self.processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt"
).to("cuda")
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
segmentation_mask = outputs.logits.cpu().numpy()
segmentation_mask = segmentation_mask.squeeze()
segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) # Normalize to 0-1
segmentation_mask = (segmentation_mask * 255).astype(np.uint8)
seg_image = Image.fromarray(segmentation_mask)
seg_image_base64 = self.encode_image(seg_image)
return [{"seg_image": seg_image_base64}]
except Exception as e:
return [{"error": str(e)}]
# helper functions
def decode_image(self, image_data: str) -> Image.Image:
"""Decodes a base64-encoded image into a PIL image."""
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
def encode_image(self, image: Image.Image) -> str:
"""Encodes a PIL image to a base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def process_depth(self, image):
print("Processing depth")
print(type(image))
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
output = self.depth_pipe(image)
depth_map = np.array(output["depth"])
# Normalize to 0-255
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6)
depth_map = (depth_map * 255).astype(np.uint8)
return Image.fromarray(depth_map)