import os from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from PIL import Image import torch import torchvision.transforms as transforms from utils import BrainTumorModel, GliomaStageModel from huggingface_hub import hf_hub_download app = FastAPI() # Download models from the Space's repo btd_model_path = hf_hub_download( repo_id="Codewithsalty/brain-tumor-api", filename="brain_tumor_model.pth", repo_type="space" ) glioma_model_path = hf_hub_download( repo_id="Codewithsalty/brain-tumor-api", filename="glioma_stage_model.pth", repo_type="space" ) # Load and prepare models btd_model = BrainTumorModel() btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) btd_model.eval() glioma_model = GliomaStageModel() glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu'))) glioma_model.eval() # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) @app.post("/predict/") async def predict(file: UploadFile = File(...)): try: image = Image.open(file.file).convert("RGB") image = transform(image).unsqueeze(0) with torch.no_grad(): output = btd_model(image) predicted = torch.argmax(output, dim=1).item() classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma'] result = classes[predicted] return JSONResponse(content={"prediction": result}) except Exception as e: return JSONResponse(content={"error": str(e)}) @app.post("/glioma-stage/") async def glioma_stage(file: UploadFile = File(...)): try: image = Image.open(file.file).convert("RGB") image = transform(image).unsqueeze(0) with torch.no_grad(): output = glioma_model(image) predicted = torch.argmax(output, dim=1).item() stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] result = stages[predicted] return JSONResponse(content={"glioma_stage": result}) except Exception as e: return JSONResponse(content={"error": str(e)})