File size: 2,679 Bytes
f99073d
 
 
39333b1
e8c5868
a75fe7e
 
f99073d
 
 
e4acaca
f99073d
 
 
e4acaca
f99073d
 
28addcf
f99073d
39333b1
 
e8c5868
39333b1
 
 
28addcf
f99073d
e4acaca
28addcf
39333b1
b81c13c
e4acaca
 
f99073d
 
39333b1
f99073d
 
 
b81c13c
 
 
 
 
 
 
28addcf
f99073d
b81c13c
 
 
 
 
 
 
 
39333b1
f99073d
 
 
 
 
39333b1
f99073d
 
 
39333b1
f99073d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# newapi.py

from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os

# Use a writable directory in Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"

# Define FastAPI app
app = FastAPI(title="🧠 Brain Tumor Detection API")

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# Define your model directly inside this file (to avoid import errors)
import torch.nn as nn

class BrainTumorModel(nn.Module):
    def __init__(self):
        super(BrainTumorModel, self).__init__()
        self.con1d = nn.Conv2d(3, 32, kernel_size=3)
        self.con2d = nn.Conv2d(32, 64, kernel_size=3)
        self.con3d = nn.Conv2d(64, 128, kernel_size=3)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(128 * 25 * 25, 256)
        self.fc2 = nn.Linear(256, 128)
        self.output = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool(torch.relu(self.con1d(x)))
        x = self.pool(torch.relu(self.con2d(x)))
        x = self.pool(torch.relu(self.con3d(x)))
        x = x.view(-1, 128 * 25 * 25)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.output(x)
        return x

# Load model
model_path = "BTD_model.pth"
if not os.path.exists(model_path):
    from huggingface_hub import hf_hub_download
    model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")

btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
btd_model.eval()

# Define prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        image_tensor = transform(image).unsqueeze(0)

        with torch.no_grad():
            output = btd_model(image_tensor)
            prediction = torch.argmax(output, dim=1).item()

        result = {0: "No tumor", 1: "Tumor detected"}[prediction]
        return {"prediction": result}

    except Exception as e:
        return {"error": str(e)}

# Health check endpoint
@app.get("/")
def root():
    return {"message": "🧠 Brain Tumor Detection API is running!"}