File size: 4,442 Bytes
378a602
0a1a8aa
 
 
 
 
 
585855c
 
 
 
 
 
 
 
 
70ae50f
 
 
 
378a602
5e0225e
 
378a602
5e0225e
 
378a602
5e0225e
378a602
 
5e0225e
 
3af6820
 
 
378a602
70ae50f
 
 
 
 
 
 
 
 
 
5e0225e
71b8343
 
 
 
70ae50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e0225e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70ae50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e0225e
 
 
 
 
70ae50f
5e0225e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b8343
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os

# Set Hugging Face cache directory to /tmp
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"

from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import cv2

app = FastAPI()

# Enable CORS for all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Create cache directories in /tmp
os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/torch", exist_ok=True)

# Load the langSAM model
langsam_model = LangSAM()

# Load SAM2 Model
sam2_checkpoint = "sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
device = torch.device("cpu")

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)

@app.get("/")
async def root():
    return {"message": "LangSAM API is running!"}

def apply_mask(image, mask):
    """Overlay mask on image."""
    mask = mask.astype(np.uint8) * 255  # Convert mask to 0-255 scale
    mask_colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
    mask_colored[mask > 0] = [30, 144, 255]  # Blue color for the mask
    
    # Add contour
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(mask_colored, contours, -1, (255, 255, 255), thickness=2)
    
    # Blend with original image
    overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0)
    return overlay


def draw_image(image_rgb, masks, xyxy, probs, labels):
    mask_annotator = sv.MaskAnnotator()
    # Create class_id for each unique label
    unique_labels = list(set(labels))
    class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
    class_id = [class_id_map[label] for label in labels]

    # Add class_id to the Detections object
    detections = sv.Detections(
        xyxy=xyxy,
        mask=masks.astype(bool),
        confidence=probs,
        class_id=np.array(class_id),
    )
    annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
    return annotated_image

@app.post("/segment/sam2")
async def segment_image(
    file: UploadFile = File(...), 
    x: int = Form(...), 
    y: int = Form(...)
):
    """Segment image using SAM2 with a single input point."""
    image_bytes = await file.read()
    image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    image_array = np.array(image_pil)
    
    predictor.set_image(image_array)
    
    input_point = np.array([[x, y]])
    input_label = np.array([1])  # Foreground point
    
    # Run SAM2 model
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )

    # Get top mask
    top_mask = masks[np.argmax(scores)]

    # Apply mask overlay
    output_image = apply_mask(image_array, top_mask)

    # Convert to PNG
    output_pil = Image.fromarray(output_image)
    img_io = io.BytesIO()
    output_pil.save(img_io, format="PNG")
    img_io.seek(0)

    return Response(content=img_io.getvalue(), media_type="image/png")


@app.post("/segment/langsam")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
    image_bytes = await file.read()
    image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    
    # Run segmentation
    results = langsam_model.predict([image_pil], [text_prompt])
    
    # Convert to NumPy array
    image_array = np.asarray(image_pil)
    output_image = draw_image(
        image_array,
        results[0]["masks"],
        results[0]["boxes"],
        results[0]["scores"],
        results[0]["labels"],
    )
    
    # Convert back to PIL Image
    output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
    
    # Save to byte stream
    img_io = io.BytesIO()
    output_pil.save(img_io, format="PNG")
    img_io.seek(0)
    
    return Response(content=img_io.getvalue(), media_type="image/png")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)