brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
28addcf verified
raw
history blame
1.56 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, get_precautions_from_gemini
app = FastAPI()
# Load the model
btd_model = BrainTumorModel()
btd_model_path = "brain_tumor_model.pth"
try:
btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
btd_model.eval()
except Exception as e:
print(f"❌ Error loading model: {e}")
# Define image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Class labels (adjust if your model uses different labels)
classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
@app.get("/")
def read_root():
return {"message": "Brain Tumor Detection API is running 🚀"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image = Image.open(file.file).convert("RGB")
image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
with torch.no_grad():
outputs = btd_model(image)
_, predicted = torch.max(outputs.data, 1)
predicted_class = classes[predicted.item()]
precautions = get_precautions_from_gemini(predicted_class)
return JSONResponse(content={
"prediction": predicted_class,
"precautions": precautions
})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)