Codewithsalty commited on
Commit
39333b1
·
verified ·
1 Parent(s): 63e5aaf

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +86 -37
newapi.py CHANGED
@@ -1,51 +1,100 @@
1
- from fastapi import FastAPI, UploadFile, File
 
 
2
  from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import torch
5
  import torchvision.transforms as transforms
6
- from utils import BrainTumorModel, get_precautions_from_gemini
7
 
8
- app = FastAPI()
9
 
10
- # Load the model
11
- btd_model = BrainTumorModel()
12
- btd_model_path = "brain_tumor_model.pth"
 
13
 
14
- try:
15
- btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
16
- btd_model.eval()
17
- except Exception as e:
18
- print(f"❌ Error loading model: {e}")
 
 
 
 
19
 
20
- # Define image transform
 
21
  transform = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
- transforms.ToTensor()
 
24
  ])
25
 
26
- # Class labels (adjust if your model uses different labels)
27
- classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  @app.get("/")
30
- def read_root():
31
- return {"message": "Brain Tumor Detection API is running 🚀"}
32
-
33
- @app.post("/predict")
34
- async def predict(file: UploadFile = File(...)):
35
- try:
36
- image = Image.open(file.file).convert("RGB")
37
- image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
38
-
39
- with torch.no_grad():
40
- outputs = btd_model(image)
41
- _, predicted = torch.max(outputs.data, 1)
42
- predicted_class = classes[predicted.item()]
43
- precautions = get_precautions_from_gemini(predicted_class)
44
-
45
- return JSONResponse(content={
46
- "prediction": predicted_class,
47
- "precautions": precautions
48
- })
49
-
50
- except Exception as e:
51
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
5
  from PIL import Image
6
  import torch
7
  import torchvision.transforms as transforms
 
8
 
9
+ from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
10
 
11
+ # ---- Constants ----
12
+ MODEL_DIR = "models"
13
+ BTD_FILENAME = "BTD_model.pth"
14
+ GLIO_FILENAME = "glioma_stages.pth"
15
 
16
+ # ---- App setup ----
17
+ app = FastAPI(title="Brain Tumor Detection API")
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # adjust in production
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
 
26
+ # ---- Device & transforms ----
27
+ DEVICE = torch.device("cpu")
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
32
  ])
33
 
34
+ # ---- Load & init models ----
35
+ def load_model(cls, filename):
36
+ path = os.path.join(MODEL_DIR, filename)
37
+ if not os.path.isfile(path):
38
+ raise FileNotFoundError(f"Model file not found: {path}")
39
+ model = cls().to(DEVICE).eval()
40
+ model.load_state_dict(torch.load(path, map_location=DEVICE))
41
+ return model
42
+
43
+ try:
44
+ tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
45
+ glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
46
+ except Exception as e:
47
+ # During startup, any exception here will show in logs
48
+ print(f"❌ Error loading model: {e}")
49
+
50
+ # ---- Routes ----
51
 
52
  @app.get("/")
53
+ async def health():
54
+ return {"status": "ok", "message": "Brain Tumor API is live"}
55
+
56
+ @app.post("/predict-image/")
57
+ async def predict_image(file: UploadFile = File(...)):
58
+ if file.content_type.split("/")[0] != "image":
59
+ raise HTTPException(400, "Upload an image file")
60
+ img = Image.open(file.file).convert("RGB")
61
+ tensor = transform(img).unsqueeze(0).to(DEVICE)
62
+ with torch.no_grad():
63
+ out = tumor_model(tensor)
64
+ idx = torch.argmax(out, dim=1).item()
65
+
66
+ labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
67
+ tumor_type = labels[idx]
68
+
69
+ if tumor_type == "glioma":
70
+ return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
71
+ else:
72
+ return {
73
+ "tumor_type": tumor_type,
74
+ "precaution": get_precautions_from_gemini(tumor_type)
75
+ }
76
+
77
+ class MutationInput(BaseModel):
78
+ gender: str
79
+ age: float
80
+ idh1: int
81
+ tp53: int
82
+ atrx: int
83
+ pten: int
84
+ egfr: int
85
+ cic: int
86
+ pik3ca: int
87
+
88
+ @app.post("/predict-glioma-stage/")
89
+ async def predict_glioma_stage(data: MutationInput):
90
+ gender_val = 0 if data.gender.lower().startswith('m') else 1
91
+ features = [gender_val, data.age, data.idh1, data.tp53,
92
+ data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
93
+ tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)
94
+
95
+ with torch.no_grad():
96
+ out = glioma_model(tensor)
97
+ idx = torch.argmax(out, dim=1).item()
98
+
99
+ stages = ['Stage 1','Stage 2','Stage 3','Stage 4']
100
+ return {"glioma_stage": stages[idx]}