File size: 1,591 Bytes
3394a77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Optional

from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig


class RelikReaderConfig(PretrainedConfig):
    model_type = "relik-reader"

    def __init__(
        self,
        transformer_model: str = "microsoft/deberta-v3-base",
        additional_special_symbols: int = 101,
        additional_special_symbols_types: Optional[int] = 0,
        num_layers: Optional[int] = None,
        activation: str = "gelu",
        linears_hidden_size: Optional[int] = 512,
        use_last_k_layers: int = 1,
        entity_type_loss: bool = False,
        add_entity_embedding: bool = None,
        training: bool = False,
        default_reader_class: Optional[str] = None,
        **kwargs
    ) -> None:
        # TODO: add name_or_path to kwargs
        self.transformer_model = transformer_model
        self.additional_special_symbols = additional_special_symbols
        self.additional_special_symbols_types = additional_special_symbols_types
        self.num_layers = num_layers
        self.activation = activation
        self.linears_hidden_size = linears_hidden_size
        self.use_last_k_layers = use_last_k_layers
        self.entity_type_loss = entity_type_loss
        self.add_entity_embedding = (
            True
            if add_entity_embedding is None and entity_type_loss
            else add_entity_embedding
        )
        self.training = training
        self.default_reader_class = default_reader_class
        super().__init__(**kwargs)


AutoConfig.register("relik-reader", RelikReaderConfig)