File size: 3,235 Bytes
e4acaca
 
 
 
 
 
 
 
 
 
 
 
28fd1d7
daee81e
e4acaca
 
 
28fd1d7
e4acaca
 
 
 
 
 
 
 
28fd1d7
e4acaca
 
28fd1d7
e4acaca
 
 
 
 
28fd1d7
e4acaca
 
28fd1d7
e4acaca
 
 
 
 
28fd1d7
e4acaca
 
 
 
 
 
 
28fd1d7
e4acaca
 
 
 
28fd1d7
e4acaca
 
28fd1d7
e4acaca
 
 
28fd1d7
e4acaca
 
 
 
 
 
 
 
 
 
 
 
 
28fd1d7
e4acaca
 
 
 
 
 
 
 
 
 
 
28fd1d7
e4acaca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28fd1d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from torchvision import transforms
from PIL import Image
import io
from huggingface_hub import hf_hub_download

from models.TumorModel import TumorClassification, GliomaStageModel
from utils import get_precautions_from_gemini

# ✅ Let Hugging Face handle cache automatically — DO NOT manually create any folders

# Initialize FastAPI app
app = FastAPI(title="Brain Tumor Detection API")

# Enable CORS for all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ✅ Load Tumor Classification Model from Hugging Face
btd_model_path = hf_hub_download(
    repo_id="Codewithsalty/brain-tumor-models",
    filename="BTD_model.pth"
)
tumor_model = TumorClassification()
tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
tumor_model.eval()

# ✅ Load Glioma Stage Prediction Model from Hugging Face
glioma_model_path = hf_hub_download(
    repo_id="Codewithsalty/brain-tumor-models",
    filename="glioma_stages.pth"
)
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
glioma_model.eval()

# Image preprocessing pipeline
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# Health check endpoint
@app.get("/")
async def root():
    return {"message": "Brain Tumor Detection API is running."}

# Tumor type labels
labels = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Predict tumor type from uploaded image
@app.post("/predict-image")
async def predict_image(file: UploadFile = File(...)):
    img_bytes = await file.read()
    img = Image.open(io.BytesIO(img_bytes)).convert("L")  # Ensure grayscale
    x = transform(img).unsqueeze(0)

    with torch.no_grad():
        out = tumor_model(x)
        idx = torch.argmax(out, dim=1).item()
        tumor_type = labels[idx]

    if tumor_type == "glioma":
        return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
    else:
        precautions = get_precautions_from_gemini(tumor_type)
        return {"tumor_type": tumor_type, "precaution": precautions}

# Input model for glioma mutation data
class MutationInput(BaseModel):
    gender: str
    age: float
    idh1: int
    tp53: int
    atrx: int
    pten: int
    egfr: int
    cic: int
    pik3ca: int

# Predict glioma stage based on mutations
@app.post("/predict-glioma-stage")
async def predict_glioma_stage(data: MutationInput):
    gender_val = 0 if data.gender.lower() == 'm' else 1
    features = [
        gender_val, data.age, data.idh1, data.tp53, data.atrx,
        data.pten, data.egfr, data.cic, data.pik3ca
    ]
    x = torch.tensor(features).float().unsqueeze(0)

    with torch.no_grad():
        out = glioma_model(x)
        idx = torch.argmax(out, dim=1).item()
        stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
        return {"glioma_stage": stages[idx]}

# ✅ No need to run uvicorn manually in Hugging Face Spaces