File size: 3,201 Bytes
e4acaca
 
 
 
 
 
 
daee81e
e4acaca
 
 
 
 
ac416f7
 
a4a23df
ac416f7
daee81e
e4acaca
 
 
ac416f7
e4acaca
 
 
 
 
 
 
 
ac416f7
e4acaca
 
daee81e
bc39385
e4acaca
 
 
 
 
ac416f7
e4acaca
 
daee81e
bc39385
e4acaca
 
 
 
 
ac416f7
e4acaca
 
 
 
 
 
 
 
 
 
 
ac416f7
e4acaca
 
 
 
 
a4a23df
e4acaca
 
 
 
 
 
 
 
 
 
 
 
 
ac416f7
e4acaca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac416f7
e4acaca
 
daee81e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from torchvision import transforms
from PIL import Image
import io
import os
from huggingface_hub import hf_hub_download

from models.TumorModel import TumorClassification, GliomaStageModel
from utils import get_precautions_from_gemini

# ✅ Use Hugging Face's built-in writable cache directory
cache_dir = "/home/user/.cache/huggingface"

# No need to call os.makedirs — directory already exists

# Initialize FastAPI app
app = FastAPI(title="Brain Tumor Detection API")

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load tumor classification model
btd_model_path = hf_hub_download(
    repo_id="Codewithsalty/brain-tumor-models",
    filename="BTD_model.pth",
    cache_dir=cache_dir
)
tumor_model = TumorClassification()
tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
tumor_model.eval()

# Load glioma stage model
glioma_model_path = hf_hub_download(
    repo_id="Codewithsalty/brain-tumor-models",
    filename="glioma_stages.pth",
    cache_dir=cache_dir
)
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
glioma_model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

@app.get("/")
async def root():
    return {"message": "Brain Tumor Detection API is running."}

# Labels
labels = ['glioma', 'meningioma', 'notumor', 'pituitary']

@app.post("/predict-image")
async def predict_image(file: UploadFile = File(...)):
    img_bytes = await file.read()
    img = Image.open(io.BytesIO(img_bytes)).convert("L")
    x = transform(img).unsqueeze(0)

    with torch.no_grad():
        out = tumor_model(x)
        idx = torch.argmax(out, dim=1).item()
        tumor_type = labels[idx]

    if tumor_type == "glioma":
        return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
    else:
        precautions = get_precautions_from_gemini(tumor_type)
        return {"tumor_type": tumor_type, "precaution": precautions}

# Mutation input
class MutationInput(BaseModel):
    gender: str
    age: float
    idh1: int
    tp53: int
    atrx: int
    pten: int
    egfr: int
    cic: int
    pik3ca: int

@app.post("/predict-glioma-stage")
async def predict_glioma_stage(data: MutationInput):
    gender_val = 0 if data.gender.lower() == 'm' else 1
    features = [
        gender_val, data.age, data.idh1, data.tp53, data.atrx,
        data.pten, data.egfr, data.cic, data.pik3ca
    ]
    x = torch.tensor(features).float().unsqueeze(0)

    with torch.no_grad():
        out = glioma_model(x)
        idx = torch.argmax(out, dim=1).item()
        stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
        return {"glioma_stage": stages[idx]}

# For local development only
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("newapi:app", host="0.0.0.0", port=10000)