Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from utils import BrainTumorModel, get_precautions_from_gemini | |
app = FastAPI() | |
# Load the model | |
btd_model = BrainTumorModel() | |
btd_model_path = "brain_tumor_model.pth" | |
try: | |
btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu'))) | |
btd_model.eval() | |
except Exception as e: | |
print(f"❌ Error loading model: {e}") | |
# Define image transform | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
# Class labels (adjust if your model uses different labels) | |
classes = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
def read_root(): | |
return {"message": "Brain Tumor Detection API is running 🚀"} | |
async def predict(file: UploadFile = File(...)): | |
try: | |
image = Image.open(file.file).convert("RGB") | |
image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224] | |
with torch.no_grad(): | |
outputs = btd_model(image) | |
_, predicted = torch.max(outputs.data, 1) | |
predicted_class = classes[predicted.item()] | |
precautions = get_precautions_from_gemini(predicted_class) | |
return JSONResponse(content={ | |
"prediction": predicted_class, | |
"precautions": precautions | |
}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |