rhendz commited on
Commit
ea69011
·
1 Parent(s): 8121e82

Upload configuration_spice_cnn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_spice_cnn.py +49 -0
configuration_spice_cnn.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ """Spice CNN model configuration"""
4
+
5
+ SPICE_CNN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
6
+ "spicecloud/spice-cnn-base": "https://huggingface.co/spice-cnn-base/resolve/main/config.json"
7
+ }
8
+
9
+
10
+ # Define custom convnet configuration
11
+ class SpiceCNNConfig(PretrainedConfig):
12
+ """
13
+ This is the configuration class to store the configuration of a [`SpiceCNNModel`].
14
+ It is used to instantiate an SpiceCNN model according to the specified arguments,
15
+ defining the model architecture. Instantiating a configuration with the defaults
16
+ will yield a similar configuration to that of the SpiceCNN
17
+ [spicecloud/spice-cnn-base](https://huggingface.co/spicecloud/spice-cnn-base)
18
+ architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control
21
+ the model outputs. Read the documentation from [`PretrainedConfig`] for more
22
+ information.
23
+ """
24
+
25
+ model_type = "spicecnn"
26
+
27
+ def __init__(
28
+ self,
29
+ in_channels: int = 3,
30
+ num_classes: int = 10,
31
+ dropout_rate: float = 0.4,
32
+ hidden_size: int = 128,
33
+ num_filters: int = 16,
34
+ kernel_size: int = 3,
35
+ stride: int = 1,
36
+ padding: int = 1,
37
+ pooling_size: int = 2,
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.in_channels = in_channels
42
+ self.num_classes = num_classes
43
+ self.dropout_rate = dropout_rate
44
+ self.hidden_size = hidden_size
45
+ self.num_filters = num_filters
46
+ self.kernel_size = kernel_size
47
+ self.stride = stride
48
+ self.padding = padding
49
+ self.pooling_size = pooling_size