rhendz commited on
Commit
d08594a
·
1 Parent(s): 893e284

Upload model

Browse files
Files changed (2) hide show
  1. modeling_spice_cnn.py +13 -24
  2. pytorch_model.bin +2 -2
modeling_spice_cnn.py CHANGED
@@ -11,36 +11,25 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
11
  def __init__(self, config: SpiceCNNConfig):
12
  super().__init__(config)
13
  layers = [
14
- nn.Conv2d(
15
- config.in_channels,
16
- 32,
17
- kernel_size=3,
18
- ),
19
- nn.ReLU(),
20
- nn.BatchNorm2d(32),
21
- nn.Conv2d(32, 32, kernel_size=3),
22
  nn.ReLU(),
 
 
 
23
  nn.BatchNorm2d(32),
24
- nn.Conv2d(32, 32, kernel_size=5, stride=2),
25
- nn.ReLU(),
26
- nn.Dropout(0.4),
27
-
28
- nn.Conv2d(32, 64, kernel_size=3),
29
- nn.ReLU(),
30
- nn.BatchNorm2d(64),
31
- nn.Conv2d(64, 64, kernel_size=3),
32
  nn.ReLU(),
 
 
 
33
  nn.BatchNorm2d(64),
34
- nn.Conv2d(64, 64, kernel_size=5, stride=2),
35
  nn.ReLU(),
36
- nn.BatchNorm2d(64),
37
- nn.Dropout(0.4),
38
-
39
- nn.Flatten(),
40
- nn.BatchNorm1d(64),
41
  nn.ReLU(),
42
- nn.Dropout(0.4),
43
- nn.Linear(64, config.num_classes),
44
  ]
45
  self.model = nn.Sequential(*layers)
46
 
 
11
  def __init__(self, config: SpiceCNNConfig):
12
  super().__init__(config)
13
  layers = [
14
+ nn.Conv2d(config.in_channels, 16, kernel_size=config.kernel_size, padding=1),
15
+ nn.BatchNorm2d(16),
 
 
 
 
 
 
16
  nn.ReLU(),
17
+ nn.MaxPool2d(kernel_size=config.pooling_size),
18
+
19
+ nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1),
20
  nn.BatchNorm2d(32),
 
 
 
 
 
 
 
 
21
  nn.ReLU(),
22
+ nn.MaxPool2d(kernel_size=config.pooling_size),
23
+
24
+ nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1),
25
  nn.BatchNorm2d(64),
 
26
  nn.ReLU(),
27
+ nn.MaxPool2d(kernel_size=2),
28
+
29
+ nn.Linear(64 * 4 * 4, 128),
 
 
30
  nn.ReLU(),
31
+ nn.Dropout(0.5),
32
+ nn.Linear(128, config.num_classes)
33
  ]
34
  self.model = nn.Sequential(*layers)
35
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:46f937126b04976c9954c667a252be6c1ae0e7283a81c620e8ed53d50d7dcded
3
- size 792757
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93eef51c610e5fe5b2e79ceb111c0a2b090f38167f859a2394a1cd65af2bef78
3
+ size 632188