# newapi.py from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware import torch import torchvision.transforms as transforms from PIL import Image import io import os # Set writable cache directories os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache" os.environ["HF_HOME"] = "/tmp/.cache" # FastAPI setup app = FastAPI(title="🧠 Brain Tumor Detection API") # Allow CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Define image transform (grayscale) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=1), # Ensure grayscale transforms.ToTensor(), transforms.Normalize(mean=[0.286], std=[0.229]), # Adjust mean/std if needed ]) # Define the exact same model used during training import torch.nn as nn class BrainTumorModel(nn.Module): def __init__(self): super(BrainTumorModel, self).__init__() self.con1d = nn.Conv2d(1, 32, kernel_size=3) # Input is grayscale (1 channel) self.con2d = nn.Conv2d(32, 64, kernel_size=3) self.con3d = nn.Conv2d(64, 128, kernel_size=3) self.pool = nn.MaxPool2d(2) self.fc1 = nn.Linear(128 * 28 * 28, 512) # Match the saved model's input size self.fc2 = nn.Linear(512, 256) self.output = nn.Linear(256, 4) # 4 classes expected def forward(self, x): x = self.pool(torch.relu(self.con1d(x))) x = self.pool(torch.relu(self.con2d(x))) x = self.pool(torch.relu(self.con3d(x))) x = x.view(-1, 128 * 28 * 28) # Flatten the feature maps x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.output(x) return x # Load model model_path = "BTD_model.pth" if not os.path.exists(model_path): from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache") btd_model = BrainTumorModel() btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) btd_model.eval() # Prediction endpoint @app.post("/predict/") async def predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("L") # Grayscale image_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = btd_model(image_tensor) prediction = torch.argmax(output, dim=1).item() result = { 0: "No tumor", 1: "Glioma", 2: "Meningioma", 3: "Pituitary tumor" }[prediction] return {"prediction": result} except Exception as e: return {"error": str(e)} # Health check @app.get("/") def root(): return {"message": "🧠 Brain Tumor Detection API is running!"}