Commit
·
233e6e6
1
Parent(s):
7b91b13
new version
Browse files- config.json +16 -0
- config.py +14 -0
- model.py +17 -0
- model.safetensors +3 -0
config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MNIST_Classifier"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "config.Config",
|
7 |
+
"AutoModel": "model.MNIST_Classifier"
|
8 |
+
},
|
9 |
+
"hidden_size1": 128,
|
10 |
+
"hidden_size2": 64,
|
11 |
+
"input_size": 784,
|
12 |
+
"model_type": "MNIST_Classifier",
|
13 |
+
"output_size": 10,
|
14 |
+
"torch_dtype": "float32",
|
15 |
+
"transformers_version": "4.35.2"
|
16 |
+
}
|
config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
class Config(PretrainedConfig):
|
3 |
+
model_type = "MNIST_Classifier"
|
4 |
+
def __init__(self, **kwargs):
|
5 |
+
super().__init__(**kwargs)
|
6 |
+
|
7 |
+
for key,value in kwargs.items():
|
8 |
+
setattr(self,key,value)
|
9 |
+
#print(key,value)
|
10 |
+
#self.input_size=kwargs['input_size']
|
11 |
+
#self.hidden_size1=kwargs["hidden_size1"]
|
12 |
+
#self.hidden_size2=kwargs["hidden_size2"]
|
13 |
+
#self.output_size=kwargs["output_size"]
|
14 |
+
|
model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###import os,sys
|
2 |
+
###sys.path.insert(1,os.path.join(sys.path[0],".."))
|
3 |
+
from network import Net
|
4 |
+
from config import Config
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
# utils not used but importing it forces the upload to huggingface hub to include it
|
7 |
+
|
8 |
+
|
9 |
+
class MNIST_Classifier(PreTrainedModel):
|
10 |
+
config_class = Config
|
11 |
+
def __init__(self, config):
|
12 |
+
super().__init__(config)
|
13 |
+
self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
|
14 |
+
config.output_size)
|
15 |
+
|
16 |
+
def forward(self, input):
|
17 |
+
return self.classifier(input)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:be564e6f102c09fc41e0f6dc4ed302dab8828457c67964eb698091c8795580a0
|
3 |
+
size 438104
|