Codewithsalty commited on
Commit
28fd1d7
·
verified ·
1 Parent(s): 68f05cc

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +14 -19
newapi.py CHANGED
@@ -5,20 +5,17 @@ import torch
5
  from torchvision import transforms
6
  from PIL import Image
7
  import io
8
- import os
9
  from huggingface_hub import hf_hub_download
10
 
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
- # ✅ Use a safe local cache dir
15
- cache_dir = "./hf_cache"
16
- os.makedirs(cache_dir, exist_ok=True) # create if it doesn't exist
17
 
18
  # Initialize FastAPI app
19
  app = FastAPI(title="Brain Tumor Detection API")
20
 
21
- # Enable CORS
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
@@ -27,27 +24,25 @@ app.add_middleware(
27
  allow_headers=["*"],
28
  )
29
 
30
- # Load tumor classification model
31
  btd_model_path = hf_hub_download(
32
  repo_id="Codewithsalty/brain-tumor-models",
33
- filename="BTD_model.pth",
34
- cache_dir=cache_dir
35
  )
36
  tumor_model = TumorClassification()
37
  tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
38
  tumor_model.eval()
39
 
40
- # Load glioma stage model
41
  glioma_model_path = hf_hub_download(
42
  repo_id="Codewithsalty/brain-tumor-models",
43
- filename="glioma_stages.pth",
44
- cache_dir=cache_dir
45
  )
46
  glioma_model = GliomaStageModel()
47
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
48
  glioma_model.eval()
49
 
50
- # Image preprocessing
51
  transform = transforms.Compose([
52
  transforms.Grayscale(),
53
  transforms.Resize((224, 224)),
@@ -55,17 +50,19 @@ transform = transforms.Compose([
55
  transforms.Normalize(mean=[0.5], std=[0.5]),
56
  ])
57
 
 
58
  @app.get("/")
59
  async def root():
60
  return {"message": "Brain Tumor Detection API is running."}
61
 
62
- # Labels
63
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
64
 
 
65
  @app.post("/predict-image")
66
  async def predict_image(file: UploadFile = File(...)):
67
  img_bytes = await file.read()
68
- img = Image.open(io.BytesIO(img_bytes)).convert("L")
69
  x = transform(img).unsqueeze(0)
70
 
71
  with torch.no_grad():
@@ -79,7 +76,7 @@ async def predict_image(file: UploadFile = File(...)):
79
  precautions = get_precautions_from_gemini(tumor_type)
80
  return {"tumor_type": tumor_type, "precaution": precautions}
81
 
82
- # Mutation input
83
  class MutationInput(BaseModel):
84
  gender: str
85
  age: float
@@ -91,6 +88,7 @@ class MutationInput(BaseModel):
91
  cic: int
92
  pik3ca: int
93
 
 
94
  @app.post("/predict-glioma-stage")
95
  async def predict_glioma_stage(data: MutationInput):
96
  gender_val = 0 if data.gender.lower() == 'm' else 1
@@ -106,7 +104,4 @@ async def predict_glioma_stage(data: MutationInput):
106
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
107
  return {"glioma_stage": stages[idx]}
108
 
109
- # Only needed for local development, not in Hugging Face Spaces
110
- if __name__ == "__main__":
111
- import uvicorn
112
- uvicorn.run("newapi:app", host="0.0.0.0", port=10000)
 
5
  from torchvision import transforms
6
  from PIL import Image
7
  import io
 
8
  from huggingface_hub import hf_hub_download
9
 
10
  from models.TumorModel import TumorClassification, GliomaStageModel
11
  from utils import get_precautions_from_gemini
12
 
13
+ # ✅ Let Hugging Face handle cache automatically — DO NOT manually create any folders
 
 
14
 
15
  # Initialize FastAPI app
16
  app = FastAPI(title="Brain Tumor Detection API")
17
 
18
+ # Enable CORS for all origins
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
 
24
  allow_headers=["*"],
25
  )
26
 
27
+ # Load Tumor Classification Model from Hugging Face
28
  btd_model_path = hf_hub_download(
29
  repo_id="Codewithsalty/brain-tumor-models",
30
+ filename="BTD_model.pth"
 
31
  )
32
  tumor_model = TumorClassification()
33
  tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
34
  tumor_model.eval()
35
 
36
+ # Load Glioma Stage Prediction Model from Hugging Face
37
  glioma_model_path = hf_hub_download(
38
  repo_id="Codewithsalty/brain-tumor-models",
39
+ filename="glioma_stages.pth"
 
40
  )
41
  glioma_model = GliomaStageModel()
42
  glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
43
  glioma_model.eval()
44
 
45
+ # Image preprocessing pipeline
46
  transform = transforms.Compose([
47
  transforms.Grayscale(),
48
  transforms.Resize((224, 224)),
 
50
  transforms.Normalize(mean=[0.5], std=[0.5]),
51
  ])
52
 
53
+ # Health check endpoint
54
  @app.get("/")
55
  async def root():
56
  return {"message": "Brain Tumor Detection API is running."}
57
 
58
+ # Tumor type labels
59
  labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
60
 
61
+ # Predict tumor type from uploaded image
62
  @app.post("/predict-image")
63
  async def predict_image(file: UploadFile = File(...)):
64
  img_bytes = await file.read()
65
+ img = Image.open(io.BytesIO(img_bytes)).convert("L") # Ensure grayscale
66
  x = transform(img).unsqueeze(0)
67
 
68
  with torch.no_grad():
 
76
  precautions = get_precautions_from_gemini(tumor_type)
77
  return {"tumor_type": tumor_type, "precaution": precautions}
78
 
79
+ # Input model for glioma mutation data
80
  class MutationInput(BaseModel):
81
  gender: str
82
  age: float
 
88
  cic: int
89
  pik3ca: int
90
 
91
+ # Predict glioma stage based on mutations
92
  @app.post("/predict-glioma-stage")
93
  async def predict_glioma_stage(data: MutationInput):
94
  gender_val = 0 if data.gender.lower() == 'm' else 1
 
104
  stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
105
  return {"glioma_stage": stages[idx]}
106
 
107
+ # No need to run uvicorn manually in Hugging Face Spaces