sakshee05 commited on
Commit
5e0225e
·
verified ·
1 Parent(s): 9e5bbe8

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +71 -0
api.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import Response
4
+ import uvicorn
5
+ from PIL import Image
6
+ import io
7
+ import numpy as np
8
+ from lang_sam import LangSAM
9
+ import supervision as sv
10
+
11
+ app = FastAPI()
12
+
13
+ # Enable CORS for all origins (Adjust as needed)
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"], # Allow requests from any origin (Change this for security)
17
+ allow_credentials=True,
18
+ allow_methods=["*"], # Allow all HTTP methods
19
+ allow_headers=["*"], # Allow all headers
20
+ )
21
+
22
+ # Load the segmentation model
23
+ model = LangSAM()
24
+
25
+ def draw_image(image_rgb, masks, xyxy, probs, labels):
26
+ mask_annotator = sv.MaskAnnotator()
27
+ # Create class_id for each unique label
28
+ unique_labels = list(set(labels))
29
+ class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
30
+ class_id = [class_id_map[label] for label in labels]
31
+
32
+ # Add class_id to the Detections object
33
+ detections = sv.Detections(
34
+ xyxy=xyxy,
35
+ mask=masks.astype(bool),
36
+ confidence=probs,
37
+ class_id=np.array(class_id),
38
+ )
39
+ annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
40
+ return annotated_image
41
+
42
+ @app.post("/segment/")
43
+ async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
44
+ image_bytes = await file.read()
45
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
46
+
47
+ # Run segmentation
48
+ results = model.predict([image_pil], [text_prompt])
49
+
50
+ # Convert to NumPy array
51
+ image_array = np.asarray(image_pil)
52
+ output_image = draw_image(
53
+ image_array,
54
+ results[0]["masks"],
55
+ results[0]["boxes"],
56
+ results[0]["scores"],
57
+ results[0]["labels"],
58
+ )
59
+
60
+ # Convert back to PIL Image
61
+ output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
62
+
63
+ # Save to byte stream
64
+ img_io = io.BytesIO()
65
+ output_pil.save(img_io, format="PNG")
66
+ img_io.seek(0)
67
+
68
+ return Response(content=img_io.getvalue(), media_type="image/png")
69
+
70
+ if __name__ == "__main__":
71
+ uvicorn.run(app, host="0.0.0.0", port=8000)