File size: 2,457 Bytes
378a602
0a1a8aa
 
 
 
 
 
585855c
 
 
 
 
 
 
 
 
378a602
5e0225e
 
378a602
5e0225e
 
378a602
5e0225e
378a602
 
5e0225e
 
3af6820
 
 
378a602
5e0225e
 
 
71b8343
 
 
 
5e0225e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b8343
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os

# Set Hugging Face cache directory to /tmp
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"

from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv

app = FastAPI()

# Enable CORS for all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Create cache directories in /tmp
os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/torch", exist_ok=True)

# Load the segmentation model
model = LangSAM()

@app.get("/")
async def root():
    return {"message": "LangSAM API is running!"}

def draw_image(image_rgb, masks, xyxy, probs, labels):
    mask_annotator = sv.MaskAnnotator()
    # Create class_id for each unique label
    unique_labels = list(set(labels))
    class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
    class_id = [class_id_map[label] for label in labels]

    # Add class_id to the Detections object
    detections = sv.Detections(
        xyxy=xyxy,
        mask=masks.astype(bool),
        confidence=probs,
        class_id=np.array(class_id),
    )
    annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
    return annotated_image

@app.post("/segment/")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
    image_bytes = await file.read()
    image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    
    # Run segmentation
    results = model.predict([image_pil], [text_prompt])
    
    # Convert to NumPy array
    image_array = np.asarray(image_pil)
    output_image = draw_image(
        image_array,
        results[0]["masks"],
        results[0]["boxes"],
        results[0]["scores"],
        results[0]["labels"],
    )
    
    # Convert back to PIL Image
    output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
    
    # Save to byte stream
    img_io = io.BytesIO()
    output_pil.save(img_io, format="PNG")
    img_io.seek(0)
    
    return Response(content=img_io.getvalue(), media_type="image/png")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)