File size: 5,166 Bytes
c3b1078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel, BertPreTrainedModel
from transformers.modeling_outputs import ModelOutput

import torch


def get_range_vector(size: int, device: int) -> torch.Tensor:
    """
    Returns a range vector with the desired size, starting at 0. The CUDA implementation
    is meant to avoid copy data from CPU to GPU.
    """
    return torch.arange(0, size, dtype=torch.long, device=device)

@dataclass
class Seq2LabelsOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    detect_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    max_error_probability: Optional[torch.FloatTensor] = None


class Seq2LabelsModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.num_detect_classes = config.num_detect_classes
        self.label_smoothing = config.label_smoothing

        if config.load_pretrained:
            self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path)
            bert_config = self.bert.config
        else:
            bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path)
            self.bert = AutoModel.from_config(bert_config)

        if config.special_tokens_fix:
            try:
                vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
            except AttributeError:
                # reserve more space
                vocab_size = self.bert.word_embedding.num_embeddings + 5
            self.bert.resize_token_embeddings(vocab_size + 1)

        predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
        self.dropout = nn.Dropout(predictor_dropout)
        self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
        self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        input_offsets: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        d_tags: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        if input_offsets is not None:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            sequence_output = sequence_output[range_vector, input_offsets]

        logits = self.classifier(self.dropout(sequence_output))
        logits_d = self.detector(sequence_output)

        loss = None
        if labels is not None and d_tags is not None:
            loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
            loss_d_fct = CrossEntropyLoss()
            loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
            loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
            loss = loss_labels + loss_d

        if not return_dict:
            output = (logits, logits_d) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return Seq2LabelsOutput(
            loss=loss,
            logits=logits,
            detect_logits=logits_d,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            max_error_probability=torch.ones(logits.size(0), device=logits.device),
        )