File size: 7,141 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from torch.utils.data.dataset import Dataset
from transformers.tokenization_utils import PreTrainedTokenizer
from tqdm import tqdm
import json
from dataclasses import dataclass
import torch
from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence
import random

class TaBARTDataset(Dataset):
  """
  This dataset is used for pretraining task on generation-based or retrieval-based
  text-schema pair examples.
  The fields that will be used is `question`, `table_info.header`, `entities`.
  We already make sure that every entity in `entities` will be in `table_info.header`.
  """
  def __init__(self,
               tokenizer: PreTrainedTokenizer,
               file_path: str,
               col_token: str):
    self.examples = []
    total = 0
    valid = 0
    with open(file_path, encoding="utf-8") as f:
      for line in tqdm(f):
        total += 1
        example = json.loads(line)
        text = example["question"]
        schema = example["table_info"]["header"]
        tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=True) + [col_token]
        column_spans = []
        start_idx = len(tokens)
        for column in schema:
          column_tokens = tokenizer.tokenize(column.lower(), add_prefix_space=True)
          tokens.extend(column_tokens)
          column_spans.append((start_idx, start_idx + len(column_tokens)))
          tokens.append(col_token)
          start_idx += len(column_tokens) + 1
        # Change last col token to sep token
        tokens[-1] = tokenizer.sep_token
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        entities = example["entities"]
        column_labels = [0] * len(schema)
        for entity in entities:
          if entity != "limit" and entity != "*":
            column_labels[schema.index(entity)] = 1
        if len(input_ids) > 600:
          continue
        self.examples.append({
          "input_ids": input_ids,
          "column_spans": column_spans,
          "column_labels": column_labels
        })
        valid += 1
        # Create input
    print("Total {} and Valid {}".format(total, valid))
  def __len__(self):
    return len(self.examples)

  def __getitem__(self, i):
    return self.examples[i]


@dataclass
class DataCollatorForTaBART:
  tokenizer: PreTrainedTokenizer
  task: str
  mlm_probability: float = 0.35



  def __post_init__(self):
    self.label_bos_id = self.tokenizer.cls_token_id
    self.label_eos_id = self.tokenizer.sep_token_id

  def collate_batch(self, examples):
    input_ids_sequences = [example["input_ids"] for example in examples]
    padded_input_ids_tensor = pad_and_tensorize_sequence(input_ids_sequences,
                                                         padding_value=self.tokenizer.pad_token_id)
    if self.task == "mlm":
      inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
      return {
        "task": "mlm",
        "input_ids": inputs,
        "labels": padded_input_ids_tensor,
        "pad_token_id": self.tokenizer.pad_token_id,
        "label_bos_id": self.tokenizer.bos_token_id,
        "label_eos_id": self.tokenizer.eos_token_id,
        "label_padding_id": self.tokenizer.pad_token_id}
    elif self.task == "col_pred":
      column_labels_sequences = [example["column_labels"] for example in examples]
      padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences,
                                                           padding_value=-100)
      column_spans_sequences = [example["column_spans"] for example in examples]
      padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences,
                                                              padding_value=(0, 1))
      return {
        "task": "col_pred",
        "input_ids": padded_input_ids_tensor,
        "column_spans": padded_column_spans_tensor,
        "labels": padded_label_ids_tensor,
        "pad_token_id": self.tokenizer.pad_token_id}
    elif self.task == "mlm+col_pred":
      if random.random() < 0.6:
        inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
        return {
          "task": "mlm",
          "input_ids": inputs,
          "labels": padded_input_ids_tensor,
          "pad_token_id": self.tokenizer.pad_token_id,
          "label_bos_id": self.tokenizer.bos_token_id,
          "label_eos_id": self.tokenizer.eos_token_id,
          "label_padding_id": self.tokenizer.pad_token_id}
      else:
        column_labels_sequences = [example["column_labels"] for example in examples]
        padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences,
                                                             padding_value=-100)
        column_spans_sequences = [example["column_spans"] for example in examples]
        padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences,
                                                                padding_value=(0, 1))
        return {
          "task": "col_pred",
          "input_ids": padded_input_ids_tensor,
          "column_spans": padded_column_spans_tensor,
          "labels": padded_label_ids_tensor,
          "pad_token_id": self.tokenizer.pad_token_id}

  def mask_tokens(self, inputs):
    """
    Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
    """

    if self.tokenizer.mask_token is None:
      raise ValueError(
        "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
      )

    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, self.mlm_probability)
    special_tokens_mask = [
      self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if self.tokenizer._pad_token is not None:
      padding_mask = labels.eq(self.tokenizer.pad_token_id)
      probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels