rhendz commited on
Commit
23be06c
·
1 Parent(s): 9d8457e

Upload modeling_spice_cnn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_spice_cnn.py +49 -0
modeling_spice_cnn.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ # from torchsummary import summary
4
+
5
+ from transformers import PreTrainedModel
6
+
7
+ from .configuration_spice_cnn import SpiceCNNConfig
8
+
9
+
10
+ class SpiceCNNModelForImageClassification(PreTrainedModel):
11
+ config_class = SpiceCNNConfig
12
+
13
+ def __init__(self, config: SpiceCNNConfig):
14
+ super().__init__(config)
15
+ layers = [
16
+ nn.Conv2d(
17
+ config.in_channels, 16, kernel_size=config.kernel_size, padding=1
18
+ ),
19
+ nn.BatchNorm2d(16),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(kernel_size=config.pooling_size),
22
+ nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1),
23
+ nn.BatchNorm2d(32),
24
+ nn.ReLU(),
25
+ nn.MaxPool2d(kernel_size=config.pooling_size),
26
+ nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1),
27
+ nn.BatchNorm2d(64),
28
+ nn.ReLU(),
29
+ nn.MaxPool2d(kernel_size=config.pooling_size),
30
+ nn.Flatten(),
31
+ nn.Linear(64 * 3 * 3, 128),
32
+ nn.ReLU(),
33
+ nn.Dropout(0.5),
34
+ nn.Linear(128, config.num_classes),
35
+ ]
36
+ self.model = nn.Sequential(*layers)
37
+
38
+ def forward(self, tensor, labels=None):
39
+ logits = self.model(tensor)
40
+ if labels is not None:
41
+ loss_fnc = nn.CrossEntropyLoss()
42
+ loss = loss_fnc(logits, labels)
43
+ return {"loss": loss, "logits": logits}
44
+ return {"logits": logits}
45
+
46
+
47
+ # config = SpiceCNNConfig(in_channels=1)
48
+ # cnn = SpiceCNNModelForImageClassification(config)
49
+ # summary(cnn, (1,28,28))