Codewithsalty commited on
Commit
64d242b
·
verified ·
1 Parent(s): 97c833a

Upload 5 files

Browse files
Files changed (5) hide show
  1. BTD_model.pth +3 -0
  2. TumorModel.py +34 -0
  3. app.py +79 -0
  4. glioma_stages.pth +3 -0
  5. requirements.txt +4 -0
BTD_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:269506d46e3854e3561379a7d4eb977194e66430b781b36bc56fab481e224f17
3
+ size 178115794
TumorModel.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class TumorClassification(nn.Module):
4
+ def __init__(self):
5
+ super().__init__()
6
+ self.model = nn.Sequential(
7
+ nn.Conv2d(1, 16, 3, 1, 1),
8
+ nn.ReLU(),
9
+ nn.MaxPool2d(2),
10
+ nn.Conv2d(16, 32, 3, 1, 1),
11
+ nn.ReLU(),
12
+ nn.MaxPool2d(2),
13
+ nn.Flatten(),
14
+ nn.Linear(32 * 56 * 56, 128),
15
+ nn.ReLU(),
16
+ nn.Linear(128, 4)
17
+ )
18
+
19
+ def forward(self, x):
20
+ return self.model(x)
21
+
22
+ class GliomaStageModel(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.model = nn.Sequential(
26
+ nn.Linear(9, 128),
27
+ nn.ReLU(),
28
+ nn.Linear(128, 64),
29
+ nn.ReLU(),
30
+ nn.Linear(64, 4)
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.model(x)
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from TumorModel import TumorClassification, GliomaStageModel
6
+
7
+ # Load tumor classification model
8
+ tumor_model = TumorClassification()
9
+ tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
10
+ tumor_model.eval()
11
+
12
+ # Load glioma stage model
13
+ glioma_model = GliomaStageModel()
14
+ glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
15
+ glioma_model.eval()
16
+
17
+ # Labels
18
+ tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
19
+ stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
20
+
21
+ # Image transform
22
+ transform = transforms.Compose([
23
+ transforms.Grayscale(),
24
+ transforms.Resize((224, 224)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.5], std=[0.5])
27
+ ])
28
+
29
+ # Tumor type prediction
30
+ def predict_tumor(image):
31
+ image = transform(image).unsqueeze(0)
32
+ with torch.no_grad():
33
+ out = tumor_model(image)
34
+ pred = torch.argmax(out, dim=1).item()
35
+ return tumor_labels[pred]
36
+
37
+ # Glioma stage prediction
38
+ def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
39
+ gender_val = 0 if gender == "Male" else 1
40
+ features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
41
+ x = torch.tensor(features).float().unsqueeze(0)
42
+ with torch.no_grad():
43
+ out = glioma_model(x)
44
+ pred = torch.argmax(out, dim=1).item()
45
+ return stage_labels[pred]
46
+
47
+ # Interface 1: Tumor Detector
48
+ tumor_tab = gr.Interface(
49
+ fn=predict_tumor,
50
+ inputs=gr.Image(type="pil"),
51
+ outputs=gr.Label(),
52
+ title="🧠 Brain Tumor Detection",
53
+ description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
54
+ )
55
+
56
+ # Interface 2: Glioma Stage Prediction
57
+ stage_tab = gr.Interface(
58
+ fn=predict_stage,
59
+ inputs=[
60
+ gr.Radio(["Male", "Female"], label="Gender"),
61
+ gr.Slider(0, 100, label="Age"),
62
+ gr.Slider(0, 1, step=1, label="IDH1 Mutation"),
63
+ gr.Slider(0, 1, step=1, label="TP53 Mutation"),
64
+ gr.Slider(0, 1, step=1, label="ATRX Mutation"),
65
+ gr.Slider(0, 1, step=1, label="PTEN Mutation"),
66
+ gr.Slider(0, 1, step=1, label="EGFR Mutation"),
67
+ gr.Slider(0, 1, step=1, label="CIC Mutation"),
68
+ gr.Slider(0, 1, step=1, label="PIK3CA Mutation")
69
+ ],
70
+ outputs=gr.Label(),
71
+ title="🧬 Glioma Stage Predictor",
72
+ description="Enter patient’s mutation and demographic data to predict glioma stage."
73
+ )
74
+
75
+ # Combine in tabs
76
+ demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
77
+
78
+ # Run app
79
+ demo.launch()
glioma_stages.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c64f506838505e6cfa74775ceed9260b03a978bf64243d123396393932cb1a2b
3
+ size 33544
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ pillow