sakshee05 commited on
Commit
71b8343
·
verified ·
1 Parent(s): 378a602

Fix: Use /tmp for caching

Browse files
Files changed (1) hide show
  1. main.py +14 -9
main.py CHANGED
@@ -1,15 +1,12 @@
1
  import os
2
- os.environ["TORCH_HOME"] = "/app/cache" # Set a writable directory
3
 
4
- from fastapi import FastAPI, File, UploadFile, Form
 
 
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
- from fastapi.responses import Response
7
  import uvicorn
8
- from PIL import Image
9
- import io
10
- import numpy as np
11
  from lang_sam import LangSAM
12
- import supervision as sv
13
 
14
  app = FastAPI()
15
 
@@ -22,12 +19,16 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Create a cache directory in the app space
26
- os.makedirs("/app/cache", exist_ok=True)
27
 
28
  # Load the segmentation model
29
  model = LangSAM()
30
 
 
 
 
 
31
  def draw_image(image_rgb, masks, xyxy, probs, labels):
32
  mask_annotator = sv.MaskAnnotator()
33
  # Create class_id for each unique label
@@ -72,3 +73,7 @@ async def segment_image(file: UploadFile = File(...), text_prompt: str = Form(..
72
  img_io.seek(0)
73
 
74
  return Response(content=img_io.getvalue(), media_type="image/png")
 
 
 
 
 
1
  import os
 
2
 
3
+ # Set the cache directory to /tmp (which is writable in Hugging Face Spaces)
4
+ os.environ["TORCH_HOME"] = "/tmp/cache"
5
+
6
+ from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
 
8
  import uvicorn
 
 
 
9
  from lang_sam import LangSAM
 
10
 
11
  app = FastAPI()
12
 
 
19
  allow_headers=["*"],
20
  )
21
 
22
+ # Create a cache directory in /tmp
23
+ os.makedirs("/tmp/cache", exist_ok=True)
24
 
25
  # Load the segmentation model
26
  model = LangSAM()
27
 
28
+ @app.get("/")
29
+ async def root():
30
+ return {"message": "LangSAM API is running!"}
31
+
32
  def draw_image(image_rgb, masks, xyxy, probs, labels):
33
  mask_annotator = sv.MaskAnnotator()
34
  # Create class_id for each unique label
 
73
  img_io.seek(0)
74
 
75
  return Response(content=img_io.getvalue(), media_type="image/png")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ uvicorn.run(app, host="0.0.0.0", port=7860)