brain-tumor-api / newapi.py
Codewithsalty's picture
Update newapi.py
60d002b verified
raw
history blame
2.12 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from torchvision import transforms
from PIL import Image
import io
import os
# ✅ Set Hugging Face model cache directory to a writable path
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
from huggingface_hub import hf_hub_download
from models.TumorModel import TumorClassification, GliomaStageModel
from utils import get_precautions_from_gemini
# Define your app
app = FastAPI()
# ✅ Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ✅ Load your models from the Hugging Face Hub
btd_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="brain_tumor_model.pt")
glioma_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="glioma_stage_model.pt")
btd_model = TumorClassification(model_path=btd_model_path)
glioma_model = GliomaStageModel(model_path=glioma_model_path)
# ✅ Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
class DiagnosisResponse(BaseModel):
tumor: str
stage: str
precautions: list
@app.post("/predict", response_model=DiagnosisResponse)
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
tumor_result = btd_model.predict(image_tensor)
if tumor_result == "No Tumor":
return DiagnosisResponse(
tumor="No Tumor Detected",
stage="N/A",
precautions=[]
)
stage_result = glioma_model.predict(image_tensor)
precautions = get_precautions_from_gemini(tumor_result, stage_result)
return DiagnosisResponse(
tumor=tumor_result,
stage=stage_result,
precautions=precautions
)
@app.get("/")
def root():
return {"message": "Brain Tumor API is running."}