eaglelandsonce commited on
Commit
7ff3f27
·
verified ·
1 Parent(s): 2e0ae2e

Update pages/15_CNN.py

Browse files
Files changed (1) hide show
  1. pages/15_CNN.py +21 -10
pages/15_CNN.py CHANGED
@@ -1,4 +1,4 @@
1
- import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
@@ -47,20 +47,31 @@ class CNN(nn.Module):
47
  nn.BatchNorm2d(64),
48
  nn.ReLU(),
49
  nn.MaxPool2d(2))
50
- self.fc1 = nn.Linear(6*6*64, 600)
 
 
 
 
 
51
  self.drop = nn.Dropout2d(0.25)
52
  self.fc2 = nn.Linear(600, 100)
53
  self.fc3 = nn.Linear(100, 10)
54
 
 
 
 
 
 
 
 
55
  def forward(self, x):
56
- out = self.layer1(x)
57
- out = self.layer2(out)
58
- out = out.view(out.size(0), -1)
59
- out = self.fc1(out)
60
- out = self.drop(out)
61
- out = self.fc2(out)
62
- out = self.fc3(out)
63
- return out
64
 
65
  model = CNN().to(device)
66
 
 
1
+ ]import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
 
47
  nn.BatchNorm2d(64),
48
  nn.ReLU(),
49
  nn.MaxPool2d(2))
50
+
51
+ # Automatically determine the size of the flattened features after convolution and pooling
52
+ self._to_linear = None
53
+ self.convs(torch.randn(1, 3, 32, 32))
54
+
55
+ self.fc1 = nn.Linear(self._to_linear, 600)
56
  self.drop = nn.Dropout2d(0.25)
57
  self.fc2 = nn.Linear(600, 100)
58
  self.fc3 = nn.Linear(100, 10)
59
 
60
+ def convs(self, x):
61
+ x = self.layer1(x)
62
+ x = self.layer2(x)
63
+ if self._to_linear is None:
64
+ self._to_linear = x.view(x.size(0), -1).shape[1]
65
+ return x
66
+
67
  def forward(self, x):
68
+ x = self.convs(x)
69
+ x = x.view(x.size(0), -1)
70
+ x = self.fc1(x)
71
+ x = self.drop(x)
72
+ x = self.fc2(x)
73
+ x = self.fc3(x)
74
+ return x
 
75
 
76
  model = CNN().to(device)
77