brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
39333b1 verified
raw
history blame
3.05 kB
import os
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
# ---- Constants ----
MODEL_DIR = "models"
BTD_FILENAME = "BTD_model.pth"
GLIO_FILENAME = "glioma_stages.pth"
# ---- App setup ----
app = FastAPI(title="Brain Tumor Detection API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # adjust in production
allow_methods=["*"],
allow_headers=["*"],
)
# ---- Device & transforms ----
DEVICE = torch.device("cpu")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])
# ---- Load & init models ----
def load_model(cls, filename):
path = os.path.join(MODEL_DIR, filename)
if not os.path.isfile(path):
raise FileNotFoundError(f"Model file not found: {path}")
model = cls().to(DEVICE).eval()
model.load_state_dict(torch.load(path, map_location=DEVICE))
return model
try:
tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
except Exception as e:
# During startup, any exception here will show in logs
print(f"❌ Error loading model: {e}")
# ---- Routes ----
@app.get("/")
async def health():
return {"status": "ok", "message": "Brain Tumor API is live"}
@app.post("/predict-image/")
async def predict_image(file: UploadFile = File(...)):
if file.content_type.split("/")[0] != "image":
raise HTTPException(400, "Upload an image file")
img = Image.open(file.file).convert("RGB")
tensor = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = tumor_model(tensor)
idx = torch.argmax(out, dim=1).item()
labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
tumor_type = labels[idx]
if tumor_type == "glioma":
return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
else:
return {
"tumor_type": tumor_type,
"precaution": get_precautions_from_gemini(tumor_type)
}
class MutationInput(BaseModel):
gender: str
age: float
idh1: int
tp53: int
atrx: int
pten: int
egfr: int
cic: int
pik3ca: int
@app.post("/predict-glioma-stage/")
async def predict_glioma_stage(data: MutationInput):
gender_val = 0 if data.gender.lower().startswith('m') else 1
features = [gender_val, data.age, data.idh1, data.tp53,
data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = glioma_model(tensor)
idx = torch.argmax(out, dim=1).item()
stages = ['Stage 1','Stage 2','Stage 3','Stage 4']
return {"glioma_stage": stages[idx]}