import os from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from PIL import Image import torch import torchvision.transforms as transforms from utils import BrainTumorModel, GliomaStageModel app = FastAPI() # Load models (updated to local .pth files) btd_model_path = "brain_tumor_model.pth" glioma_model_path = "glioma_stage_model.pth" # Initialize and load Brain Tumor Detection Model btd_model = BrainTumorModel() btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) btd_model.eval() # Initialize and load Glioma Stage Detection Model glioma_model = GliomaStageModel() glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu'))) glioma_model.eval() # Define preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) @app.post("/predict/") async def predict(file: UploadFile = File(...)): try: image = Image.open(file.file).convert("RGB") image = transform(image).unsqueeze(0) with torch.no_grad(): output = btd_model(image) predicted = torch.argmax(output, dim=1).item() classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma'] result = classes[predicted] return JSONResponse(content={"prediction": result}) except Exception as e: return JSONResponse(content={"error": str(e)}) @app.post("/glioma-stage/") async def glioma_stage(file: UploadFile = File(...)): try: image = Image.open(file.file).convert("RGB") image = transform(image).unsqueeze(0) with torch.no_grad(): output = glioma_model(image) predicted = torch.argmax(output, dim=1).item() stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] result = stages[predicted] return JSONResponse(content={"glioma_stage": result}) except Exception as e: return JSONResponse(content={"error": str(e)})