Spaces:
Runtime error
Runtime error
# newapi.py | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import io | |
import os | |
# Set writable cache directories | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache" | |
os.environ["HF_HOME"] = "/tmp/.cache" | |
# FastAPI setup | |
app = FastAPI(title="🧠 Brain Tumor Detection API") | |
# Allow CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define image transform (grayscale) | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.Grayscale(num_output_channels=1), # Ensure grayscale | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.286], std=[0.229]), # Adjust mean/std if needed | |
]) | |
# Define the exact same model used during training | |
import torch.nn as nn | |
class BrainTumorModel(nn.Module): | |
def __init__(self): | |
super(BrainTumorModel, self).__init__() | |
self.con1d = nn.Conv2d(1, 32, kernel_size=3) # Input is grayscale (1 channel) | |
self.con2d = nn.Conv2d(32, 64, kernel_size=3) | |
self.con3d = nn.Conv2d(64, 128, kernel_size=3) | |
self.pool = nn.MaxPool2d(2) | |
self.fc1 = nn.Linear(128 * 28 * 28, 512) # Match the saved model's input size | |
self.fc2 = nn.Linear(512, 256) | |
self.output = nn.Linear(256, 4) # 4 classes expected | |
def forward(self, x): | |
x = self.pool(torch.relu(self.con1d(x))) | |
x = self.pool(torch.relu(self.con2d(x))) | |
x = self.pool(torch.relu(self.con3d(x))) | |
x = x.view(-1, 128 * 28 * 28) # Flatten the feature maps | |
x = torch.relu(self.fc1(x)) | |
x = torch.relu(self.fc2(x)) | |
x = self.output(x) | |
return x | |
# Load model | |
model_path = "BTD_model.pth" | |
if not os.path.exists(model_path): | |
from huggingface_hub import hf_hub_download | |
model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache") | |
btd_model = BrainTumorModel() | |
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
btd_model.eval() | |
# Prediction endpoint | |
async def predict(file: UploadFile = File(...)): | |
try: | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("L") # Grayscale | |
image_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = btd_model(image_tensor) | |
prediction = torch.argmax(output, dim=1).item() | |
result = { | |
0: "No tumor", | |
1: "Glioma", | |
2: "Meningioma", | |
3: "Pituitary tumor" | |
}[prediction] | |
return {"prediction": result} | |
except Exception as e: | |
return {"error": str(e)} | |
# Health check | |
def root(): | |
return {"message": "🧠 Brain Tumor Detection API is running!"} |