import torch.nn as nn from transformers import PreTrainedModel from .configuration_spice_cnn import SpiceCNNConfig class SpiceCNNModelForImageClassification(PreTrainedModel): config_class = SpiceCNNConfig def __init__(self, config: SpiceCNNConfig): super().__init__(config) layers = [ nn.Conv2d( config.in_channels, 32, kernel_size=3, ), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(32, 32, kernel_size=3), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(32, 32, kernel_size=5, stride=2), nn.ReLU(), nn.Dropout(0.4), nn.Conv2d(32, 64, kernel_size=3), nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(64, 64, kernel_size=3), nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(64, 64, kernel_size=5, stride=2), nn.ReLU(), nn.BatchNorm2d(64), nn.Dropout(0.4), nn.Flatten(), nn.Linear(64 * 28 * 28, 128), nn.BatchNorm1d(128), nn.Dropout(0.4), nn.Linear(128, config.num_classes), ] self.model = nn.Sequential(*layers) def forward(self, tensor, labels=None): logits = self.model(tensor) if labels is not None: loss_fnc = nn.CrossEntropyLoss() loss = loss_fnc(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits}