Codewithsalty commited on
Commit
119a255
·
verified ·
1 Parent(s): 4bcd70b

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +3 -11
newapi.py CHANGED
@@ -11,14 +11,12 @@ from huggingface_hub import hf_hub_download
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
- # ✅ Create a writable cache directory inside the current working directory
15
- cache_dir = os.path.join(os.getcwd(), "cache")
16
  os.makedirs(cache_dir, exist_ok=True)
17
 
18
- # ✅ Initialize FastAPI app
19
  app = FastAPI(title="Brain Tumor Detection API")
20
 
21
- # ✅ Enable CORS for frontend requests
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
@@ -47,7 +45,6 @@ glioma_model = GliomaStageModel()
47
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
48
  glioma_model.eval()
49
 
50
- # ✅ Image preprocessing steps
51
  transform = transforms.Compose([
52
  transforms.Grayscale(),
53
  transforms.Resize((224, 224)),
@@ -55,15 +52,12 @@ transform = transforms.Compose([
55
  transforms.Normalize(mean=[0.5], std=[0.5]),
56
  ])
57
 
58
- # ✅ Health check route
59
  @app.get("/")
60
  async def root():
61
  return {"message": "Brain Tumor Detection API is running."}
62
 
63
- # ✅ Tumor labels
64
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
65
 
66
- # ✅ Predict tumor type
67
  @app.post("/predict-image")
68
  async def predict_image(file: UploadFile = File(...)):
69
  img_bytes = await file.read()
@@ -81,7 +75,6 @@ async def predict_image(file: UploadFile = File(...)):
81
  precautions = get_precautions_from_gemini(tumor_type)
82
  return {"tumor_type": tumor_type, "precaution": precautions}
83
 
84
- # ✅ Input format for glioma prediction
85
  class MutationInput(BaseModel):
86
  gender: str
87
  age: float
@@ -93,7 +86,6 @@ class MutationInput(BaseModel):
93
  cic: int
94
  pik3ca: int
95
 
96
- # ✅ Predict glioma stage
97
  @app.post("/predict-glioma-stage")
98
  async def predict_glioma_stage(data: MutationInput):
99
  gender_val = 0 if data.gender.lower() == 'm' else 1
@@ -109,7 +101,7 @@ async def predict_glioma_stage(data: MutationInput):
109
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
110
  return {"glioma_stage": stages[idx]}
111
 
112
- # ✅ Optional: Only used when running locally (ignored on Spaces)
113
  if __name__ == "__main__":
114
  import uvicorn
115
  uvicorn.run("newapi:app", host="0.0.0.0", port=10000)
 
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
+ # ✅ Use /data as Hugging Face allows writing here only
15
+ cache_dir = os.path.join("/data", "cache")
16
  os.makedirs(cache_dir, exist_ok=True)
17
 
 
18
  app = FastAPI(title="Brain Tumor Detection API")
19
 
 
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
 
45
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
46
  glioma_model.eval()
47
 
 
48
  transform = transforms.Compose([
49
  transforms.Grayscale(),
50
  transforms.Resize((224, 224)),
 
52
  transforms.Normalize(mean=[0.5], std=[0.5]),
53
  ])
54
 
 
55
  @app.get("/")
56
  async def root():
57
  return {"message": "Brain Tumor Detection API is running."}
58
 
 
59
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
60
 
 
61
  @app.post("/predict-image")
62
  async def predict_image(file: UploadFile = File(...)):
63
  img_bytes = await file.read()
 
75
  precautions = get_precautions_from_gemini(tumor_type)
76
  return {"tumor_type": tumor_type, "precaution": precautions}
77
 
 
78
  class MutationInput(BaseModel):
79
  gender: str
80
  age: float
 
86
  cic: int
87
  pik3ca: int
88
 
 
89
  @app.post("/predict-glioma-stage")
90
  async def predict_glioma_stage(data: MutationInput):
91
  gender_val = 0 if data.gender.lower() == 'm' else 1
 
101
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
102
  return {"glioma_stage": stages[idx]}
103
 
104
+ # Only used when running locally
105
  if __name__ == "__main__":
106
  import uvicorn
107
  uvicorn.run("newapi:app", host="0.0.0.0", port=10000)