File size: 3,483 Bytes
11b246c d48d98c 11b246c d48d98c 11b246c d48d98c 11b246c d48d98c 11b246c 1f0d721 11b246c 1f0d721 11b246c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from typing import Dict, List, Any
from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
import torch
import base64
import io
import numpy as np
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device)
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
if "inputs" not in data:
return [{"error": "Missing 'inputs' key"}]
inputs_data = data["inputs"]
if "image" not in inputs_data or "text" not in inputs_data:
return [{"error": "Missing 'image' or 'text' key in input data"}]
try:
# Decode base64 image
image = self.decode_image(inputs_data["image"])
prompts = inputs_data["text"]
# Preprocess input
inputs = self.processor(
text=prompts,
images=[image] * len(prompts),
padding="max_length",
return_tensors="pt"
).to("cuda")
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
segmentation_mask = outputs.logits.cpu().numpy()
segmentation_mask = segmentation_mask.squeeze()
segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) # Normalize to 0-1
segmentation_mask = (segmentation_mask * 255).astype(np.uint8)
seg_image = Image.fromarray(segmentation_mask)
seg_image_base64 = self.encode_image(seg_image)
return [{"seg_image": seg_image_base64}]
except Exception as e:
return [{"error": str(e)}]
# helper functions
def decode_image(self, image_data: str) -> Image.Image:
"""Decodes a base64-encoded image into a PIL image."""
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
def encode_image(self, image: Image.Image) -> str:
"""Encodes a PIL image to a base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def process_depth(self, image):
print("Processing depth")
print(type(image))
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
output = self.depth_pipe(image)
depth_map = np.array(output["depth"])
# Normalize to 0-255
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6)
depth_map = (depth_map * 255).astype(np.uint8)
return Image.fromarray(depth_map)
|