Codewithsalty commited on
Commit
4285f13
·
verified ·
1 Parent(s): dbd7a98

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -35
  2. BTD_model.pth +3 -0
  3. TumorModel.py +37 -0
  4. app.py +77 -0
  5. glioma_stages.pth +3 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
BTD_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:269506d46e3854e3561379a7d4eb977194e66430b781b36bc56fab481e224f17
3
+ size 178115794
TumorModel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class TumorClassification(nn.Module):
4
+ def __init__(self):
5
+ super().__init__()
6
+ self.con1d = nn.Conv2d(1, 32, 3, padding=1)
7
+ self.con2d = nn.Conv2d(32, 64, 3, padding=1)
8
+ self.con3d = nn.Conv2d(64, 128, 3, padding=1)
9
+ self.pool = nn.MaxPool2d(2, 2)
10
+ self.fc1 = nn.Linear(128 * 28 * 28, 512)
11
+ self.fc2 = nn.Linear(512, 256)
12
+ self.output = nn.Linear(256, 4)
13
+ self.relu = nn.ReLU()
14
+
15
+ def forward(self, x):
16
+ x = self.pool(self.relu(self.con1d(x)))
17
+ x = self.pool(self.relu(self.con2d(x)))
18
+ x = self.pool(self.relu(self.con3d(x)))
19
+ x = x.view(x.size(0), -1)
20
+ x = self.relu(self.fc1(x))
21
+ x = self.relu(self.fc2(x))
22
+ return self.output(x)
23
+
24
+ class GliomaStageModel(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.fc1 = nn.Linear(9, 100)
28
+ self.fc2 = nn.Linear(100, 50)
29
+ self.fc3 = nn.Linear(50, 30)
30
+ self.out = nn.Linear(30, 4)
31
+ self.relu = nn.ReLU()
32
+
33
+ def forward(self, x):
34
+ x = self.relu(self.fc1(x))
35
+ x = self.relu(self.fc2(x))
36
+ x = self.relu(self.fc3(x))
37
+ return self.out(x)
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from TumorModel import TumorClassification, GliomaStageModel
7
+
8
+ # Title
9
+ st.set_page_config(page_title="Brain Tumor Predictor", layout="centered")
10
+ st.title("🧠 Brain Tumor Classifier & Glioma Stage Predictor")
11
+
12
+ # Load models
13
+ @st.cache_resource
14
+ def load_models():
15
+ tumor_model = TumorClassification()
16
+ tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
17
+ tumor_model.eval()
18
+
19
+ glioma_model = GliomaStageModel()
20
+ glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
21
+ glioma_model.eval()
22
+
23
+ return tumor_model, glioma_model
24
+
25
+ tumor_model, glioma_model = load_models()
26
+
27
+ # Labels
28
+ tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
29
+ stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
30
+
31
+ # Image Preprocessing
32
+ transform = transforms.Compose([
33
+ transforms.Grayscale(),
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5], [0.5])
37
+ ])
38
+
39
+ # Tabs
40
+ tab1, tab2 = st.tabs(["🧠 Tumor Type Detection", "🧬 Glioma Stage Prediction"])
41
+
42
+ # Tab 1: Tumor Classification
43
+ with tab1:
44
+ st.header("Upload Brain MRI")
45
+ image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
46
+ if image_file:
47
+ image = Image.open(image_file).convert("L")
48
+ st.image(image, caption="Uploaded Image", use_column_width=True)
49
+
50
+ image_tensor = transform(image).unsqueeze(0)
51
+ with torch.no_grad():
52
+ prediction = tumor_model(image_tensor)
53
+ pred_label = tumor_labels[torch.argmax(prediction)]
54
+ st.success(f"🧠 Tumor Prediction: **{pred_label.upper()}**")
55
+
56
+ # Tab 2: Glioma Stage Prediction
57
+ with tab2:
58
+ st.header("Patient & Genetic Info")
59
+ gender = st.radio("Gender", ["Male", "Female"])
60
+ age = st.slider("Age", 1, 100, 25)
61
+
62
+ st.subheader("Gene Mutations (0 = No, 1 = Yes)")
63
+ idh1 = st.radio("IDH1", [0, 1], horizontal=True)
64
+ tp53 = st.radio("TP53", [0, 1], horizontal=True)
65
+ atrx = st.radio("ATRX", [0, 1], horizontal=True)
66
+ pten = st.radio("PTEN", [0, 1], horizontal=True)
67
+ egfr = st.radio("EGFR", [0, 1], horizontal=True)
68
+ cic = st.radio("CIC", [0, 1], horizontal=True)
69
+ pik3ca = st.radio("PIK3CA", [0, 1], horizontal=True)
70
+
71
+ if st.button("🔍 Predict Glioma Stage"):
72
+ gender_val = 0 if gender == "Male" else 1
73
+ features = torch.tensor([[gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]], dtype=torch.float32)
74
+ with torch.no_grad():
75
+ prediction = glioma_model(features)
76
+ stage = stage_labels[torch.argmax(prediction)]
77
+ st.success(f"🧬 Predicted Glioma Stage: **{stage}**")
glioma_stages.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c64f506838505e6cfa74775ceed9260b03a978bf64243d123396393932cb1a2b
3
+ size 33544