# models/cnn_model.py import torch.nn as nn class MonkeyCNN(nn.Module): def __init__(self, num_classes): super(MonkeyCNN, self).__init__() self.net = nn.Sequential( # Conv Block 1 nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), # Conv Block 2 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), # Conv Block 3 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), # Conv Block 4 (Optional: add more depth) nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), # Output size: [B, 256, 1, 1] nn.Flatten(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) def forward(self, x): return self.net(x)