File size: 2,798 Bytes
39333b1
 
 
a75fe7e
e8c5868
28addcf
a75fe7e
 
e4acaca
39333b1
e4acaca
e8c5868
 
 
39333b1
28addcf
e8c5868
39333b1
 
 
e8c5868
39333b1
 
 
28addcf
39333b1
e4acaca
28addcf
39333b1
e8c5868
e4acaca
 
e8c5868
 
39333b1
 
e8c5868
 
 
39333b1
 
 
 
 
e8c5868
28addcf
342c341
39333b1
e8c5868
39333b1
 
 
e8c5868
 
39333b1
e8c5868
39333b1
e8c5868
 
 
39333b1
 
 
e8c5868
 
 
 
39333b1
 
 
e8c5868
 
 
 
 
 
 
39333b1
 
 
 
e8c5868
 
 
 
39333b1
e8c5868
 
 
39333b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from PIL import Image
import torch
import torchvision.transforms as transforms

from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini

# ——— Constants ———
MODEL_DIR     = "models"
BTD_FILENAME  = "BTD_model.pth"
GLIO_FILENAME = "glioma_stages.pth"

# ——— App setup ———
app = FastAPI(title="Brain Tumor Detection API")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

DEVICE = torch.device("cpu")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

def load_model(cls, fname):
    path = os.path.join(MODEL_DIR, fname)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Model file not found: {path}")
    m = cls().to(DEVICE)
    m.load_state_dict(torch.load(path, map_location=DEVICE))
    return m.eval()

try:
    tumor_model  = load_model(BrainTumorModel, BTD_FILENAME)
    glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
except Exception as e:
    print("❌ Error loading model:", e)

@app.get("/")
async def health():
    return {"status": "ok", "message": "API is up"}

@app.post("/predict-image/")
async def predict_image(file: UploadFile = File(...)):
    if not file.content_type.startswith("image/"):
        raise HTTPException(400, "Upload an image")
    img = Image.open(file.file).convert("RGB")
    t   = transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = tumor_model(t)
        idx = int(out.argmax(1))
    labels = ["glioma","meningioma","notumor","pituitary"]
    tumor_type = labels[idx]
    if tumor_type == "glioma":
        return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
    return {
        "tumor_type": tumor_type,
        "precaution": get_precautions_from_gemini(tumor_type)
    }

class MutationInput(BaseModel):
    gender: str
    age:    float
    idh1:   int
    tp53:   int
    atrx:   int
    pten:   int
    egfr:   int
    cic:    int
    pik3ca: int

@app.post("/predict-glioma-stage/")
async def predict_glioma_stage(data: MutationInput):
    gen = 0 if data.gender.lower().startswith("m") else 1
    feats = [gen, data.age, data.idh1, data.tp53,
             data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
    t = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = glioma_model(t)
        idx = int(out.argmax(1))
    stages = ["Stage 1","Stage 2","Stage 3","Stage 4"]
    return {"glioma_stage": stages[idx]}