from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from torchvision import transforms from PIL import Image import io from huggingface_hub import hf_hub_download from models.TumorModel import TumorClassification, GliomaStageModel from utils import get_precautions_from_gemini # Initialize FastAPI app app = FastAPI(title="Brain Tumor Detection API") # Enable CORS for all origins (adjust for production) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load Tumor Classification Model from Hugging Face btd_model_path = hf_hub_download( repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth" ) tumor_model = TumorClassification() tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu")) tumor_model.eval() # Load Glioma Stage Prediction Model from Hugging Face glioma_model_path = hf_hub_download( repo_id="Codewithsalty/brain-tumor-models", filename="glioma_stages.pth" ) glioma_model = GliomaStageModel() glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu")) glioma_model.eval() # Image preprocessing pipeline transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]), ]) # Health check endpoint @app.get("/") async def root(): return {"message": "Brain Tumor Detection API is running."} # Tumor type labels labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] # Predict tumor type from uploaded image @app.post("/predict-image") async def predict_image(file: UploadFile = File(...)): img_bytes = await file.read() img = Image.open(io.BytesIO(img_bytes)).convert("L") # Ensure grayscale x = transform(img).unsqueeze(0) with torch.no_grad(): out = tumor_model(x) idx = torch.argmax(out, dim=1).item() tumor_type = labels[idx] if tumor_type == "glioma": return {"tumor_type": tumor_type, "next": "submit_mutation_data"} else: precautions = get_precautions_from_gemini(tumor_type) return {"tumor_type": tumor_type, "precaution": precautions} # Input model for glioma mutation data class MutationInput(BaseModel): gender: str age: float idh1: int tp53: int atrx: int pten: int egfr: int cic: int pik3ca: int # Predict glioma stage based on mutations @app.post("/predict-glioma-stage") async def predict_glioma_stage(data: MutationInput): gender_val = 0 if data.gender.lower() == 'm' else 1 features = [ gender_val, data.age, data.idh1, data.tp53, data.atrx, data.pten, data.egfr, data.cic, data.pik3ca ] x = torch.tensor(features).float().unsqueeze(0) with torch.no_grad(): out = glioma_model(x) idx = torch.argmax(out, dim=1).item() stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] return {"glioma_stage": stages[idx]} if __name__ == "__main__": import uvicorn uvicorn.run("newapi:app", host="0.0.0.0", port=10000)