File size: 3,050 Bytes
39333b1
 
 
a75fe7e
28addcf
a75fe7e
 
e4acaca
39333b1
e4acaca
39333b1
 
 
 
28addcf
39333b1
 
 
 
 
 
 
 
 
28addcf
39333b1
 
e4acaca
28addcf
39333b1
 
e4acaca
 
39333b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28addcf
342c341
39333b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms

from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini

# ---- Constants ----
MODEL_DIR = "models"
BTD_FILENAME   = "BTD_model.pth"
GLIO_FILENAME = "glioma_stages.pth"

# ---- App setup ----
app = FastAPI(title="Brain Tumor Detection API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # adjust in production
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---- Device & transforms ----
DEVICE = torch.device("cpu")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

# ---- Load & init models ----
def load_model(cls, filename):
    path = os.path.join(MODEL_DIR, filename)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Model file not found: {path}")
    model = cls().to(DEVICE).eval()
    model.load_state_dict(torch.load(path, map_location=DEVICE))
    return model

try:
    tumor_model  = load_model(BrainTumorModel, BTD_FILENAME)
    glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
except Exception as e:
    # During startup, any exception here will show in logs
    print(f"❌ Error loading model: {e}")

# ---- Routes ----

@app.get("/")
async def health():
    return {"status": "ok", "message": "Brain Tumor API is live"}

@app.post("/predict-image/")
async def predict_image(file: UploadFile = File(...)):
    if file.content_type.split("/")[0] != "image":
        raise HTTPException(400, "Upload an image file")
    img = Image.open(file.file).convert("RGB")
    tensor = transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = tumor_model(tensor)
        idx = torch.argmax(out, dim=1).item()

    labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
    tumor_type = labels[idx]

    if tumor_type == "glioma":
        return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
    else:
        return {
            "tumor_type": tumor_type,
            "precaution": get_precautions_from_gemini(tumor_type)
        }

class MutationInput(BaseModel):
    gender: str
    age: float
    idh1: int
    tp53: int
    atrx: int
    pten: int
    egfr: int
    cic: int
    pik3ca: int

@app.post("/predict-glioma-stage/")
async def predict_glioma_stage(data: MutationInput):
    gender_val = 0 if data.gender.lower().startswith('m') else 1
    features = [gender_val, data.age, data.idh1, data.tp53,
                data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
    tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        out = glioma_model(tensor)
        idx = torch.argmax(out, dim=1).item()

    stages = ['Stage 1','Stage 2','Stage 3','Stage 4']
    return {"glioma_stage": stages[idx]}