|
import logging |
|
import os |
|
from tqdm import tqdm |
|
import json |
|
|
|
from dataclasses import dataclass |
|
from transformers.tokenization_bart import BartTokenizer |
|
from transformers.tokenization_roberta import RobertaTokenizer |
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence |
|
|
|
from torch.utils.data.dataset import Dataset |
|
import random |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class MultiDDataset(Dataset): |
|
""" |
|
Dataset for training task: SQL (+ schema) -> text |
|
""" |
|
def __init__(self, tokenizer: PreTrainedTokenizer, file_path, block_size, local_rank=-1): |
|
assert os.path.isfile(file_path) |
|
logger.info("Creating features from dataset file at {}".format(file_path)) |
|
|
|
self.examples = [] |
|
total, valid = 0, 0 |
|
add_prefix_space = isinstance(tokenizer, BartTokenizer) or isinstance(tokenizer, RobertaTokenizer) |
|
with open(file_path, encoding="utf-8") as f: |
|
for line in tqdm(f): |
|
total += 1 |
|
example = json.loads(line) |
|
|
|
sql = " ".join(example["sql"].split()).lower() |
|
text = example["question"].strip().lower() |
|
|
|
text_tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=add_prefix_space) + [tokenizer.sep_token] |
|
sql_tokens = [tokenizer.cls_token] + tokenizer.tokenize(sql, add_prefix_space=add_prefix_space) + [tokenizer.sep_token] |
|
|
|
text_token_ids = tokenizer.convert_tokens_to_ids(text_tokens) |
|
sql_token_ids = tokenizer.convert_tokens_to_ids(sql_tokens) |
|
if len(text_token_ids) > 800 or len(sql_token_ids) > 800: |
|
continue |
|
|
|
self.examples.append({ |
|
"text_token_ids": text_token_ids, |
|
"sql_token_ids": sql_token_ids}) |
|
logger.info("Total {} examples.".format(total)) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, i): |
|
return self.examples[i] |
|
|
|
@dataclass |
|
class DataCollatorForMultiD: |
|
""" |
|
|
|
""" |
|
tokenizer: PreTrainedTokenizer |
|
bi_direc: bool = False |
|
|
|
def __post_init__(self): |
|
|
|
|
|
|
|
self.label_eos_id = self.tokenizer.sep_token_id |
|
self.label_bos_id = self.tokenizer.cls_token_id |
|
|
|
|
|
def collate_batch(self, examples): |
|
text_ids_sequences = [example["text_token_ids"] for example in examples] |
|
sql_ids_sequences = [example["sql_token_ids"] for example in examples] |
|
|
|
padded_text_ids_tensor = pad_and_tensorize_sequence( |
|
text_ids_sequences, padding_value=self.tokenizer.pad_token_id) |
|
|
|
padded_sql_ids_tensor = pad_and_tensorize_sequence( |
|
sql_ids_sequences, padding_value=self.tokenizer.pad_token_id) |
|
|
|
if self.bi_direc: |
|
if random.random() < 0.5: |
|
return { |
|
"input_ids": padded_sql_ids_tensor, |
|
"labels": padded_text_ids_tensor, |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
"label_eos_id": self.label_eos_id, |
|
"label_bos_id": self.label_bos_id, |
|
"label_padding_id": self.tokenizer.pad_token_id |
|
} |
|
|
|
else: |
|
return { |
|
"input_ids": padded_text_ids_tensor, |
|
"labels": padded_sql_ids_tensor, |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
"label_eos_id": self.label_eos_id, |
|
"label_bos_id": self.label_bos_id, |
|
"label_padding_id": self.tokenizer.pad_token_id |
|
} |
|
else: |
|
return { |
|
"input_ids": padded_text_ids_tensor, |
|
"labels": padded_sql_ids_tensor, |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
"label_eos_id": self.label_eos_id, |
|
"label_bos_id": self.label_bos_id, |
|
"label_padding_id": self.tokenizer.pad_token_id |
|
} |