Spaces:
Runtime error
Runtime error
File size: 2,798 Bytes
39333b1 a75fe7e e8c5868 28addcf a75fe7e e4acaca 39333b1 e4acaca e8c5868 39333b1 28addcf e8c5868 39333b1 e8c5868 39333b1 28addcf 39333b1 e4acaca 28addcf 39333b1 e8c5868 e4acaca e8c5868 39333b1 e8c5868 39333b1 e8c5868 28addcf 342c341 39333b1 e8c5868 39333b1 e8c5868 39333b1 e8c5868 39333b1 e8c5868 39333b1 e8c5868 39333b1 e8c5868 39333b1 e8c5868 39333b1 e8c5868 39333b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
# ——— Constants ———
MODEL_DIR = "models"
BTD_FILENAME = "BTD_model.pth"
GLIO_FILENAME = "glioma_stages.pth"
# ——— App setup ———
app = FastAPI(title="Brain Tumor Detection API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = torch.device("cpu")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
def load_model(cls, fname):
path = os.path.join(MODEL_DIR, fname)
if not os.path.isfile(path):
raise FileNotFoundError(f"Model file not found: {path}")
m = cls().to(DEVICE)
m.load_state_dict(torch.load(path, map_location=DEVICE))
return m.eval()
try:
tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
except Exception as e:
print("❌ Error loading model:", e)
@app.get("/")
async def health():
return {"status": "ok", "message": "API is up"}
@app.post("/predict-image/")
async def predict_image(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(400, "Upload an image")
img = Image.open(file.file).convert("RGB")
t = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = tumor_model(t)
idx = int(out.argmax(1))
labels = ["glioma","meningioma","notumor","pituitary"]
tumor_type = labels[idx]
if tumor_type == "glioma":
return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
return {
"tumor_type": tumor_type,
"precaution": get_precautions_from_gemini(tumor_type)
}
class MutationInput(BaseModel):
gender: str
age: float
idh1: int
tp53: int
atrx: int
pten: int
egfr: int
cic: int
pik3ca: int
@app.post("/predict-glioma-stage/")
async def predict_glioma_stage(data: MutationInput):
gen = 0 if data.gender.lower().startswith("m") else 1
feats = [gen, data.age, data.idh1, data.tp53,
data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
t = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = glioma_model(t)
idx = int(out.argmax(1))
stages = ["Stage 1","Stage 2","Stage 3","Stage 4"]
return {"glioma_stage": stages[idx]}
|