Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import io | |
from huggingface_hub import hf_hub_download | |
from models.TumorModel import TumorClassification, GliomaStageModel | |
from utils import get_precautions_from_gemini | |
# Initialize FastAPI app | |
app = FastAPI(title="Brain Tumor Detection API") | |
# Enable CORS for all origins (adjust for production) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load Tumor Classification Model from Hugging Face | |
btd_model_path = hf_hub_download( | |
repo_id="Codewithsalty/brain-tumor-models", | |
filename="BTD_model.pth" | |
) | |
tumor_model = TumorClassification() | |
tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu")) | |
tumor_model.eval() | |
# Load Glioma Stage Prediction Model from Hugging Face | |
glioma_model_path = hf_hub_download( | |
repo_id="Codewithsalty/brain-tumor-models", | |
filename="glioma_stages.pth" | |
) | |
glioma_model = GliomaStageModel() | |
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu")) | |
glioma_model.eval() | |
# Image preprocessing pipeline | |
transform = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]), | |
]) | |
# Health check endpoint | |
async def root(): | |
return {"message": "Brain Tumor Detection API is running."} | |
# Tumor type labels | |
labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
# Predict tumor type from uploaded image | |
async def predict_image(file: UploadFile = File(...)): | |
img_bytes = await file.read() | |
img = Image.open(io.BytesIO(img_bytes)).convert("L") # Ensure grayscale | |
x = transform(img).unsqueeze(0) | |
with torch.no_grad(): | |
out = tumor_model(x) | |
idx = torch.argmax(out, dim=1).item() | |
tumor_type = labels[idx] | |
if tumor_type == "glioma": | |
return {"tumor_type": tumor_type, "next": "submit_mutation_data"} | |
else: | |
precautions = get_precautions_from_gemini(tumor_type) | |
return {"tumor_type": tumor_type, "precaution": precautions} | |
# Input model for glioma mutation data | |
class MutationInput(BaseModel): | |
gender: str | |
age: float | |
idh1: int | |
tp53: int | |
atrx: int | |
pten: int | |
egfr: int | |
cic: int | |
pik3ca: int | |
# Predict glioma stage based on mutations | |
async def predict_glioma_stage(data: MutationInput): | |
gender_val = 0 if data.gender.lower() == 'm' else 1 | |
features = [ | |
gender_val, data.age, data.idh1, data.tp53, data.atrx, | |
data.pten, data.egfr, data.cic, data.pik3ca | |
] | |
x = torch.tensor(features).float().unsqueeze(0) | |
with torch.no_grad(): | |
out = glioma_model(x) | |
idx = torch.argmax(out, dim=1).item() | |
stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] | |
return {"glioma_stage": stages[idx]} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("newapi:app", host="0.0.0.0", port=10000) |