File size: 580 Bytes
233e6e6 54159da 233e6e6 0bd2cee 233e6e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
###import os,sys
###sys.path.insert(1,os.path.join(sys.path[0],".."))
from .network import Net
from .config import MNIST_config
from transformers import PreTrainedModel
# utils not used but importing it forces the upload to huggingface hub to include it
class MNIST_Classifier(PreTrainedModel):
config_class = MNIST_config
def __init__(self, config):
super().__init__(config)
self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
config.output_size)
def forward(self, input):
return self.classifier(input)
|