Codewithsalty commited on
Commit
a4a23df
·
verified ·
1 Parent(s): bc39385

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +13 -9
newapi.py CHANGED
@@ -11,8 +11,11 @@ from huggingface_hub import hf_hub_download
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
- # Use Hugging Face's writable cache directory
15
- cache_dir = os.getenv("HF_HOME", "/data")
 
 
 
16
 
17
  # Initialize FastAPI app
18
  app = FastAPI(title="Brain Tumor Detection API")
@@ -26,7 +29,7 @@ app.add_middleware(
26
  allow_headers=["*"],
27
  )
28
 
29
- # Load Tumor Classification Model from Hugging Face
30
  btd_model_path = hf_hub_download(
31
  repo_id="Codewithsalty/brain-tumor-models",
32
  filename="BTD_model.pth",
@@ -36,7 +39,7 @@ tumor_model = TumorClassification()
36
  tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
37
  tumor_model.eval()
38
 
39
- # Load Glioma Stage Prediction Model from Hugging Face
40
  glioma_model_path = hf_hub_download(
41
  repo_id="Codewithsalty/brain-tumor-models",
42
  filename="glioma_stages.pth",
@@ -54,7 +57,7 @@ transform = transforms.Compose([
54
  transforms.Normalize(mean=[0.5], std=[0.5]),
55
  ])
56
 
57
- # Health check endpoint
58
  @app.get("/")
59
  async def root():
60
  return {"message": "Brain Tumor Detection API is running."}
@@ -62,11 +65,11 @@ async def root():
62
  # Tumor type labels
63
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
64
 
65
- # Predict tumor type from uploaded image
66
  @app.post("/predict-image")
67
  async def predict_image(file: UploadFile = File(...)):
68
  img_bytes = await file.read()
69
- img = Image.open(io.BytesIO(img_bytes)).convert("L") # Ensure grayscale
70
  x = transform(img).unsqueeze(0)
71
 
72
  with torch.no_grad():
@@ -80,7 +83,7 @@ async def predict_image(file: UploadFile = File(...)):
80
  precautions = get_precautions_from_gemini(tumor_type)
81
  return {"tumor_type": tumor_type, "precaution": precautions}
82
 
83
- # Input model for glioma mutation data
84
  class MutationInput(BaseModel):
85
  gender: str
86
  age: float
@@ -92,7 +95,7 @@ class MutationInput(BaseModel):
92
  cic: int
93
  pik3ca: int
94
 
95
- # Predict glioma stage based on mutations
96
  @app.post("/predict-glioma-stage")
97
  async def predict_glioma_stage(data: MutationInput):
98
  gender_val = 0 if data.gender.lower() == 'm' else 1
@@ -108,6 +111,7 @@ async def predict_glioma_stage(data: MutationInput):
108
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
109
  return {"glioma_stage": stages[idx]}
110
 
 
111
  if __name__ == "__main__":
112
  import uvicorn
113
  uvicorn.run("newapi:app", host="0.0.0.0", port=10000)
 
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
+ # Writable cache directory inside project
15
+ cache_dir = "./hf_cache"
16
+
17
+ # ✅ Create the directory if it doesn't exist
18
+ os.makedirs(cache_dir, exist_ok=True)
19
 
20
  # Initialize FastAPI app
21
  app = FastAPI(title="Brain Tumor Detection API")
 
29
  allow_headers=["*"],
30
  )
31
 
32
+ # Load Tumor Classification Model
33
  btd_model_path = hf_hub_download(
34
  repo_id="Codewithsalty/brain-tumor-models",
35
  filename="BTD_model.pth",
 
39
  tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
40
  tumor_model.eval()
41
 
42
+ # Load Glioma Stage Prediction Model
43
  glioma_model_path = hf_hub_download(
44
  repo_id="Codewithsalty/brain-tumor-models",
45
  filename="glioma_stages.pth",
 
57
  transforms.Normalize(mean=[0.5], std=[0.5]),
58
  ])
59
 
60
+ # Health check
61
  @app.get("/")
62
  async def root():
63
  return {"message": "Brain Tumor Detection API is running."}
 
65
  # Tumor type labels
66
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
67
 
68
+ # Predict tumor type from image
69
  @app.post("/predict-image")
70
  async def predict_image(file: UploadFile = File(...)):
71
  img_bytes = await file.read()
72
+ img = Image.open(io.BytesIO(img_bytes)).convert("L")
73
  x = transform(img).unsqueeze(0)
74
 
75
  with torch.no_grad():
 
83
  precautions = get_precautions_from_gemini(tumor_type)
84
  return {"tumor_type": tumor_type, "precaution": precautions}
85
 
86
+ # Mutation input model
87
  class MutationInput(BaseModel):
88
  gender: str
89
  age: float
 
95
  cic: int
96
  pik3ca: int
97
 
98
+ # Predict glioma stage
99
  @app.post("/predict-glioma-stage")
100
  async def predict_glioma_stage(data: MutationInput):
101
  gender_val = 0 if data.gender.lower() == 'm' else 1
 
111
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
112
  return {"glioma_stage": stages[idx]}
113
 
114
+ # Run locally (ignored on Spaces, used only for dev/testing)
115
  if __name__ == "__main__":
116
  import uvicorn
117
  uvicorn.run("newapi:app", host="0.0.0.0", port=10000)