Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -56,8 +56,8 @@ class CNN(nn.Module):
|
|
56 |
super(CNN, self).__init__()
|
57 |
self.conv = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1)
|
58 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
59 |
-
#
|
60 |
-
self.fc = nn.Linear(16 * 4 *
|
61 |
|
62 |
def forward(self, x):
|
63 |
x = self.pool(torch.relu(self.conv(x)))
|
@@ -65,6 +65,7 @@ class CNN(nn.Module):
|
|
65 |
x = x.view(x.size(0), -1) # Flatten for the fully connected layer
|
66 |
return torch.sigmoid(self.fc(x))
|
67 |
|
|
|
68 |
class PhiModel(nn.Module):
|
69 |
def __init__(self, input_dim):
|
70 |
super(PhiModel, self).__init__()
|
|
|
56 |
super(CNN, self).__init__()
|
57 |
self.conv = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1)
|
58 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
59 |
+
# Adjust the fully connected layer to accommodate the correct input size
|
60 |
+
self.fc = nn.Linear(16 * 4 * 8, output_dim) # 16 * 4 * 8 = 512
|
61 |
|
62 |
def forward(self, x):
|
63 |
x = self.pool(torch.relu(self.conv(x)))
|
|
|
65 |
x = x.view(x.size(0), -1) # Flatten for the fully connected layer
|
66 |
return torch.sigmoid(self.fc(x))
|
67 |
|
68 |
+
|
69 |
class PhiModel(nn.Module):
|
70 |
def __init__(self, input_dim):
|
71 |
super(PhiModel, self).__init__()
|