brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
a75fe7e verified
raw
history blame
2.01 kB
import os
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, GliomaStageModel
app = FastAPI()
# Load models (updated to local .pth files)
btd_model_path = "brain_tumor_model.pth"
glioma_model_path = "glioma_stage_model.pth"
# Initialize and load Brain Tumor Detection Model
btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
btd_model.eval()
# Initialize and load Glioma Stage Detection Model
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
glioma_model.eval()
# Define preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
image = Image.open(file.file).convert("RGB")
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = btd_model(image)
predicted = torch.argmax(output, dim=1).item()
classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma']
result = classes[predicted]
return JSONResponse(content={"prediction": result})
except Exception as e:
return JSONResponse(content={"error": str(e)})
@app.post("/glioma-stage/")
async def glioma_stage(file: UploadFile = File(...)):
try:
image = Image.open(file.file).convert("RGB")
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = glioma_model(image)
predicted = torch.argmax(output, dim=1).item()
stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
result = stages[predicted]
return JSONResponse(content={"glioma_stage": result})
except Exception as e:
return JSONResponse(content={"error": str(e)})