File size: 1,989 Bytes
b367dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import PretrainedConfig

class ProtoConfig(PretrainedConfig):
    model_type = "proto"

    def __init__(self,
                 pretrained_model_name_or_path="xlm-roberta-base",
                 num_classes=10,
                 label_order_path=None,
                 use_sigmoid=False,
                 use_cuda=True,
                 lr_prototypes=5e-2,
                 lr_features=2e-6,
                 lr_others=2e-2,
                 num_training_steps=5000,
                 num_warmup_steps=1000,
                 loss='BCE',
                 save_dir='output',
                 use_attention=True,
                 dot_product=False,
                 normalize=None,
                 final_layer=False,
                 reduce_hidden_size=None,
                 use_prototype_loss=False,
                 prototype_vector_path=None,
                 attention_vector_path=None,
                 eval_buckets=None,
                 seed=7,
                 **kwargs):
        super().__init__(**kwargs)

        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.num_classes = num_classes
        self.label_order_path = label_order_path
        self.use_sigmoid = use_sigmoid
        self.use_cuda = use_cuda
        self.lr_prototypes = lr_prototypes
        self.lr_features = lr_features
        self.lr_others = lr_others
        self.num_training_steps = num_training_steps
        self.num_warmup_steps = num_warmup_steps
        self.loss = loss
        self.save_dir = save_dir
        self.use_attention = use_attention
        self.dot_product = dot_product
        self.normalize = normalize
        self.final_layer = final_layer
        self.reduce_hidden_size = reduce_hidden_size
        self.use_prototype_loss = use_prototype_loss
        self.prototype_vector_path = prototype_vector_path
        self.attention_vector_path = attention_vector_path
        self.eval_buckets = eval_buckets
        self.seed = seed