File size: 4,442 Bytes
378a602 0a1a8aa 585855c 70ae50f 378a602 5e0225e 378a602 5e0225e 378a602 5e0225e 378a602 5e0225e 3af6820 378a602 70ae50f 5e0225e 71b8343 70ae50f 5e0225e 70ae50f 5e0225e 70ae50f 5e0225e 71b8343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
# Set Hugging Face cache directory to /tmp
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import cv2
app = FastAPI()
# Enable CORS for all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create cache directories in /tmp
os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/torch", exist_ok=True)
# Load the langSAM model
langsam_model = LangSAM()
# Load SAM2 Model
sam2_checkpoint = "sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
device = torch.device("cpu")
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)
@app.get("/")
async def root():
return {"message": "LangSAM API is running!"}
def apply_mask(image, mask):
"""Overlay mask on image."""
mask = mask.astype(np.uint8) * 255 # Convert mask to 0-255 scale
mask_colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
mask_colored[mask > 0] = [30, 144, 255] # Blue color for the mask
# Add contour
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(mask_colored, contours, -1, (255, 255, 255), thickness=2)
# Blend with original image
overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0)
return overlay
def draw_image(image_rgb, masks, xyxy, probs, labels):
mask_annotator = sv.MaskAnnotator()
# Create class_id for each unique label
unique_labels = list(set(labels))
class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
class_id = [class_id_map[label] for label in labels]
# Add class_id to the Detections object
detections = sv.Detections(
xyxy=xyxy,
mask=masks.astype(bool),
confidence=probs,
class_id=np.array(class_id),
)
annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
return annotated_image
@app.post("/segment/sam2")
async def segment_image(
file: UploadFile = File(...),
x: int = Form(...),
y: int = Form(...)
):
"""Segment image using SAM2 with a single input point."""
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_array = np.array(image_pil)
predictor.set_image(image_array)
input_point = np.array([[x, y]])
input_label = np.array([1]) # Foreground point
# Run SAM2 model
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
# Get top mask
top_mask = masks[np.argmax(scores)]
# Apply mask overlay
output_image = apply_mask(image_array, top_mask)
# Convert to PNG
output_pil = Image.fromarray(output_image)
img_io = io.BytesIO()
output_pil.save(img_io, format="PNG")
img_io.seek(0)
return Response(content=img_io.getvalue(), media_type="image/png")
@app.post("/segment/langsam")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Run segmentation
results = langsam_model.predict([image_pil], [text_prompt])
# Convert to NumPy array
image_array = np.asarray(image_pil)
output_image = draw_image(
image_array,
results[0]["masks"],
results[0]["boxes"],
results[0]["scores"],
results[0]["labels"],
)
# Convert back to PIL Image
output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
# Save to byte stream
img_io = io.BytesIO()
output_pil.save(img_io, format="PNG")
img_io.seek(0)
return Response(content=img_io.getvalue(), media_type="image/png")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|