Codewithsalty commited on
Commit
4bcd70b
·
verified ·
1 Parent(s): 28fd1d7

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +23 -15
newapi.py CHANGED
@@ -5,17 +5,20 @@ import torch
5
  from torchvision import transforms
6
  from PIL import Image
7
  import io
 
8
  from huggingface_hub import hf_hub_download
9
 
10
  from models.TumorModel import TumorClassification, GliomaStageModel
11
  from utils import get_precautions_from_gemini
12
 
13
- # ✅ Let Hugging Face handle cache automatically DO NOT manually create any folders
 
 
14
 
15
- # Initialize FastAPI app
16
  app = FastAPI(title="Brain Tumor Detection API")
17
 
18
- # Enable CORS for all origins
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
@@ -24,25 +27,27 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # ✅ Load Tumor Classification Model from Hugging Face
28
  btd_model_path = hf_hub_download(
29
  repo_id="Codewithsalty/brain-tumor-models",
30
- filename="BTD_model.pth"
 
31
  )
32
  tumor_model = TumorClassification()
33
  tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
34
  tumor_model.eval()
35
 
36
- # ✅ Load Glioma Stage Prediction Model from Hugging Face
37
  glioma_model_path = hf_hub_download(
38
  repo_id="Codewithsalty/brain-tumor-models",
39
- filename="glioma_stages.pth"
 
40
  )
41
  glioma_model = GliomaStageModel()
42
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
43
  glioma_model.eval()
44
 
45
- # Image preprocessing pipeline
46
  transform = transforms.Compose([
47
  transforms.Grayscale(),
48
  transforms.Resize((224, 224)),
@@ -50,19 +55,19 @@ transform = transforms.Compose([
50
  transforms.Normalize(mean=[0.5], std=[0.5]),
51
  ])
52
 
53
- # Health check endpoint
54
  @app.get("/")
55
  async def root():
56
  return {"message": "Brain Tumor Detection API is running."}
57
 
58
- # Tumor type labels
59
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
60
 
61
- # Predict tumor type from uploaded image
62
  @app.post("/predict-image")
63
  async def predict_image(file: UploadFile = File(...)):
64
  img_bytes = await file.read()
65
- img = Image.open(io.BytesIO(img_bytes)).convert("L") # Ensure grayscale
66
  x = transform(img).unsqueeze(0)
67
 
68
  with torch.no_grad():
@@ -76,7 +81,7 @@ async def predict_image(file: UploadFile = File(...)):
76
  precautions = get_precautions_from_gemini(tumor_type)
77
  return {"tumor_type": tumor_type, "precaution": precautions}
78
 
79
- # Input model for glioma mutation data
80
  class MutationInput(BaseModel):
81
  gender: str
82
  age: float
@@ -88,7 +93,7 @@ class MutationInput(BaseModel):
88
  cic: int
89
  pik3ca: int
90
 
91
- # Predict glioma stage based on mutations
92
  @app.post("/predict-glioma-stage")
93
  async def predict_glioma_stage(data: MutationInput):
94
  gender_val = 0 if data.gender.lower() == 'm' else 1
@@ -104,4 +109,7 @@ async def predict_glioma_stage(data: MutationInput):
104
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
105
  return {"glioma_stage": stages[idx]}
106
 
107
- # ✅ No need to run uvicorn manually in Hugging Face Spaces
 
 
 
 
5
  from torchvision import transforms
6
  from PIL import Image
7
  import io
8
+ import os
9
  from huggingface_hub import hf_hub_download
10
 
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
+ # ✅ Create a writable cache directory inside the current working directory
15
+ cache_dir = os.path.join(os.getcwd(), "cache")
16
+ os.makedirs(cache_dir, exist_ok=True)
17
 
18
+ # Initialize FastAPI app
19
  app = FastAPI(title="Brain Tumor Detection API")
20
 
21
+ # Enable CORS for frontend requests
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
 
27
  allow_headers=["*"],
28
  )
29
 
30
+ # ✅ Load Tumor Classification Model
31
  btd_model_path = hf_hub_download(
32
  repo_id="Codewithsalty/brain-tumor-models",
33
+ filename="BTD_model.pth",
34
+ cache_dir=cache_dir
35
  )
36
  tumor_model = TumorClassification()
37
  tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
38
  tumor_model.eval()
39
 
40
+ # ✅ Load Glioma Stage Model
41
  glioma_model_path = hf_hub_download(
42
  repo_id="Codewithsalty/brain-tumor-models",
43
+ filename="glioma_stages.pth",
44
+ cache_dir=cache_dir
45
  )
46
  glioma_model = GliomaStageModel()
47
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
48
  glioma_model.eval()
49
 
50
+ # Image preprocessing steps
51
  transform = transforms.Compose([
52
  transforms.Grayscale(),
53
  transforms.Resize((224, 224)),
 
55
  transforms.Normalize(mean=[0.5], std=[0.5]),
56
  ])
57
 
58
+ # Health check route
59
  @app.get("/")
60
  async def root():
61
  return {"message": "Brain Tumor Detection API is running."}
62
 
63
+ # Tumor labels
64
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
65
 
66
+ # Predict tumor type
67
  @app.post("/predict-image")
68
  async def predict_image(file: UploadFile = File(...)):
69
  img_bytes = await file.read()
70
+ img = Image.open(io.BytesIO(img_bytes)).convert("L")
71
  x = transform(img).unsqueeze(0)
72
 
73
  with torch.no_grad():
 
81
  precautions = get_precautions_from_gemini(tumor_type)
82
  return {"tumor_type": tumor_type, "precaution": precautions}
83
 
84
+ # Input format for glioma prediction
85
  class MutationInput(BaseModel):
86
  gender: str
87
  age: float
 
93
  cic: int
94
  pik3ca: int
95
 
96
+ # Predict glioma stage
97
  @app.post("/predict-glioma-stage")
98
  async def predict_glioma_stage(data: MutationInput):
99
  gender_val = 0 if data.gender.lower() == 'm' else 1
 
109
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
110
  return {"glioma_stage": stages[idx]}
111
 
112
+ # ✅ Optional: Only used when running locally (ignored on Spaces)
113
+ if __name__ == "__main__":
114
+ import uvicorn
115
+ uvicorn.run("newapi:app", host="0.0.0.0", port=10000)