nssharmaofficial's picture
Remove unused lines
10385fa
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
"""
Convolutional Neural Network (CNN) for classifying 'normal' and 'red' eye images.
The network consists of four convolutional layers followed by two fully connected layers.
Each convolutional layer is followed by batch normalization and a LeakyReLU activation function.
A dropout layer is added before the final fully connected layer to prevent overfitting.
Attributes:
conv1 (nn.Sequential): First convolutional layer block.
conv2 (nn.Sequential): Second convolutional layer block.
conv3 (nn.Sequential): Third convolutional layer block.
conv4 (nn.Sequential): Fourth convolutional layer block.
fc1 (nn.Linear): First fully connected layer.
fc2 (nn.Linear): Second fully connected layer (output layer).
dropout (nn.Dropout): Dropout layer with a probability of 0.5.
"""
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 8, 4, stride=2, padding=1),
nn.BatchNorm2d(8),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, 4, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, 4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True)
)
self.fc1 = nn.Linear(64 * 2 * 2, 32)
self.fc2 = nn.Linear(32, 2)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
"""
Defines the forward pass of the CNN.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, 3, 32, 32).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 2).
"""
# print('\nOriginal: ', x.size())
x = self.conv1(x)
# print('Conv1: ', x.size())
x = self.conv2(x)
# print('Conv2: ', x.size())
x = self.conv3(x)
# print('Conv3: ', x.size())
x = self.conv4(x)
# print('Conv4: ', x.size())
x = x.view(x.size(0), -1)
x = F.leaky_relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
# print('Out: ', x.size())
return F.log_softmax(x, dim=1)