Codewithsalty commited on
Commit
e133eb4
·
verified ·
1 Parent(s): 86b180a

Update TumorModel.py

Browse files
Files changed (1) hide show
  1. TumorModel.py +14 -11
TumorModel.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch.nn as nn
2
 
3
- # 🎯 Tumor Type Classification Model
4
  class TumorClassification(nn.Module):
5
  def __init__(self):
6
  super(TumorClassification, self).__init__()
@@ -17,7 +17,7 @@ class TumorClassification(nn.Module):
17
  self.pool3 = nn.MaxPool2d(2)
18
 
19
  self.flatten = nn.Flatten()
20
- self.fc1 = nn.Linear(86528, 512) # fixed to match pth
21
  self.relu_fc = nn.ReLU()
22
  self.fc2 = nn.Linear(512, 256)
23
  self.relu_fc2 = nn.ReLU()
@@ -32,17 +32,20 @@ class TumorClassification(nn.Module):
32
  x = self.relu_fc2(self.fc2(x))
33
  return self.output(x)
34
 
35
- # 🧬 Glioma Stage Classifier Model
36
  class GliomaStageModel(nn.Module):
37
  def __init__(self):
38
  super(GliomaStageModel, self).__init__()
39
- self.model = nn.Sequential(
40
- nn.Linear(9, 128),
41
- nn.ReLU(),
42
- nn.Linear(128, 64),
43
- nn.ReLU(),
44
- nn.Linear(64, 4)
45
- )
46
 
47
  def forward(self, x):
48
- return self.model(x)
 
 
 
 
1
  import torch.nn as nn
2
 
3
+ # 🧠 Tumor Type Classification Model
4
  class TumorClassification(nn.Module):
5
  def __init__(self):
6
  super(TumorClassification, self).__init__()
 
17
  self.pool3 = nn.MaxPool2d(2)
18
 
19
  self.flatten = nn.Flatten()
20
+ self.fc1 = nn.Linear(86528, 512) # Adjust this number to match your original
21
  self.relu_fc = nn.ReLU()
22
  self.fc2 = nn.Linear(512, 256)
23
  self.relu_fc2 = nn.ReLU()
 
32
  x = self.relu_fc2(self.fc2(x))
33
  return self.output(x)
34
 
35
+ # 🧬 Glioma Stage Prediction Model (MATCHES `glioma_stages.pth`)
36
  class GliomaStageModel(nn.Module):
37
  def __init__(self):
38
  super(GliomaStageModel, self).__init__()
39
+ self.fc1 = nn.Linear(9, 128)
40
+ self.relu1 = nn.ReLU()
41
+ self.fc2 = nn.Linear(128, 64)
42
+ self.relu2 = nn.ReLU()
43
+ self.fc3 = nn.Linear(64, 32)
44
+ self.relu3 = nn.ReLU()
45
+ self.out = nn.Linear(32, 4)
46
 
47
  def forward(self, x):
48
+ x = self.relu1(self.fc1(x))
49
+ x = self.relu2(self.fc2(x))
50
+ x = self.relu3(self.fc3(x))
51
+ return self.out(x)