File size: 1,775 Bytes
832f402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import PreTrainedModel
import torch
from .proto import ProtoModule  
from .configuration_proto import ProtoConfig

class ProtoForMultiLabelClassification(PreTrainedModel):
    config_class = ProtoConfig

    def __init__(self, config: ProtoConfig):
        super().__init__(config)
        self.proto_module = ProtoModule(
            pretrained_model=config.pretrained_model_name_or_path,
            num_classes=config.num_classes,
            label_order_path=config.label_order_path,
            use_sigmoid=config.use_sigmoid,
            use_cuda=config.use_cuda,
            lr_prototypes=config.lr_prototypes,
            lr_features=config.lr_features,
            lr_others=config.lr_others,
            num_training_steps=config.num_training_steps,
            num_warmup_steps=config.num_warmup_steps,
            loss=config.loss,
            save_dir=config.save_dir,
            use_attention=config.use_attention,
            dot_product=config.dot_product,
            normalize=config.normalize,
            final_layer=config.final_layer,
            reduce_hidden_size=config.reduce_hidden_size,
            use_prototype_loss=config.use_prototype_loss,
            prototype_vector_path=config.prototype_vector_path,
            attention_vector_path=config.attention_vector_path,
            eval_buckets=config.eval_buckets,
            seed=config.seed
        )
        self.init_weights()

    def forward(self, input_ids, attention_mask, token_type_ids, **kwargs):
        batch = {
            "input_ids": input_ids,
            "attention_masks": attention_mask,
            "token_type_ids": token_type_ids,
        }
        logits, metadata = self.proto_module(batch)
        return {"logits": logits, "metadata": metadata}