Spaces:
Runtime error
Runtime error
import os | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from utils import BrainTumorModel, GliomaStageModel | |
app = FastAPI() | |
# Load models (updated to local .pth files) | |
btd_model_path = "brain_tumor_model.pth" | |
glioma_model_path = "glioma_stage_model.pth" | |
# Initialize and load Brain Tumor Detection Model | |
btd_model = BrainTumorModel() | |
btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) | |
btd_model.eval() | |
# Initialize and load Glioma Stage Detection Model | |
glioma_model = GliomaStageModel() | |
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu'))) | |
glioma_model.eval() | |
# Define preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
async def predict(file: UploadFile = File(...)): | |
try: | |
image = Image.open(file.file).convert("RGB") | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = btd_model(image) | |
predicted = torch.argmax(output, dim=1).item() | |
classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma'] | |
result = classes[predicted] | |
return JSONResponse(content={"prediction": result}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}) | |
async def glioma_stage(file: UploadFile = File(...)): | |
try: | |
image = Image.open(file.file).convert("RGB") | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = glioma_model(image) | |
predicted = torch.argmax(output, dim=1).item() | |
stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] | |
result = stages[predicted] | |
return JSONResponse(content={"glioma_stage": result}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}) | |