Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch.nn.functional as F | |
class CNN(nn.Module): | |
""" | |
Convolutional Neural Network (CNN) for classifying 'normal' and 'red' eye images. | |
The network consists of four convolutional layers followed by two fully connected layers. | |
Each convolutional layer is followed by batch normalization and a LeakyReLU activation function. | |
A dropout layer is added before the final fully connected layer to prevent overfitting. | |
Attributes: | |
conv1 (nn.Sequential): First convolutional layer block. | |
conv2 (nn.Sequential): Second convolutional layer block. | |
conv3 (nn.Sequential): Third convolutional layer block. | |
conv4 (nn.Sequential): Fourth convolutional layer block. | |
fc1 (nn.Linear): First fully connected layer. | |
fc2 (nn.Linear): Second fully connected layer (output layer). | |
dropout (nn.Dropout): Dropout layer with a probability of 0.5. | |
""" | |
def __init__(self): | |
super(CNN, self).__init__() | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(3, 8, 4, stride=2, padding=1), | |
nn.BatchNorm2d(8), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(8, 16, 4, stride=2, padding=1), | |
nn.BatchNorm2d(16), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
self.conv3 = nn.Sequential( | |
nn.Conv2d(16, 32, 4, stride=2, padding=1), | |
nn.BatchNorm2d(32), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
self.conv4 = nn.Sequential( | |
nn.Conv2d(32, 64, 4, stride=2, padding=1), | |
nn.BatchNorm2d(64), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
self.fc1 = nn.Linear(64 * 2 * 2, 32) | |
self.fc2 = nn.Linear(32, 2) | |
self.dropout = nn.Dropout(0.5) | |
def forward(self, x): | |
""" | |
Defines the forward pass of the CNN. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, 3, 32, 32). | |
Returns: | |
torch.Tensor: Output tensor of shape (batch_size, 2). | |
""" | |
# print('\nOriginal: ', x.size()) | |
x = self.conv1(x) | |
# print('Conv1: ', x.size()) | |
x = self.conv2(x) | |
# print('Conv2: ', x.size()) | |
x = self.conv3(x) | |
# print('Conv3: ', x.size()) | |
x = self.conv4(x) | |
# print('Conv4: ', x.size()) | |
x = x.view(x.size(0), -1) | |
x = F.leaky_relu(self.fc1(x)) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
# print('Out: ', x.size()) | |
return F.log_softmax(x, dim=1) | |