MNIST_Classifier / model.py
hikmatfarhat's picture
new version
54159da
raw
history blame contribute delete
580 Bytes
###import os,sys
###sys.path.insert(1,os.path.join(sys.path[0],".."))
from .network import Net
from .config import MNIST_config
from transformers import PreTrainedModel
# utils not used but importing it forces the upload to huggingface hub to include it
class MNIST_Classifier(PreTrainedModel):
config_class = MNIST_config
def __init__(self, config):
super().__init__(config)
self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
config.output_size)
def forward(self, input):
return self.classifier(input)