Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import io | |
import os | |
# ✅ Set Hugging Face model cache directory to a writable path | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
from huggingface_hub import hf_hub_download | |
from models.TumorModel import TumorClassification, GliomaStageModel | |
from utils import get_precautions_from_gemini | |
# Define your app | |
app = FastAPI() | |
# ✅ Enable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ✅ Load your models from the Hugging Face Hub | |
btd_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="brain_tumor_model.pt") | |
glioma_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="glioma_stage_model.pt") | |
btd_model = TumorClassification(model_path=btd_model_path) | |
glioma_model = GliomaStageModel(model_path=glioma_model_path) | |
# ✅ Image preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
class DiagnosisResponse(BaseModel): | |
tumor: str | |
stage: str | |
precautions: list | |
async def predict(file: UploadFile = File(...)): | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
image_tensor = transform(image).unsqueeze(0) | |
tumor_result = btd_model.predict(image_tensor) | |
if tumor_result == "No Tumor": | |
return DiagnosisResponse( | |
tumor="No Tumor Detected", | |
stage="N/A", | |
precautions=[] | |
) | |
stage_result = glioma_model.predict(image_tensor) | |
precautions = get_precautions_from_gemini(tumor_result, stage_result) | |
return DiagnosisResponse( | |
tumor=tumor_result, | |
stage=stage_result, | |
precautions=precautions | |
) | |
def root(): | |
return {"message": "Brain Tumor API is running."} | |