Spaces:
Runtime error
Runtime error
File size: 2,951 Bytes
f99073d 39333b1 a75fe7e f99073d e4acaca 37adc03 f99073d e4acaca 37adc03 f99073d 28addcf 37adc03 39333b1 e8c5868 39333b1 28addcf 37adc03 e4acaca 28addcf 37adc03 39333b1 f0abd4b e4acaca 37adc03 f99073d 39333b1 f99073d 37adc03 b81c13c f0abd4b 37adc03 28addcf f99073d b81c13c f0abd4b b81c13c 39333b1 f99073d 39333b1 f99073d 39333b1 37adc03 f99073d 37adc03 f99073d 37adc03 f99073d 37adc03 f99073d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# newapi.py
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import os
# Set writable cache directories
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"
# FastAPI setup
app = FastAPI(title="🧠 Brain Tumor Detection API")
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Define image transform (grayscale)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=1), # Ensure grayscale
transforms.ToTensor(),
transforms.Normalize(mean=[0.286], std=[0.229]), # Adjust mean/std if needed
])
# Define the exact same model used during training
import torch.nn as nn
class BrainTumorModel(nn.Module):
def __init__(self):
super(BrainTumorModel, self).__init__()
self.con1d = nn.Conv2d(1, 32, kernel_size=3) # Input is grayscale (1 channel)
self.con2d = nn.Conv2d(32, 64, kernel_size=3)
self.con3d = nn.Conv2d(64, 128, kernel_size=3)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(128 * 28 * 28, 512) # Match the saved model's input size
self.fc2 = nn.Linear(512, 256)
self.output = nn.Linear(256, 4) # 4 classes expected
def forward(self, x):
x = self.pool(torch.relu(self.con1d(x)))
x = self.pool(torch.relu(self.con2d(x)))
x = self.pool(torch.relu(self.con3d(x)))
x = x.view(-1, 128 * 28 * 28) # Flatten the feature maps
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.output(x)
return x
# Load model
model_path = "BTD_model.pth"
if not os.path.exists(model_path):
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")
btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
btd_model.eval()
# Prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("L") # Grayscale
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = btd_model(image_tensor)
prediction = torch.argmax(output, dim=1).item()
result = {
0: "No tumor",
1: "Glioma",
2: "Meningioma",
3: "Pituitary tumor"
}[prediction]
return {"prediction": result}
except Exception as e:
return {"error": str(e)}
# Health check
@app.get("/")
def root():
return {"message": "🧠 Brain Tumor Detection API is running!"} |