File size: 2,951 Bytes
f99073d
 
 
39333b1
a75fe7e
 
f99073d
 
 
e4acaca
37adc03
f99073d
 
e4acaca
37adc03
f99073d
28addcf
37adc03
39333b1
 
e8c5868
39333b1
 
 
28addcf
37adc03
e4acaca
28addcf
37adc03
39333b1
f0abd4b
e4acaca
 
37adc03
f99073d
39333b1
f99073d
 
 
37adc03
b81c13c
 
 
f0abd4b
37adc03
 
28addcf
f99073d
b81c13c
 
 
f0abd4b
b81c13c
 
 
 
39333b1
f99073d
 
 
 
 
39333b1
f99073d
 
 
39333b1
37adc03
f99073d
 
 
 
37adc03
f99073d
 
 
 
 
 
37adc03
 
 
 
 
 
 
f99073d
 
 
 
 
37adc03
f99073d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# newapi.py

from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os

# Set writable cache directories
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"

# FastAPI setup
app = FastAPI(title="🧠 Brain Tumor Detection API")

# Allow CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define image transform (grayscale)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=1),  # Ensure grayscale
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.286], std=[0.229]),  # Adjust mean/std if needed
])

# Define the exact same model used during training
import torch.nn as nn

class BrainTumorModel(nn.Module):
    def __init__(self):
        super(BrainTumorModel, self).__init__()
        self.con1d = nn.Conv2d(1, 32, kernel_size=3)   # Input is grayscale (1 channel)
        self.con2d = nn.Conv2d(32, 64, kernel_size=3)
        self.con3d = nn.Conv2d(64, 128, kernel_size=3)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(128 * 28 * 28, 512)      # Match the saved model's input size
        self.fc2 = nn.Linear(512, 256)
        self.output = nn.Linear(256, 4)               # 4 classes expected

    def forward(self, x):
        x = self.pool(torch.relu(self.con1d(x)))
        x = self.pool(torch.relu(self.con2d(x)))
        x = self.pool(torch.relu(self.con3d(x)))
        x = x.view(-1, 128 * 28 * 28)  # Flatten the feature maps
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.output(x)
        return x

# Load model
model_path = "BTD_model.pth"
if not os.path.exists(model_path):
    from huggingface_hub import hf_hub_download
    model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")

btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
btd_model.eval()

# Prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("L")  # Grayscale
        image_tensor = transform(image).unsqueeze(0)

        with torch.no_grad():
            output = btd_model(image_tensor)
            prediction = torch.argmax(output, dim=1).item()

        result = {
            0: "No tumor",
            1: "Glioma",
            2: "Meningioma",
            3: "Pituitary tumor"
        }[prediction]

        return {"prediction": result}

    except Exception as e:
        return {"error": str(e)}

# Health check
@app.get("/")
def root():
    return {"message": "🧠 Brain Tumor Detection API is running!"}