from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from PIL import Image import torch import torchvision.transforms as transforms from utils import BrainTumorModel, get_precautions_from_gemini app = FastAPI() # Load the model btd_model = BrainTumorModel() btd_model_path = "brain_tumor_model.pth" try: btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) btd_model.eval() except Exception as e: print(f"❌ Error loading model: {e}") # Define image transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Class labels (adjust if your model uses different labels) classes = ['glioma', 'meningioma', 'notumor', 'pituitary'] @app.get("/") def read_root(): return {"message": "Brain Tumor Detection API is running 🚀"} @app.post("/predict") async def predict(file: UploadFile = File(...)): try: image = Image.open(file.file).convert("RGB") image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] with torch.no_grad(): outputs = btd_model(image) _, predicted = torch.max(outputs.data, 1) predicted_class = classes[predicted.item()] precautions = get_precautions_from_gemini(predicted_class) return JSONResponse(content={ "prediction": predicted_class, "precautions": precautions }) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)