File size: 403 Bytes
8b17fbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from transformers import PreTrainedModel
import torch.nn as nn
from .configuration_simple_model import SimpleNNConfig
# Define the model class
class SimpleNN(PreTrainedModel):
config_class = SimpleNNConfig
def __init__(self, config):
super().__init__(config)
self.dense = nn.Linear(config.input_size, config.num_classes)
def forward(self, x):
return self.dense(x) |