Spaces:
Runtime error
Runtime error
import os | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini | |
# ---- Constants ---- | |
MODEL_DIR = "models" | |
BTD_FILENAME = "BTD_model.pth" | |
GLIO_FILENAME = "glioma_stages.pth" | |
# ---- App setup ---- | |
app = FastAPI(title="Brain Tumor Detection API") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # adjust in production | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ---- Device & transforms ---- | |
DEVICE = torch.device("cpu") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) | |
]) | |
# ---- Load & init models ---- | |
def load_model(cls, filename): | |
path = os.path.join(MODEL_DIR, filename) | |
if not os.path.isfile(path): | |
raise FileNotFoundError(f"Model file not found: {path}") | |
model = cls().to(DEVICE).eval() | |
model.load_state_dict(torch.load(path, map_location=DEVICE)) | |
return model | |
try: | |
tumor_model = load_model(BrainTumorModel, BTD_FILENAME) | |
glioma_model = load_model(GliomaStageModel, GLIO_FILENAME) | |
except Exception as e: | |
# During startup, any exception here will show in logs | |
print(f"❌ Error loading model: {e}") | |
# ---- Routes ---- | |
async def health(): | |
return {"status": "ok", "message": "Brain Tumor API is live"} | |
async def predict_image(file: UploadFile = File(...)): | |
if file.content_type.split("/")[0] != "image": | |
raise HTTPException(400, "Upload an image file") | |
img = Image.open(file.file).convert("RGB") | |
tensor = transform(img).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
out = tumor_model(tensor) | |
idx = torch.argmax(out, dim=1).item() | |
labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
tumor_type = labels[idx] | |
if tumor_type == "glioma": | |
return {"tumor_type": tumor_type, "next": "submit_mutation_data"} | |
else: | |
return { | |
"tumor_type": tumor_type, | |
"precaution": get_precautions_from_gemini(tumor_type) | |
} | |
class MutationInput(BaseModel): | |
gender: str | |
age: float | |
idh1: int | |
tp53: int | |
atrx: int | |
pten: int | |
egfr: int | |
cic: int | |
pik3ca: int | |
async def predict_glioma_stage(data: MutationInput): | |
gender_val = 0 if data.gender.lower().startswith('m') else 1 | |
features = [gender_val, data.age, data.idh1, data.tp53, | |
data.atrx, data.pten, data.egfr, data.cic, data.pik3ca] | |
tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
out = glioma_model(tensor) | |
idx = torch.argmax(out, dim=1).item() | |
stages = ['Stage 1','Stage 2','Stage 3','Stage 4'] | |
return {"glioma_stage": stages[idx]} | |