File size: 1,578 Bytes
c1dc7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9a09c
b47d68f
ccdba0d
c1dc7e0
 
ccdba0d
 
c1dc7e0
ccdba0d
 
c1dc7e0
ccdba0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1dc7e0
 
 
 
 
 
 
6d8bbad
 
c1dc7e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch.nn as nn

from transformers import PreTrainedModel

from .configuration_spice_cnn import SpiceCNNConfig


class SpiceCNNModelForImageClassification(PreTrainedModel):
    config_class = SpiceCNNConfig

    def __init__(self, config: SpiceCNNConfig):
        super().__init__(config)
        layers = [
            nn.Conv2d(
                config.in_channels,
                32,
                kernel_size=3,
            ),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout(0.4),
            
            nn.Flatten(),
            nn.Linear(64 * 28 * 28, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.4),
            nn.Linear(128, config.num_classes),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, tensor, labels=None):
        logits = self.model(tensor)
        if labels is not None:
            loss_fnc = nn.CrossEntropyLoss()
            loss = loss_fnc(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}