alperugurcan commited on
Commit
d302107
·
verified ·
1 Parent(s): 1e860b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -8,15 +8,19 @@ class IcebergClassifier(nn.Module):
8
  def __init__(self):
9
  super().__init__()
10
  self.conv = nn.Sequential(
11
- nn.Conv2d(2,16,3), nn.ReLU(), nn.MaxPool2d(2),
12
- nn.Conv2d(16,32,3), nn.ReLU(), nn.MaxPool2d(2)
 
13
  )
14
  self.fc = nn.Sequential(
15
- nn.Linear(32*17*17,64), nn.ReLU(),
16
- nn.Linear(64,1), nn.Sigmoid()
 
 
17
  )
 
18
  def forward(self, x):
19
- return self.fc(self.conv(x).view(x.size(0),-1))
20
 
21
  @st.cache_resource
22
  def load_model():
 
8
  def __init__(self):
9
  super().__init__()
10
  self.conv = nn.Sequential(
11
+ nn.Conv2d(2, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
12
+ nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
13
+ nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
14
  )
15
  self.fc = nn.Sequential(
16
+ nn.Linear(64 * 9 * 9, 64), nn.ReLU(),
17
+ nn.Dropout(0.5),
18
+ nn.Linear(64, 1),
19
+ nn.Sigmoid()
20
  )
21
+
22
  def forward(self, x):
23
+ return self.fc(self.conv(x).view(x.size(0), -1))
24
 
25
  @st.cache_resource
26
  def load_model():