charlesxsh commited on
Commit
2c55a32
·
1 Parent(s): 5377e9f
Files changed (1) hide show
  1. custom_model.py +21 -0
custom_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_model.py
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class CustomModelConfig(PretrainedConfig):
7
+ model_type = "custom-model"
8
+ def __init__(self, hidden_size=128, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.hidden_size = hidden_size
11
+
12
+ class CustomModel(PreTrainedModel):
13
+ config_class = CustomModelConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.linear = nn.Linear(config.hidden_size, config.hidden_size)
18
+
19
+ def forward(self, input_ids):
20
+ output = self.linear(input_ids)
21
+ return output