|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" |
|
os.environ["TORCH_HOME"] = "/tmp/torch" |
|
|
|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import Response |
|
import uvicorn |
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
from lang_sam import LangSAM |
|
import supervision as sv |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
import torch |
|
import cv2 |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
os.makedirs("/tmp/huggingface", exist_ok=True) |
|
os.makedirs("/tmp/torch", exist_ok=True) |
|
|
|
|
|
langsam_model = LangSAM() |
|
|
|
|
|
sam2_checkpoint = "sam2.1_hiera_small.pt" |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" |
|
device = torch.device("cpu") |
|
|
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) |
|
predictor = SAM2ImagePredictor(sam2_model) |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "LangSAM API is running!"} |
|
|
|
def apply_mask(image, mask): |
|
"""Overlay mask on image.""" |
|
mask = mask.astype(np.uint8) * 255 |
|
mask_colored = np.zeros((*mask.shape, 3), dtype=np.uint8) |
|
mask_colored[mask > 0] = [30, 144, 255] |
|
|
|
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
cv2.drawContours(mask_colored, contours, -1, (255, 255, 255), thickness=2) |
|
|
|
|
|
overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0) |
|
return overlay |
|
|
|
|
|
def draw_image(image_rgb, masks, xyxy, probs, labels): |
|
mask_annotator = sv.MaskAnnotator() |
|
|
|
unique_labels = list(set(labels)) |
|
class_id_map = {label: idx for idx, label in enumerate(unique_labels)} |
|
class_id = [class_id_map[label] for label in labels] |
|
|
|
|
|
detections = sv.Detections( |
|
xyxy=xyxy, |
|
mask=masks.astype(bool), |
|
confidence=probs, |
|
class_id=np.array(class_id), |
|
) |
|
annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections) |
|
return annotated_image |
|
|
|
@app.post("/segment/sam2") |
|
async def segment_image( |
|
file: UploadFile = File(...), |
|
x: int = Form(...), |
|
y: int = Form(...) |
|
): |
|
"""Segment image using SAM2 with a single input point.""" |
|
image_bytes = await file.read() |
|
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
image_array = np.array(image_pil) |
|
|
|
predictor.set_image(image_array) |
|
|
|
input_point = np.array([[x, y]]) |
|
input_label = np.array([1]) |
|
|
|
|
|
masks, scores, logits = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=True, |
|
) |
|
|
|
|
|
top_mask = masks[np.argmax(scores)] |
|
|
|
|
|
output_image = apply_mask(image_array, top_mask) |
|
|
|
|
|
output_pil = Image.fromarray(output_image) |
|
img_io = io.BytesIO() |
|
output_pil.save(img_io, format="PNG") |
|
img_io.seek(0) |
|
|
|
return Response(content=img_io.getvalue(), media_type="image/png") |
|
|
|
|
|
@app.post("/segment/langsam") |
|
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)): |
|
image_bytes = await file.read() |
|
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
results = langsam_model.predict([image_pil], [text_prompt]) |
|
|
|
|
|
image_array = np.asarray(image_pil) |
|
output_image = draw_image( |
|
image_array, |
|
results[0]["masks"], |
|
results[0]["boxes"], |
|
results[0]["scores"], |
|
results[0]["labels"], |
|
) |
|
|
|
|
|
output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB") |
|
|
|
|
|
img_io = io.BytesIO() |
|
output_pil.save(img_io, format="PNG") |
|
img_io.seek(0) |
|
|
|
return Response(content=img_io.getvalue(), media_type="image/png") |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|