Spaces:
Runtime error
Runtime error
File size: 2,408 Bytes
f99073d 39333b1 e8c5868 a75fe7e f99073d e4acaca f99073d e4acaca f99073d 28addcf f99073d 39333b1 e8c5868 39333b1 28addcf f99073d e4acaca 28addcf 39333b1 f99073d e4acaca f99073d 39333b1 f99073d 28addcf f99073d 39333b1 f99073d 39333b1 f99073d 39333b1 f99073d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
# newapi.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os
# Use a writable directory in Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"
# Define FastAPI app
app = FastAPI(title="🧠 Brain Tumor Detection API")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Define your model directly inside this file (to avoid import errors)
import torch.nn as nn
class BrainTumorModel(nn.Module):
def __init__(self):
super(BrainTumorModel, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(32 * 54 * 54, 2),
)
def forward(self, x):
return self.model(x)
# Load model
model_path = "BTD_model.pth"
if not os.path.exists(model_path):
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")
btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
btd_model.eval()
# Define prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = btd_model(image_tensor)
prediction = torch.argmax(output, dim=1).item()
result = {0: "No tumor", 1: "Tumor detected"}[prediction]
return {"prediction": result}
except Exception as e:
return {"error": str(e)}
# Health check endpoint
@app.get("/")
def root():
return {"message": "🧠 Brain Tumor Detection API is running!"} |