langSAM / api.py
sakshee05's picture
Update api.py
fddf2bb verified
raw
history blame
2.11 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv
app = FastAPI()
# Enable CORS for all origins (Adjust as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow requests from any origin (Change this for security)
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
# Load the segmentation model
model = LangSAM()
def draw_image(image_rgb, masks, xyxy, probs, labels):
mask_annotator = sv.MaskAnnotator()
# Create class_id for each unique label
unique_labels = list(set(labels))
class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
class_id = [class_id_map[label] for label in labels]
# Add class_id to the Detections object
detections = sv.Detections(
xyxy=xyxy,
mask=masks.astype(bool),
confidence=probs,
class_id=np.array(class_id),
)
annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
return annotated_image
@app.post("/segment/")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Run segmentation
results = model.predict([image_pil], [text_prompt])
# Convert to NumPy array
image_array = np.asarray(image_pil)
output_image = draw_image(
image_array,
results[0]["masks"],
results[0]["boxes"],
results[0]["scores"],
results[0]["labels"],
)
# Convert back to PIL Image
output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
# Save to byte stream
img_io = io.BytesIO()
output_pil.save(img_io, format="PNG")
img_io.seek(0)
return Response(content=img_io.getvalue(), media_type="image/png")