Codewithsalty commited on
Commit
e8c5868
·
verified ·
1 Parent(s): 72e8105

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +38 -48
newapi.py CHANGED
@@ -2,99 +2,89 @@ import os
2
  from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
 
5
  from PIL import Image
6
  import torch
7
  import torchvision.transforms as transforms
8
 
9
  from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
10
 
11
- # ---- Constants ----
12
- MODEL_DIR = "models"
13
- BTD_FILENAME = "BTD_model.pth"
14
  GLIO_FILENAME = "glioma_stages.pth"
15
 
16
- # ---- App setup ----
17
  app = FastAPI(title="Brain Tumor Detection API")
18
-
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=["*"], # adjust in production
22
  allow_methods=["*"],
23
  allow_headers=["*"],
24
  )
25
 
26
- # ---- Device & transforms ----
27
  DEVICE = torch.device("cpu")
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
- transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
32
  ])
33
 
34
- # ---- Load & init models ----
35
- def load_model(cls, filename):
36
- path = os.path.join(MODEL_DIR, filename)
37
  if not os.path.isfile(path):
38
  raise FileNotFoundError(f"Model file not found: {path}")
39
- model = cls().to(DEVICE).eval()
40
- model.load_state_dict(torch.load(path, map_location=DEVICE))
41
- return model
42
 
43
  try:
44
  tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
45
  glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
46
  except Exception as e:
47
- # During startup, any exception here will show in logs
48
- print(f"❌ Error loading model: {e}")
49
-
50
- # ---- Routes ----
51
 
52
  @app.get("/")
53
  async def health():
54
- return {"status": "ok", "message": "Brain Tumor API is live"}
55
 
56
  @app.post("/predict-image/")
57
  async def predict_image(file: UploadFile = File(...)):
58
- if file.content_type.split("/")[0] != "image":
59
- raise HTTPException(400, "Upload an image file")
60
  img = Image.open(file.file).convert("RGB")
61
- tensor = transform(img).unsqueeze(0).to(DEVICE)
62
  with torch.no_grad():
63
- out = tumor_model(tensor)
64
- idx = torch.argmax(out, dim=1).item()
65
-
66
- labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
67
  tumor_type = labels[idx]
68
-
69
  if tumor_type == "glioma":
70
  return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
71
- else:
72
- return {
73
- "tumor_type": tumor_type,
74
- "precaution": get_precautions_from_gemini(tumor_type)
75
- }
76
 
77
  class MutationInput(BaseModel):
78
  gender: str
79
- age: float
80
- idh1: int
81
- tp53: int
82
- atrx: int
83
- pten: int
84
- egfr: int
85
- cic: int
86
  pik3ca: int
87
 
88
  @app.post("/predict-glioma-stage/")
89
  async def predict_glioma_stage(data: MutationInput):
90
- gender_val = 0 if data.gender.lower().startswith('m') else 1
91
- features = [gender_val, data.age, data.idh1, data.tp53,
92
- data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
93
- tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)
94
-
95
  with torch.no_grad():
96
- out = glioma_model(tensor)
97
- idx = torch.argmax(out, dim=1).item()
98
-
99
- stages = ['Stage 1','Stage 2','Stage 3','Stage 4']
100
  return {"glioma_stage": stages[idx]}
 
2
  from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
  from PIL import Image
7
  import torch
8
  import torchvision.transforms as transforms
9
 
10
  from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
11
 
12
+ # ——— Constants ———
13
+ MODEL_DIR = "models"
14
+ BTD_FILENAME = "BTD_model.pth"
15
  GLIO_FILENAME = "glioma_stages.pth"
16
 
17
+ # ——— App setup ———
18
  app = FastAPI(title="Brain Tumor Detection API")
 
19
  app.add_middleware(
20
  CORSMiddleware,
21
+ allow_origins=["*"],
22
  allow_methods=["*"],
23
  allow_headers=["*"],
24
  )
25
 
 
26
  DEVICE = torch.device("cpu")
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
30
+ transforms.Normalize([0.5]*3, [0.5]*3),
31
  ])
32
 
33
+ def load_model(cls, fname):
34
+ path = os.path.join(MODEL_DIR, fname)
 
35
  if not os.path.isfile(path):
36
  raise FileNotFoundError(f"Model file not found: {path}")
37
+ m = cls().to(DEVICE)
38
+ m.load_state_dict(torch.load(path, map_location=DEVICE))
39
+ return m.eval()
40
 
41
  try:
42
  tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
43
  glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
44
  except Exception as e:
45
+ print("❌ Error loading model:", e)
 
 
 
46
 
47
  @app.get("/")
48
  async def health():
49
+ return {"status": "ok", "message": "API is up"}
50
 
51
  @app.post("/predict-image/")
52
  async def predict_image(file: UploadFile = File(...)):
53
+ if not file.content_type.startswith("image/"):
54
+ raise HTTPException(400, "Upload an image")
55
  img = Image.open(file.file).convert("RGB")
56
+ t = transform(img).unsqueeze(0).to(DEVICE)
57
  with torch.no_grad():
58
+ out = tumor_model(t)
59
+ idx = int(out.argmax(1))
60
+ labels = ["glioma","meningioma","notumor","pituitary"]
 
61
  tumor_type = labels[idx]
 
62
  if tumor_type == "glioma":
63
  return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
64
+ return {
65
+ "tumor_type": tumor_type,
66
+ "precaution": get_precautions_from_gemini(tumor_type)
67
+ }
 
68
 
69
  class MutationInput(BaseModel):
70
  gender: str
71
+ age: float
72
+ idh1: int
73
+ tp53: int
74
+ atrx: int
75
+ pten: int
76
+ egfr: int
77
+ cic: int
78
  pik3ca: int
79
 
80
  @app.post("/predict-glioma-stage/")
81
  async def predict_glioma_stage(data: MutationInput):
82
+ gen = 0 if data.gender.lower().startswith("m") else 1
83
+ feats = [gen, data.age, data.idh1, data.tp53,
84
+ data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
85
+ t = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(DEVICE)
 
86
  with torch.no_grad():
87
+ out = glioma_model(t)
88
+ idx = int(out.argmax(1))
89
+ stages = ["Stage 1","Stage 2","Stage 3","Stage 4"]
 
90
  return {"glioma_stage": stages[idx]}