File size: 580 Bytes
233e6e6
 
54159da
 
233e6e6
 
 
 
0bd2cee
233e6e6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
###import os,sys
###sys.path.insert(1,os.path.join(sys.path[0],".."))
from .network import Net
from .config import MNIST_config
from transformers import PreTrainedModel
# utils not used but importing it forces the upload to huggingface hub to include it

class MNIST_Classifier(PreTrainedModel):
    config_class = MNIST_config
    def __init__(self, config):
        super().__init__(config)
        self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
           config.output_size)

    def forward(self, input):
        return self.classifier(input)