File size: 377 Bytes
4c19b30
 
 
 
 
6a6ee6a
4c19b30
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import PretrainedConfig
from typing import List


class TestConfig(PretrainedConfig):
    model_type = "my_test_model"

    def __init__(

        self,

        input_dim: int = 20,

        output_dim: int = 10,

        **kwargs,

    ):
        self.input_dim = input_dim
        self.output_dim = output_dim
        super().__init__(**kwargs)