File size: 3,413 Bytes
c73e079
 
 
 
 
 
91d7910
c73e079
91d7910
 
 
 
 
 
 
 
 
 
 
 
c73e079
 
91d7910
c73e079
91d7910
c73e079
 
91d7910
c73e079
91d7910
c73e079
 
 
91d7910
c73e079
 
 
 
91d7910
 
c73e079
91d7910
c73e079
91d7910
c73e079
91d7910
c73e079
 
 
 
 
 
 
 
91d7910
 
 
c73e079
 
 
 
 
91d7910
c73e079
 
 
 
 
 
 
91d7910
 
 
 
 
 
 
c73e079
 
91d7910
e57c445
c73e079
91d7910
c73e079
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from TumorModel import TumorClassification, GliomaStageModel

# ─── Load Tumor Classification Model ───────────────────────────────────────────
tumor_model = TumorClassification()
sd = torch.load("BTD_model.pth", map_location="cpu")
renamed_sd = {}
for k, v in sd.items():
    new_key = (k
        .replace("con1d.", "model.0.")
        .replace("con2d.", "model.3.")
        .replace("con3d.", "model.6.")
        .replace("fc1.", "model.8.")
        .replace("fc2.", "model.10.")
        .replace("output.", "model.12."))
    renamed_sd[new_key] = v
tumor_model.load_state_dict(renamed_sd)
tumor_model.eval()

# ─── Load Glioma Stage Model ───────────────────────────────────────────────────
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location="cpu"))
glioma_model.eval()

# ─── Labels and Image Transform ───────────────────────────────────────────────
tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
stage_labels = ['Stage 1', 'Stage 2']  # Or adjust to match your second model

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((208, 208)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# ─── Gradio Prediction Functions ───────────────────────────────────────────────

def predict_tumor(image):
    tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        out = tumor_model(tensor)
        pred = torch.argmax(out, dim=1).item()
    return tumor_labels[pred]

def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
    gender_val = 0 if gender == "Male" else 1
    features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
    x = torch.tensor(features).float().unsqueeze(0)
    with torch.no_grad():
        out = glioma_model(x)
        pred = torch.argmax(out, dim=1).item()
    return stage_labels[pred]

# ─── Gradio UI ────────────────────────────────────────────────────────────────

tumor_tab = gr.Interface(
    fn=predict_tumor,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(),
    title="Brain Tumor Detector"
)

stage_tab = gr.Interface(
    fn=predict_stage,
    inputs=[
        gr.Radio(["Male", "Female"], label="Gender"),
        gr.Slider(0, 100, label="Age"),
        gr.Slider(0, 1, step=1, label="IDH1"),
        gr.Slider(0, 1, step=1, label="TP53"),
        gr.Slider(0, 1, step=1, label="ATRX"),
        gr.Slider(0, 1, step=1, label="PTEN"),
        gr.Slider(0, 1, step=1, label="EGFR"),
        gr.Slider(0, 1, step=1, label="CIC"),
        gr.Slider(0, 1, step=1, label="PIK3CA")
    ],
    outputs=gr.Label(),
    title="Glioma Stage Predictor"
)

demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
demo.launch()