antonlabate
ver 1.3
d758c99
from torch.utils.data.dataset import Dataset
from transformers.tokenization_utils import PreTrainedTokenizer
from tqdm import tqdm
import json
from dataclasses import dataclass
import torch
from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence
import random
class TaBARTDataset(Dataset):
"""
This dataset is used for pretraining task on generation-based or retrieval-based
text-schema pair examples.
The fields that will be used is `question`, `table_info.header`, `entities`.
We already make sure that every entity in `entities` will be in `table_info.header`.
"""
def __init__(self,
tokenizer: PreTrainedTokenizer,
file_path: str,
col_token: str):
self.examples = []
total = 0
valid = 0
with open(file_path, encoding="utf-8") as f:
for line in tqdm(f):
total += 1
example = json.loads(line)
text = example["question"]
schema = example["table_info"]["header"]
tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=True) + [col_token]
column_spans = []
start_idx = len(tokens)
for column in schema:
column_tokens = tokenizer.tokenize(column.lower(), add_prefix_space=True)
tokens.extend(column_tokens)
column_spans.append((start_idx, start_idx + len(column_tokens)))
tokens.append(col_token)
start_idx += len(column_tokens) + 1
# Change last col token to sep token
tokens[-1] = tokenizer.sep_token
input_ids = tokenizer.convert_tokens_to_ids(tokens)
entities = example["entities"]
column_labels = [0] * len(schema)
for entity in entities:
if entity != "limit" and entity != "*":
column_labels[schema.index(entity)] = 1
if len(input_ids) > 600:
continue
self.examples.append({
"input_ids": input_ids,
"column_spans": column_spans,
"column_labels": column_labels
})
valid += 1
# Create input
print("Total {} and Valid {}".format(total, valid))
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return self.examples[i]
@dataclass
class DataCollatorForTaBART:
tokenizer: PreTrainedTokenizer
task: str
mlm_probability: float = 0.35
def __post_init__(self):
self.label_bos_id = self.tokenizer.cls_token_id
self.label_eos_id = self.tokenizer.sep_token_id
def collate_batch(self, examples):
input_ids_sequences = [example["input_ids"] for example in examples]
padded_input_ids_tensor = pad_and_tensorize_sequence(input_ids_sequences,
padding_value=self.tokenizer.pad_token_id)
if self.task == "mlm":
inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
return {
"task": "mlm",
"input_ids": inputs,
"labels": padded_input_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id,
"label_bos_id": self.tokenizer.bos_token_id,
"label_eos_id": self.tokenizer.eos_token_id,
"label_padding_id": self.tokenizer.pad_token_id}
elif self.task == "col_pred":
column_labels_sequences = [example["column_labels"] for example in examples]
padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences,
padding_value=-100)
column_spans_sequences = [example["column_spans"] for example in examples]
padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences,
padding_value=(0, 1))
return {
"task": "col_pred",
"input_ids": padded_input_ids_tensor,
"column_spans": padded_column_spans_tensor,
"labels": padded_label_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id}
elif self.task == "mlm+col_pred":
if random.random() < 0.6:
inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
return {
"task": "mlm",
"input_ids": inputs,
"labels": padded_input_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id,
"label_bos_id": self.tokenizer.bos_token_id,
"label_eos_id": self.tokenizer.eos_token_id,
"label_padding_id": self.tokenizer.pad_token_id}
else:
column_labels_sequences = [example["column_labels"] for example in examples]
padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences,
padding_value=-100)
column_spans_sequences = [example["column_spans"] for example in examples]
padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences,
padding_value=(0, 1))
return {
"task": "col_pred",
"input_ids": padded_input_ids_tensor,
"column_spans": padded_column_spans_tensor,
"labels": padded_label_ids_tensor,
"pad_token_id": self.tokenizer.pad_token_id}
def mask_tokens(self, inputs):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels