File size: 1,555 Bytes
28addcf
a75fe7e
28addcf
a75fe7e
 
28addcf
e4acaca
342c341
e4acaca
28addcf
 
 
 
 
 
 
 
 
 
 
e4acaca
28addcf
 
e4acaca
 
28addcf
 
 
342c341
28addcf
 
374a9d4
28addcf
60d002b
a75fe7e
28addcf
 
60d002b
a75fe7e
342c341
28addcf
 
 
60d002b
28addcf
 
 
 
e4acaca
a75fe7e
28addcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, get_precautions_from_gemini

app = FastAPI()

# Load the model
btd_model = BrainTumorModel()
btd_model_path = "brain_tumor_model.pth"

try:
    btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
    btd_model.eval()
except Exception as e:
    print(f"❌ Error loading model: {e}")

# Define image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Class labels (adjust if your model uses different labels)
classes = ['glioma', 'meningioma', 'notumor', 'pituitary']

@app.get("/")
def read_root():
    return {"message": "Brain Tumor Detection API is running 🚀"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        image = Image.open(file.file).convert("RGB")
        image = transform(image).unsqueeze(0)  # Shape: [1, 3, 224, 224]

        with torch.no_grad():
            outputs = btd_model(image)
            _, predicted = torch.max(outputs.data, 1)
            predicted_class = classes[predicted.item()]
            precautions = get_precautions_from_gemini(predicted_class)

        return JSONResponse(content={
            "prediction": predicted_class,
            "precautions": precautions
        })

    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)