File size: 1,363 Bytes
c1dc7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
92f435d
c1dc7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d8bbad
 
c1dc7e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch.nn as nn

from transformers import PreTrainedModel

from .configuration_spice_cnn import SpiceCNNConfig


class SpiceCNNModelForImageClassification(PreTrainedModel):
    config_class = SpiceCNNConfig

    def __init__(self, config: SpiceCNNConfig):
        super().__init__(config)
        layers = [
            nn.Conv2d(
                3,
                32,
                kernel_size=config.kernel_size,
                stride=config.stride,
                padding=config.padding,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=config.pooling_size),
            nn.Conv2d(
                32,
                64,
                kernel_size=config.kernel_size,
                stride=config.stride,
                padding=config.padding,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=config.pooling_size),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 128),
            nn.ReLU(),
            nn.Linear(128, config.num_classes),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, tensor, labels=None):
        logits = self.model(tensor)
        if labels is not None:
            loss_fnc = nn.CrossEntropyLoss()
            loss = loss_fnc(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}