File size: 7,141 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from torch.utils.data.dataset import Dataset
from transformers.tokenization_utils import PreTrainedTokenizer
from tqdm import tqdm
import json
from dataclasses import dataclass
import torch
from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence
import random
class TaBARTDataset(Dataset):
"""
This dataset is used for pretraining task on generation-based or retrieval-based
text-schema pair examples.
The fields that will be used is `question`, `table_info.header`, `entities`.
We already make sure that every entity in `entities` will be in `table_info.header`.
"""
def __init__(self,
tokenizer: PreTrainedTokenizer,
file_path: str,
col_token: str):
self.examples = []
total = 0
valid = 0
with open(file_path, encoding="utf-8") as f:
for line in tqdm(f):
total += 1
example = json.loads(line)
text = example["question"]
schema = example["table_info"]["header"]
tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=True) + [col_token]
column_spans = []
start_idx = len(tokens)
for column in schema:
column_tokens = tokenizer.tokenize(column.lower(), add_prefix_space=True)
tokens.extend(column_tokens)
column_spans.append((start_idx, start_idx + len(column_tokens)))
tokens.append(col_token)
start_idx += len(column_tokens) + 1
# Change last col token to sep token
tokens[-1] = tokenizer.sep_token
input_ids = tokenizer.convert_tokens_to_ids(tokens)
entities = example["entities"]
column_labels = [0] * len(schema)
for entity in entities:
if entity != "limit" and entity != "*":
column_labels[schema.index(entity)] = 1
if len(input_ids) > 600:
continue
self.examples.append({
"input_ids": input_ids,
"column_spans": column_spans,
"column_labels": column_labels
})
valid += 1
# Create input
print("Total {} and Valid {}".format(total, valid))
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return self.examples[i]
@dataclass
class DataCollatorForTaBART:
tokenizer: PreTrainedTokenizer
task: str
mlm_probability: float = 0.35
def __post_init__(self):
self.label_bos_id = self.tokenizer.cls_token_id
self.label_eos_id = self.tokenizer.sep_token_id
def collate_batch(self, examples):
input_ids_sequences = [example["input_ids"] for example in examples]
padded_input_ids_tensor = pad_and_tensorize_sequence(input_ids_sequences,
padding_value=self.tokenizer.pad_token_id)
if self.task == "mlm":
inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
return {
"task": "mlm",
"input_ids": inputs,
"labels": padded_input_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id,
"label_bos_id": self.tokenizer.bos_token_id,
"label_eos_id": self.tokenizer.eos_token_id,
"label_padding_id": self.tokenizer.pad_token_id}
elif self.task == "col_pred":
column_labels_sequences = [example["column_labels"] for example in examples]
padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences,
padding_value=-100)
column_spans_sequences = [example["column_spans"] for example in examples]
padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences,
padding_value=(0, 1))
return {
"task": "col_pred",
"input_ids": padded_input_ids_tensor,
"column_spans": padded_column_spans_tensor,
"labels": padded_label_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id}
elif self.task == "mlm+col_pred":
if random.random() < 0.6:
inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
return {
"task": "mlm",
"input_ids": inputs,
"labels": padded_input_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id,
"label_bos_id": self.tokenizer.bos_token_id,
"label_eos_id": self.tokenizer.eos_token_id,
"label_padding_id": self.tokenizer.pad_token_id}
else:
column_labels_sequences = [example["column_labels"] for example in examples]
padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences,
padding_value=-100)
column_spans_sequences = [example["column_spans"] for example in examples]
padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences,
padding_value=(0, 1))
return {
"task": "col_pred",
"input_ids": padded_input_ids_tensor,
"column_spans": padded_column_spans_tensor,
"labels": padded_label_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id}
def mask_tokens(self, inputs):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
|