rhendz commited on
Commit
cb3238b
·
1 Parent(s): b21a5a9

Upload model

Browse files
Files changed (2) hide show
  1. modeling_spice_cnn.py +14 -4
  2. pytorch_model.bin +2 -2
modeling_spice_cnn.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch.nn as nn
 
2
 
3
  from transformers import PreTrainedModel
4
 
5
- from .configuration_spice_cnn import SpiceCNNConfig
6
-
7
 
8
  class SpiceCNNModelForImageClassification(PreTrainedModel):
9
  config_class = SpiceCNNConfig
@@ -21,7 +21,13 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
21
  nn.ReLU(),
22
  nn.MaxPool2d(kernel_size=config.pooling_size),
23
 
24
- nn.Linear(7*7*32, 128),
 
 
 
 
 
 
25
  nn.ReLU(),
26
  nn.Dropout(0.5),
27
  nn.Linear(128, config.num_classes)
@@ -34,4 +40,8 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
34
  loss_fnc = nn.CrossEntropyLoss()
35
  loss = loss_fnc(logits, labels)
36
  return {"loss": loss, "logits": logits}
37
- return {"logits": logits}
 
 
 
 
 
1
  import torch.nn as nn
2
+ # from torchsummary import summary
3
 
4
  from transformers import PreTrainedModel
5
 
6
+ from hf_models.models.spice_cnn.configuration_spice_cnn import SpiceCNNConfig
 
7
 
8
  class SpiceCNNModelForImageClassification(PreTrainedModel):
9
  config_class = SpiceCNNConfig
 
21
  nn.ReLU(),
22
  nn.MaxPool2d(kernel_size=config.pooling_size),
23
 
24
+ nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1),
25
+ nn.BatchNorm2d(64),
26
+ nn.ReLU(),
27
+ nn.MaxPool2d(kernel_size=config.pooling_size),
28
+
29
+ nn.Flatten(),
30
+ nn.Linear(64*3*3, 128),
31
  nn.ReLU(),
32
  nn.Dropout(0.5),
33
  nn.Linear(128, config.num_classes)
 
40
  loss_fnc = nn.CrossEntropyLoss()
41
  loss = loss_fnc(logits, labels)
42
  return {"loss": loss, "logits": logits}
43
+ return {"logits": logits}
44
+
45
+ # config = SpiceCNNConfig(in_channels=1)
46
+ # cnn = SpiceCNNModelForImageClassification(config)
47
+ # summary(cnn, (1,28,28))
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1d40410140194428e35927afeb15389e18b894f35bd109991e281597b3938623
3
- size 833767
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:548d7b0aa1d5b69f63f128dfa0d9d343bd9ca053a29c2dae2114d5b6f2d7b5c6
3
+ size 402812