Spaces:
Runtime error
Runtime error
File size: 1,555 Bytes
28addcf a75fe7e 28addcf a75fe7e 28addcf e4acaca 342c341 e4acaca 28addcf e4acaca 28addcf e4acaca 28addcf 342c341 28addcf 374a9d4 28addcf 60d002b a75fe7e 28addcf 60d002b a75fe7e 342c341 28addcf 60d002b 28addcf e4acaca a75fe7e 28addcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, get_precautions_from_gemini
app = FastAPI()
# Load the model
btd_model = BrainTumorModel()
btd_model_path = "brain_tumor_model.pth"
try:
btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
btd_model.eval()
except Exception as e:
print(f"❌ Error loading model: {e}")
# Define image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Class labels (adjust if your model uses different labels)
classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
@app.get("/")
def read_root():
return {"message": "Brain Tumor Detection API is running 🚀"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image = Image.open(file.file).convert("RGB")
image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
with torch.no_grad():
outputs = btd_model(image)
_, predicted = torch.max(outputs.data, 1)
predicted_class = classes[predicted.item()]
precautions = get_precautions_from_gemini(predicted_class)
return JSONResponse(content={
"prediction": predicted_class,
"precautions": precautions
})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
|