import os from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from PIL import Image import torch import torchvision.transforms as transforms from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini # ---- Constants ---- MODEL_DIR = "models" BTD_FILENAME = "BTD_model.pth" GLIO_FILENAME = "glioma_stages.pth" # ---- App setup ---- app = FastAPI(title="Brain Tumor Detection API") app.add_middleware( CORSMiddleware, allow_origins=["*"], # adjust in production allow_methods=["*"], allow_headers=["*"], ) # ---- Device & transforms ---- DEVICE = torch.device("cpu") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) ]) # ---- Load & init models ---- def load_model(cls, filename): path = os.path.join(MODEL_DIR, filename) if not os.path.isfile(path): raise FileNotFoundError(f"Model file not found: {path}") model = cls().to(DEVICE).eval() model.load_state_dict(torch.load(path, map_location=DEVICE)) return model try: tumor_model = load_model(BrainTumorModel, BTD_FILENAME) glioma_model = load_model(GliomaStageModel, GLIO_FILENAME) except Exception as e: # During startup, any exception here will show in logs print(f"❌ Error loading model: {e}") # ---- Routes ---- @app.get("/") async def health(): return {"status": "ok", "message": "Brain Tumor API is live"} @app.post("/predict-image/") async def predict_image(file: UploadFile = File(...)): if file.content_type.split("/")[0] != "image": raise HTTPException(400, "Upload an image file") img = Image.open(file.file).convert("RGB") tensor = transform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = tumor_model(tensor) idx = torch.argmax(out, dim=1).item() labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] tumor_type = labels[idx] if tumor_type == "glioma": return {"tumor_type": tumor_type, "next": "submit_mutation_data"} else: return { "tumor_type": tumor_type, "precaution": get_precautions_from_gemini(tumor_type) } class MutationInput(BaseModel): gender: str age: float idh1: int tp53: int atrx: int pten: int egfr: int cic: int pik3ca: int @app.post("/predict-glioma-stage/") async def predict_glioma_stage(data: MutationInput): gender_val = 0 if data.gender.lower().startswith('m') else 1 features = [gender_val, data.age, data.idh1, data.tp53, data.atrx, data.pten, data.egfr, data.cic, data.pik3ca] tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = glioma_model(tensor) idx = torch.argmax(out, dim=1).item() stages = ['Stage 1','Stage 2','Stage 3','Stage 4'] return {"glioma_stage": stages[idx]}