File size: 3,857 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
import os
from tqdm import tqdm
import json

from dataclasses import dataclass
from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_roberta import RobertaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence

from torch.utils.data.dataset import Dataset
import random

logger = logging.getLogger(__name__)

class MultiDDataset(Dataset):
  """
  Dataset for training task: SQL (+ schema) -> text
  """
  def __init__(self, tokenizer: PreTrainedTokenizer, file_path, block_size, local_rank=-1):
    assert os.path.isfile(file_path)
    logger.info("Creating features from dataset file at {}".format(file_path))

    self.examples = []
    total, valid = 0, 0
    add_prefix_space = isinstance(tokenizer, BartTokenizer) or isinstance(tokenizer, RobertaTokenizer)
    with open(file_path, encoding="utf-8") as f:
      for line in tqdm(f):
        total += 1
        example = json.loads(line)

        sql = " ".join(example["sql"].split()).lower()
        text = example["question"].strip().lower()

        text_tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=add_prefix_space) + [tokenizer.sep_token]
        sql_tokens = [tokenizer.cls_token] + tokenizer.tokenize(sql, add_prefix_space=add_prefix_space) + [tokenizer.sep_token]

        text_token_ids = tokenizer.convert_tokens_to_ids(text_tokens)
        sql_token_ids = tokenizer.convert_tokens_to_ids(sql_tokens)
        if len(text_token_ids) > 800 or len(sql_token_ids) > 800:
          continue

        self.examples.append({
          "text_token_ids": text_token_ids,
          "sql_token_ids": sql_token_ids})
    logger.info("Total {} examples.".format(total))

  def __len__(self):
    return len(self.examples)

  def __getitem__(self, i):
    return self.examples[i]

@dataclass
class DataCollatorForMultiD:
  """

  """
  tokenizer: PreTrainedTokenizer
  bi_direc: bool = False

  def __post_init__(self):
    # self.nl_token_id = self.tokenizer.convert_tokens_to_ids(["<nl>"])[0]
    # self.sql_token_id = self.tokenizer.convert_tokens_to_ids(["<sql>"])[0]
    # self.label_bos_id = [self.nl_token_id, self.sql_token_id]# self.tokenizer.cls_token_id
    self.label_eos_id = self.tokenizer.sep_token_id
    self.label_bos_id = self.tokenizer.cls_token_id


  def collate_batch(self, examples):
    text_ids_sequences = [example["text_token_ids"] for example in examples]
    sql_ids_sequences = [example["sql_token_ids"] for example in examples]

    padded_text_ids_tensor = pad_and_tensorize_sequence(
      text_ids_sequences, padding_value=self.tokenizer.pad_token_id)

    padded_sql_ids_tensor = pad_and_tensorize_sequence(
      sql_ids_sequences, padding_value=self.tokenizer.pad_token_id)

    if self.bi_direc:
      if random.random() < 0.5:
        return {
          "input_ids": padded_sql_ids_tensor,
          "labels": padded_text_ids_tensor,
          "pad_token_id": self.tokenizer.pad_token_id,
          "label_eos_id": self.label_eos_id,
          "label_bos_id": self.label_bos_id,
          "label_padding_id": self.tokenizer.pad_token_id
        }

      else:
        return {
          "input_ids": padded_text_ids_tensor,
          "labels": padded_sql_ids_tensor,
          "pad_token_id": self.tokenizer.pad_token_id,
          "label_eos_id": self.label_eos_id,
          "label_bos_id": self.label_bos_id,
          "label_padding_id": self.tokenizer.pad_token_id
        }
    else:
      return {
        "input_ids": padded_text_ids_tensor,
        "labels": padded_sql_ids_tensor,
        "pad_token_id": self.tokenizer.pad_token_id,
        "label_eos_id": self.label_eos_id,
        "label_bos_id": self.label_bos_id,
        "label_padding_id": self.tokenizer.pad_token_id
      }