File size: 630 Bytes
2c55a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# custom_model.py
from transformers import PreTrainedModel, PretrainedConfig
import torch
import torch.nn as nn

class CustomModelConfig(PretrainedConfig):
    model_type = "custom-model"
    def __init__(self, hidden_size=128, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size

class CustomModel(PreTrainedModel):
    config_class = CustomModelConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.linear = nn.Linear(config.hidden_size, config.hidden_size)
    
    def forward(self, input_ids):
        output = self.linear(input_ids)
        return output