Codewithsalty commited on
Commit
33617dd
·
verified ·
1 Parent(s): c73e079

Update TumorModel.py

Browse files
Files changed (1) hide show
  1. TumorModel.py +28 -15
TumorModel.py CHANGED
@@ -2,21 +2,34 @@ import torch.nn as nn
2
 
3
  class TumorClassification(nn.Module):
4
  def __init__(self):
5
- super().__init__()
6
- self.model = nn.Sequential(
7
- nn.Conv2d(1, 16, 3, 1, 1),
8
- nn.ReLU(),
9
- nn.MaxPool2d(2),
10
- nn.Conv2d(16, 32, 3, 1, 1),
11
- nn.ReLU(),
12
- nn.MaxPool2d(2),
13
- nn.Flatten(),
14
- nn.Linear(32 * 56 * 56, 128),
15
- nn.ReLU(),
16
- nn.Linear(128, 4)
17
- )
 
 
 
 
 
 
 
18
  def forward(self, x):
19
- return self.model(x)
 
 
 
 
 
 
20
 
21
 
22
  class GliomaStageModel(nn.Module):
@@ -28,7 +41,7 @@ class GliomaStageModel(nn.Module):
28
  self.relu2 = nn.ReLU()
29
  self.fc3 = nn.Linear(50, 30)
30
  self.relu3 = nn.ReLU()
31
- self.out = nn.Linear(30, 2) # only 2 classes in your .pth
32
 
33
  def forward(self, x):
34
  x = self.relu1(self.fc1(x))
 
2
 
3
  class TumorClassification(nn.Module):
4
  def __init__(self):
5
+ super(TumorClassification, self).__init__()
6
+ self.con1d = nn.Conv2d(1, 32, kernel_size=3, padding=1)
7
+ self.relu1 = nn.ReLU()
8
+ self.pool1 = nn.MaxPool2d(2)
9
+
10
+ self.con2d = nn.Conv2d(32, 64, kernel_size=3, padding=1)
11
+ self.relu2 = nn.ReLU()
12
+ self.pool2 = nn.MaxPool2d(2)
13
+
14
+ self.con3d = nn.Conv2d(64, 128, kernel_size=3, padding=1)
15
+ self.relu3 = nn.ReLU()
16
+ self.pool3 = nn.MaxPool2d(2)
17
+
18
+ self.flatten = nn.Flatten()
19
+ self.fc1 = nn.Linear(128 * 28 * 28, 512)
20
+ self.relu4 = nn.ReLU()
21
+ self.fc2 = nn.Linear(512, 256)
22
+ self.relu5 = nn.ReLU()
23
+ self.output = nn.Linear(256, 4)
24
+
25
  def forward(self, x):
26
+ x = self.pool1(self.relu1(self.con1d(x)))
27
+ x = self.pool2(self.relu2(self.con2d(x)))
28
+ x = self.pool3(self.relu3(self.con3d(x)))
29
+ x = self.flatten(x)
30
+ x = self.relu4(self.fc1(x))
31
+ x = self.relu5(self.fc2(x))
32
+ return self.output(x)
33
 
34
 
35
  class GliomaStageModel(nn.Module):
 
41
  self.relu2 = nn.ReLU()
42
  self.fc3 = nn.Linear(50, 30)
43
  self.relu3 = nn.ReLU()
44
+ self.out = nn.Linear(30, 2)
45
 
46
  def forward(self, x):
47
  x = self.relu1(self.fc1(x))