import os # Set Hugging Face cache directory to /tmp os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" os.environ["TORCH_HOME"] = "/tmp/torch" from fastapi import FastAPI, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response import uvicorn from PIL import Image import io import numpy as np from lang_sam import LangSAM import supervision as sv app = FastAPI() # Enable CORS for all origins app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Create cache directories in /tmp os.makedirs("/tmp/huggingface", exist_ok=True) os.makedirs("/tmp/torch", exist_ok=True) # Load the segmentation model model = LangSAM() @app.get("/") async def root(): return {"message": "LangSAM API is running!"} def draw_image(image_rgb, masks, xyxy, probs, labels): mask_annotator = sv.MaskAnnotator() # Create class_id for each unique label unique_labels = list(set(labels)) class_id_map = {label: idx for idx, label in enumerate(unique_labels)} class_id = [class_id_map[label] for label in labels] # Add class_id to the Detections object detections = sv.Detections( xyxy=xyxy, mask=masks.astype(bool), confidence=probs, class_id=np.array(class_id), ) annotated_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections) return annotated_image @app.post("/segment/") async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(...)): image_bytes = await file.read() image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Run segmentation results = model.predict([image_pil], [text_prompt]) # Convert to NumPy array image_array = np.asarray(image_pil) output_image = draw_image( image_array, results[0]["masks"], results[0]["boxes"], results[0]["scores"], results[0]["labels"], ) # Convert back to PIL Image output_pil = Image.fromarray(np.uint8(output_image)).convert("RGB") # Save to byte stream img_io = io.BytesIO() output_pil.save(img_io, format="PNG") img_io.seek(0) return Response(content=img_io.getvalue(), media_type="image/png") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)