brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
f0abd4b verified
# newapi.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os
# Set writable cache directories
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"
# FastAPI setup
app = FastAPI(title="🧠 Brain Tumor Detection API")
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Define image transform (grayscale)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=1), # Ensure grayscale
transforms.ToTensor(),
transforms.Normalize(mean=[0.286], std=[0.229]), # Adjust mean/std if needed
])
# Define the exact same model used during training
import torch.nn as nn
class BrainTumorModel(nn.Module):
def __init__(self):
super(BrainTumorModel, self).__init__()
self.con1d = nn.Conv2d(1, 32, kernel_size=3) # Input is grayscale (1 channel)
self.con2d = nn.Conv2d(32, 64, kernel_size=3)
self.con3d = nn.Conv2d(64, 128, kernel_size=3)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(128 * 28 * 28, 512) # Match the saved model's input size
self.fc2 = nn.Linear(512, 256)
self.output = nn.Linear(256, 4) # 4 classes expected
def forward(self, x):
x = self.pool(torch.relu(self.con1d(x)))
x = self.pool(torch.relu(self.con2d(x)))
x = self.pool(torch.relu(self.con3d(x)))
x = x.view(-1, 128 * 28 * 28) # Flatten the feature maps
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.output(x)
return x
# Load model
model_path = "BTD_model.pth"
if not os.path.exists(model_path):
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")
btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
btd_model.eval()
# Prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("L") # Grayscale
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = btd_model(image_tensor)
prediction = torch.argmax(output, dim=1).item()
result = {
0: "No tumor",
1: "Glioma",
2: "Meningioma",
3: "Pituitary tumor"
}[prediction]
return {"prediction": result}
except Exception as e:
return {"error": str(e)}
# Health check
@app.get("/")
def root():
return {"message": "🧠 Brain Tumor Detection API is running!"}