File size: 403 Bytes
8b17fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import PreTrainedModel
import torch.nn as nn
from .configuration_simple_model import SimpleNNConfig


# Define the model class
class SimpleNN(PreTrainedModel):
    config_class = SimpleNNConfig

    def __init__(self, config):
        super().__init__(config)
        self.dense = nn.Linear(config.input_size, config.num_classes)

    def forward(self, x):
        return self.dense(x)