brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
e8c5868 verified
raw
history blame
2.8 kB
import os
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
# ——— Constants ———
MODEL_DIR = "models"
BTD_FILENAME = "BTD_model.pth"
GLIO_FILENAME = "glioma_stages.pth"
# ——— App setup ———
app = FastAPI(title="Brain Tumor Detection API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = torch.device("cpu")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
def load_model(cls, fname):
path = os.path.join(MODEL_DIR, fname)
if not os.path.isfile(path):
raise FileNotFoundError(f"Model file not found: {path}")
m = cls().to(DEVICE)
m.load_state_dict(torch.load(path, map_location=DEVICE))
return m.eval()
try:
tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
except Exception as e:
print("❌ Error loading model:", e)
@app.get("/")
async def health():
return {"status": "ok", "message": "API is up"}
@app.post("/predict-image/")
async def predict_image(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(400, "Upload an image")
img = Image.open(file.file).convert("RGB")
t = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = tumor_model(t)
idx = int(out.argmax(1))
labels = ["glioma","meningioma","notumor","pituitary"]
tumor_type = labels[idx]
if tumor_type == "glioma":
return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
return {
"tumor_type": tumor_type,
"precaution": get_precautions_from_gemini(tumor_type)
}
class MutationInput(BaseModel):
gender: str
age: float
idh1: int
tp53: int
atrx: int
pten: int
egfr: int
cic: int
pik3ca: int
@app.post("/predict-glioma-stage/")
async def predict_glioma_stage(data: MutationInput):
gen = 0 if data.gender.lower().startswith("m") else 1
feats = [gen, data.age, data.idh1, data.tp53,
data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
t = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = glioma_model(t)
idx = int(out.argmax(1))
stages = ["Stage 1","Stage 2","Stage 3","Stage 4"]
return {"glioma_stage": stages[idx]}