Spaces:
Runtime error
Runtime error
File size: 3,203 Bytes
e4acaca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from torchvision import transforms
from PIL import Image
import io
from huggingface_hub import hf_hub_download
from models.TumorModel import TumorClassification, GliomaStageModel
from utils import get_precautions_from_gemini
# Initialize FastAPI app
app = FastAPI(title="Brain Tumor Detection API")
# Enable CORS for all origins (adjust for production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load Tumor Classification Model from Hugging Face
btd_model_path = hf_hub_download(
repo_id="Codewithsalty/brain-tumor-models",
filename="BTD_model.pth"
)
tumor_model = TumorClassification()
tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
tumor_model.eval()
# Load Glioma Stage Prediction Model from Hugging Face
glioma_model_path = hf_hub_download(
repo_id="Codewithsalty/brain-tumor-models",
filename="glioma_stages.pth"
)
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
glioma_model.eval()
# Image preprocessing pipeline
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
# Health check endpoint
@app.get("/")
async def root():
return {"message": "Brain Tumor Detection API is running."}
# Tumor type labels
labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
# Predict tumor type from uploaded image
@app.post("/predict-image")
async def predict_image(file: UploadFile = File(...)):
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("L") # Ensure grayscale
x = transform(img).unsqueeze(0)
with torch.no_grad():
out = tumor_model(x)
idx = torch.argmax(out, dim=1).item()
tumor_type = labels[idx]
if tumor_type == "glioma":
return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
else:
precautions = get_precautions_from_gemini(tumor_type)
return {"tumor_type": tumor_type, "precaution": precautions}
# Input model for glioma mutation data
class MutationInput(BaseModel):
gender: str
age: float
idh1: int
tp53: int
atrx: int
pten: int
egfr: int
cic: int
pik3ca: int
# Predict glioma stage based on mutations
@app.post("/predict-glioma-stage")
async def predict_glioma_stage(data: MutationInput):
gender_val = 0 if data.gender.lower() == 'm' else 1
features = [
gender_val, data.age, data.idh1, data.tp53, data.atrx,
data.pten, data.egfr, data.cic, data.pik3ca
]
x = torch.tensor(features).float().unsqueeze(0)
with torch.no_grad():
out = glioma_model(x)
idx = torch.argmax(out, dim=1).item()
stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
return {"glioma_stage": stages[idx]}
if __name__ == "__main__":
import uvicorn
uvicorn.run("newapi:app", host="0.0.0.0", port=10000) |