File size: 2,457 Bytes
378a602 0a1a8aa 585855c 378a602 5e0225e 378a602 5e0225e 378a602 5e0225e 378a602 5e0225e 3af6820 378a602 5e0225e 71b8343 5e0225e 71b8343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
# Set Hugging Face cache directory to /tmp
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
import uvicorn
from PIL import Image
import io
import numpy as np
from lang_sam import LangSAM
import supervision as sv
app = FastAPI()
# Enable CORS for all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create cache directories in /tmp
os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/torch", exist_ok=True)
# Load the segmentation model
model = LangSAM()
@app.get("/")
async def root():
return {"message": "LangSAM API is running!"}
def draw_image(image_rgb, masks, xyxy, probs, labels):
mask_annotator = sv.MaskAnnotator()
# Create class_id for each unique label
unique_labels = list(set(labels))
class_id_map = {label: idx for idx, label in enumerate(unique_labels)}
class_id = [class_id_map[label] for label in labels]
# Add class_id to the Detections object
detections = sv.Detections(
xyxy=xyxy,
mask=masks.astype(bool),
confidence=probs,
class_id=np.array(class_id),
)
annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)
return annotated_image
@app.post("/segment/")
async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)):
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Run segmentation
results = model.predict([image_pil], [text_prompt])
# Convert to NumPy array
image_array = np.asarray(image_pil)
output_image = draw_image(
image_array,
results[0]["masks"],
results[0]["boxes"],
results[0]["scores"],
results[0]["labels"],
)
# Convert back to PIL Image
output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB")
# Save to byte stream
img_io = io.BytesIO()
output_pil.save(img_io, format="PNG")
img_io.seek(0)
return Response(content=img_io.getvalue(), media_type="image/png")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|