File size: 2,108 Bytes
5e0225e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv
app = FastAPI()
# Enable CORS for all origins (Adjust as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow requests from any origin (Change this for security)
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
# Load the segmentation model
model = LangSAM()
def draw_image(image_rgb, masks, xyxy, probs, labels):
mask_annotator = sv.MaskAnnotator()
# Create class_id for each unique label
unique_labels = list(set(labels))
class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
class_id = [class_id_map[label] for label in labels]
# Add class_id to the Detections object
detections = sv.Detections(
xyxy=xyxy,
mask=masks.astype(bool),
confidence=probs,
class_id=np.array(class_id),
)
annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
return annotated_image
@app.post("/segment/")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Run segmentation
results = model.predict([image_pil], [text_prompt])
# Convert to NumPy array
image_array = np.asarray(image_pil)
output_image = draw_image(
image_array,
results[0]["masks"],
results[0]["boxes"],
results[0]["scores"],
results[0]["labels"],
)
# Convert back to PIL Image
output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
# Save to byte stream
img_io = io.BytesIO()
output_pil.save(img_io, format="PNG")
img_io.seek(0)
return Response(content=img_io.getvalue(), media_type="image/png")
|