import os from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from PIL import Image import torch import torchvision.transforms as transforms from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini # ——— Constants ——— MODEL_DIR = "models" BTD_FILENAME = "BTD_model.pth" GLIO_FILENAME = "glioma_stages.pth" # ——— App setup ——— app = FastAPI(title="Brain Tumor Detection API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) DEVICE = torch.device("cpu") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3), ]) def load_model(cls, fname): path = os.path.join(MODEL_DIR, fname) if not os.path.isfile(path): raise FileNotFoundError(f"Model file not found: {path}") m = cls().to(DEVICE) m.load_state_dict(torch.load(path, map_location=DEVICE)) return m.eval() try: tumor_model = load_model(BrainTumorModel, BTD_FILENAME) glioma_model = load_model(GliomaStageModel, GLIO_FILENAME) except Exception as e: print("❌ Error loading model:", e) @app.get("/") async def health(): return {"status": "ok", "message": "API is up"} @app.post("/predict-image/") async def predict_image(file: UploadFile = File(...)): if not file.content_type.startswith("image/"): raise HTTPException(400, "Upload an image") img = Image.open(file.file).convert("RGB") t = transform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = tumor_model(t) idx = int(out.argmax(1)) labels = ["glioma","meningioma","notumor","pituitary"] tumor_type = labels[idx] if tumor_type == "glioma": return {"tumor_type": tumor_type, "next": "submit_mutation_data"} return { "tumor_type": tumor_type, "precaution": get_precautions_from_gemini(tumor_type) } class MutationInput(BaseModel): gender: str age: float idh1: int tp53: int atrx: int pten: int egfr: int cic: int pik3ca: int @app.post("/predict-glioma-stage/") async def predict_glioma_stage(data: MutationInput): gen = 0 if data.gender.lower().startswith("m") else 1 feats = [gen, data.age, data.idh1, data.tp53, data.atrx, data.pten, data.egfr, data.cic, data.pik3ca] t = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = glioma_model(t) idx = int(out.argmax(1)) stages = ["Stage 1","Stage 2","Stage 3","Stage 4"] return {"glioma_stage": stages[idx]}