Codewithsalty commited on
Commit
cf81d06
·
verified ·
1 Parent(s): e133eb4

Update TumorModel.py

Browse files
Files changed (1) hide show
  1. TumorModel.py +4 -39
TumorModel.py CHANGED
@@ -1,48 +1,13 @@
1
- import torch.nn as nn
2
-
3
- # 🧠 Tumor Type Classification Model
4
- class TumorClassification(nn.Module):
5
- def __init__(self):
6
- super(TumorClassification, self).__init__()
7
- self.con1d = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
8
- self.relu1 = nn.ReLU()
9
- self.pool1 = nn.MaxPool2d(2)
10
-
11
- self.con2d = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
12
- self.relu2 = nn.ReLU()
13
- self.pool2 = nn.MaxPool2d(2)
14
-
15
- self.con3d = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
16
- self.relu3 = nn.ReLU()
17
- self.pool3 = nn.MaxPool2d(2)
18
-
19
- self.flatten = nn.Flatten()
20
- self.fc1 = nn.Linear(86528, 512) # Adjust this number to match your original
21
- self.relu_fc = nn.ReLU()
22
- self.fc2 = nn.Linear(512, 256)
23
- self.relu_fc2 = nn.ReLU()
24
- self.output = nn.Linear(256, 4)
25
-
26
- def forward(self, x):
27
- x = self.pool1(self.relu1(self.con1d(x)))
28
- x = self.pool2(self.relu2(self.con2d(x)))
29
- x = self.pool3(self.relu3(self.con3d(x)))
30
- x = self.flatten(x)
31
- x = self.relu_fc(self.fc1(x))
32
- x = self.relu_fc2(self.fc2(x))
33
- return self.output(x)
34
-
35
- # 🧬 Glioma Stage Prediction Model (MATCHES `glioma_stages.pth`)
36
  class GliomaStageModel(nn.Module):
37
  def __init__(self):
38
  super(GliomaStageModel, self).__init__()
39
- self.fc1 = nn.Linear(9, 128)
40
  self.relu1 = nn.ReLU()
41
- self.fc2 = nn.Linear(128, 64)
42
  self.relu2 = nn.ReLU()
43
- self.fc3 = nn.Linear(64, 32)
44
  self.relu3 = nn.ReLU()
45
- self.out = nn.Linear(32, 4)
46
 
47
  def forward(self, x):
48
  x = self.relu1(self.fc1(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class GliomaStageModel(nn.Module):
2
  def __init__(self):
3
  super(GliomaStageModel, self).__init__()
4
+ self.fc1 = nn.Linear(9, 100)
5
  self.relu1 = nn.ReLU()
6
+ self.fc2 = nn.Linear(100, 50)
7
  self.relu2 = nn.ReLU()
8
+ self.fc3 = nn.Linear(50, 30)
9
  self.relu3 = nn.ReLU()
10
+ self.out = nn.Linear(30, 2) # Only 2 classes in this model
11
 
12
  def forward(self, x):
13
  x = self.relu1(self.fc1(x))