|
import torch.nn as nn |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .configuration_spice_cnn import SpiceCNNConfig |
|
|
|
|
|
class SpiceCNNModelForImageClassification(PreTrainedModel): |
|
config_class = SpiceCNNConfig |
|
|
|
def __init__(self, config: SpiceCNNConfig): |
|
super().__init__(config) |
|
layers = [ |
|
nn.Conv2d( |
|
config.in_channels, |
|
32, |
|
kernel_size=3, |
|
), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(32), |
|
nn.Conv2d(32, 32, kernel_size=3), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(32), |
|
nn.Conv2d(32, 32, kernel_size=5, stride=2), |
|
nn.ReLU(), |
|
nn.Dropout(0.4), |
|
|
|
nn.Conv2d(32, 64, kernel_size=3), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(64), |
|
nn.Conv2d(64, 64, kernel_size=3), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(64), |
|
nn.Conv2d(64, 64, kernel_size=5, stride=2), |
|
nn.ReLU(), |
|
nn.BatchNorm2d(64), |
|
nn.Dropout(0.4), |
|
|
|
nn.Flatten(), |
|
nn.Linear(64 * 28 * 28, 128), |
|
nn.BatchNorm1d(128), |
|
nn.Dropout(0.4), |
|
nn.Linear(128, config.num_classes), |
|
] |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, tensor, labels=None): |
|
logits = self.model(tensor) |
|
if labels is not None: |
|
loss_fnc = nn.CrossEntropyLoss() |
|
loss = loss_fnc(logits, labels) |
|
return {"loss": loss, "logits": logits} |
|
return {"logits": logits} |
|
|