brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
f99073d verified
raw
history blame
2.41 kB
# newapi.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os
# Use a writable directory in Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"
# Define FastAPI app
app = FastAPI(title="🧠 Brain Tumor Detection API")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Define your model directly inside this file (to avoid import errors)
import torch.nn as nn
class BrainTumorModel(nn.Module):
def __init__(self):
super(BrainTumorModel, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(32 * 54 * 54, 2),
)
def forward(self, x):
return self.model(x)
# Load model
model_path = "BTD_model.pth"
if not os.path.exists(model_path):
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")
btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
btd_model.eval()
# Define prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = btd_model(image_tensor)
prediction = torch.argmax(output, dim=1).item()
result = {0: "No tumor", 1: "Tumor detected"}[prediction]
return {"prediction": result}
except Exception as e:
return {"error": str(e)}
# Health check endpoint
@app.get("/")
def root():
return {"message": "🧠 Brain Tumor Detection API is running!"}