File size: 2,108 Bytes
5e0225e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv

app = FastAPI()

# Enable CORS for all origins (Adjust as needed)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow requests from any origin (Change this for security)
    allow_credentials=True,
    allow_methods=["*"],  # Allow all HTTP methods
    allow_headers=["*"],  # Allow all headers
)

# Load the segmentation model
model = LangSAM()

def draw_image(image_rgb, masks, xyxy, probs, labels):
    mask_annotator = sv.MaskAnnotator()
    # Create class_id for each unique label
    unique_labels = list(set(labels))
    class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
    class_id = [class_id_map[label] for label in labels]

    # Add class_id to the Detections object
    detections = sv.Detections(
        xyxy=xyxy,
        mask=masks.astype(bool),
        confidence=probs,
        class_id=np.array(class_id),
    )
    annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
    return annotated_image

@app.post("/segment/")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
    image_bytes = await file.read()
    image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    
    # Run segmentation
    results = model.predict([image_pil], [text_prompt])
    
    # Convert to NumPy array
    image_array = np.asarray(image_pil)
    output_image = draw_image(
        image_array,
        results[0]["masks"],
        results[0]["boxes"],
        results[0]["scores"],
        results[0]["labels"],
    )
    
    # Convert back to PIL Image
    output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
    
    # Save to byte stream
    img_io = io.BytesIO()
    output_pil.save(img_io, format="PNG")
    img_io.seek(0)
    
    return Response(content=img_io.getvalue(), media_type="image/png")