|
import collections |
|
import itertools |
|
import json |
|
import os |
|
|
|
import attr |
|
import nltk.corpus |
|
import torch |
|
import torchtext |
|
import numpy as np |
|
|
|
from seq2struct.models import abstract_preproc, transformer |
|
from seq2struct.models.spider import spider_enc_modules |
|
from seq2struct.resources import pretrained_embeddings |
|
from seq2struct.utils import registry |
|
from seq2struct.utils import vocab |
|
from seq2struct.utils import serialization |
|
from seq2struct import resources |
|
from seq2struct.resources import corenlp |
|
from transformers import BertModel, BertTokenizer, BartModel, BartTokenizer |
|
from seq2struct.models.spider.spider_match_utils import ( |
|
compute_schema_linking, |
|
compute_cell_value_linking |
|
) |
|
|
|
|
|
@attr.s |
|
class SpiderEncoderState: |
|
state = attr.ib() |
|
memory = attr.ib() |
|
question_memory = attr.ib() |
|
schema_memory = attr.ib() |
|
words = attr.ib() |
|
|
|
pointer_memories = attr.ib() |
|
pointer_maps = attr.ib() |
|
|
|
m2c_align_mat = attr.ib() |
|
m2t_align_mat = attr.ib() |
|
|
|
def find_word_occurrences(self, word): |
|
return [i for i, w in enumerate(self.words) if w == word] |
|
|
|
|
|
@attr.s |
|
class PreprocessedSchema: |
|
column_names = attr.ib(factory=list) |
|
table_names = attr.ib(factory=list) |
|
table_bounds = attr.ib(factory=list) |
|
column_to_table = attr.ib(factory=dict) |
|
table_to_columns = attr.ib(factory=dict) |
|
foreign_keys = attr.ib(factory=dict) |
|
foreign_keys_tables = attr.ib(factory=lambda: collections.defaultdict(set)) |
|
primary_keys = attr.ib(factory=list) |
|
|
|
|
|
normalized_column_names = attr.ib(factory=list) |
|
normalized_table_names = attr.ib(factory=list) |
|
|
|
def preprocess_schema_uncached(schema, |
|
tokenize_func, |
|
include_table_name_in_column, |
|
fix_issue_16_primary_keys, |
|
bert=False): |
|
"""If it's bert, we also cache the normalized version of |
|
question/column/table for schema linking""" |
|
r = PreprocessedSchema() |
|
|
|
if bert: assert not include_table_name_in_column |
|
|
|
last_table_id = None |
|
for i, column in enumerate(schema.columns): |
|
col_toks = tokenize_func( |
|
column.name, column.unsplit_name) |
|
|
|
|
|
type_tok = '<type: {}>'.format(column.type) |
|
if bert: |
|
|
|
column_name = col_toks + [type_tok] |
|
r.normalized_column_names.append(Bertokens(col_toks)) |
|
else: |
|
column_name = [type_tok] + col_toks |
|
|
|
if include_table_name_in_column: |
|
if column.table is None: |
|
table_name = ['<any-table>'] |
|
else: |
|
table_name = tokenize_func( |
|
column.table.name, column.table.unsplit_name) |
|
column_name += ['<table-sep>'] + table_name |
|
r.column_names.append(column_name) |
|
|
|
table_id = None if column.table is None else column.table.id |
|
r.column_to_table[str(i)] = table_id |
|
if table_id is not None: |
|
columns = r.table_to_columns.setdefault(str(table_id), []) |
|
columns.append(i) |
|
if last_table_id != table_id: |
|
r.table_bounds.append(i) |
|
last_table_id = table_id |
|
|
|
if column.foreign_key_for is not None: |
|
r.foreign_keys[str(column.id)] = column.foreign_key_for.id |
|
r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) |
|
|
|
r.table_bounds.append(len(schema.columns)) |
|
assert len(r.table_bounds) == len(schema.tables) + 1 |
|
|
|
for i, table in enumerate(schema.tables): |
|
table_toks = tokenize_func( |
|
table.name, table.unsplit_name) |
|
r.table_names.append(table_toks) |
|
if bert: |
|
r.normalized_table_names.append(Bertokens(table_toks)) |
|
last_table = schema.tables[-1] |
|
|
|
r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) |
|
r.primary_keys = [ |
|
column.id |
|
for table in schema.tables |
|
for column in table.primary_keys |
|
] if fix_issue_16_primary_keys else [ |
|
column.id |
|
for column in last_table.primary_keys |
|
for table in schema.tables |
|
] |
|
|
|
return r |
|
|
|
class SpiderEncoderV2Preproc(abstract_preproc.AbstractPreproc): |
|
|
|
def __init__( |
|
self, |
|
save_path, |
|
min_freq=3, |
|
max_count=5000, |
|
include_table_name_in_column=True, |
|
word_emb=None, |
|
count_tokens_in_word_emb_for_vocab=False, |
|
|
|
fix_issue_16_primary_keys=False, |
|
compute_sc_link=False, |
|
compute_cv_link=False, |
|
db_path=None): |
|
if word_emb is None: |
|
self.word_emb = None |
|
else: |
|
self.word_emb = registry.construct('word_emb', word_emb) |
|
|
|
self.data_dir = os.path.join(save_path, 'enc') |
|
self.include_table_name_in_column = include_table_name_in_column |
|
self.count_tokens_in_word_emb_for_vocab = count_tokens_in_word_emb_for_vocab |
|
self.fix_issue_16_primary_keys = fix_issue_16_primary_keys |
|
self.compute_sc_link = compute_sc_link |
|
self.compute_cv_link = compute_cv_link |
|
self.texts = collections.defaultdict(list) |
|
|
|
self.db_path = db_path |
|
if self.compute_cv_link: assert self.db_path is not None |
|
|
|
self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) |
|
self.vocab_path = os.path.join(save_path, 'enc_vocab.json') |
|
self.vocab_word_freq_path = os.path.join(save_path, 'enc_word_freq.json') |
|
self.vocab = None |
|
self.counted_db_ids = set() |
|
self.preprocessed_schemas = {} |
|
|
|
|
|
def validate_item(self, item, section): |
|
return True, None |
|
|
|
def add_item(self, item, section, validation_info): |
|
preprocessed = self.preprocess_item(item, validation_info) |
|
self.texts[section].append(preprocessed) |
|
|
|
if section == 'train': |
|
if item.schema.db_id in self.counted_db_ids: |
|
to_count = preprocessed['question'] |
|
else: |
|
self.counted_db_ids.add(item.schema.db_id) |
|
to_count = itertools.chain( |
|
preprocessed['question'], |
|
*preprocessed['columns'], |
|
*preprocessed['tables']) |
|
|
|
for token in to_count: |
|
count_token = ( |
|
self.word_emb is None or |
|
self.count_tokens_in_word_emb_for_vocab or |
|
self.word_emb.lookup(token) is None) |
|
if count_token: |
|
self.vocab_builder.add_word(token) |
|
|
|
def clear_items(self): |
|
self.texts = collections.defaultdict(list) |
|
|
|
def preprocess_item(self, item, validation_info): |
|
question, question_for_copying = self._tokenize_for_copying(item.text, item.orig['question']) |
|
preproc_schema = self._preprocess_schema(item.schema) |
|
if self.compute_sc_link: |
|
assert preproc_schema.column_names[0][0].startswith("<type:") |
|
column_names_without_types = [col[1:] for col in preproc_schema.column_names] |
|
sc_link = compute_schema_linking(question, \ |
|
column_names_without_types, preproc_schema.table_names) |
|
else: |
|
sc_link = {"q_col_match": {}, "q_tab_match": {}} |
|
|
|
if self.compute_cv_link: |
|
cv_link = compute_cell_value_linking(question, item.schema, self.db_path) |
|
else: |
|
cv_link = {"num_date_match": {}, "cell_match": {}} |
|
|
|
return { |
|
'raw_question': item.orig['question'], |
|
'question': question, |
|
'question_for_copying': question_for_copying, |
|
'db_id': item.schema.db_id, |
|
'sc_link': sc_link, |
|
'cv_link': cv_link, |
|
'columns': preproc_schema.column_names, |
|
'tables': preproc_schema.table_names, |
|
'table_bounds': preproc_schema.table_bounds, |
|
'column_to_table': preproc_schema.column_to_table, |
|
'table_to_columns': preproc_schema.table_to_columns, |
|
'foreign_keys': preproc_schema.foreign_keys, |
|
'foreign_keys_tables': preproc_schema.foreign_keys_tables, |
|
'primary_keys': preproc_schema.primary_keys, |
|
} |
|
|
|
def _preprocess_schema(self, schema): |
|
if schema.db_id in self.preprocessed_schemas: |
|
return self.preprocessed_schemas[schema.db_id] |
|
result = preprocess_schema_uncached(schema, self._tokenize, |
|
self.include_table_name_in_column, self.fix_issue_16_primary_keys) |
|
self.preprocessed_schemas[schema.db_id] = result |
|
return result |
|
|
|
def _tokenize(self, presplit, unsplit): |
|
if self.word_emb: |
|
return self.word_emb.tokenize(unsplit) |
|
return presplit |
|
|
|
def _tokenize_for_copying(self, presplit, unsplit): |
|
if self.word_emb: |
|
return self.word_emb.tokenize_for_copying(unsplit) |
|
return presplit, presplit |
|
|
|
def save(self): |
|
os.makedirs(self.data_dir, exist_ok=True) |
|
self.vocab = self.vocab_builder.finish() |
|
print(f"{len(self.vocab)} words in vocab") |
|
self.vocab.save(self.vocab_path) |
|
self.vocab_builder.save(self.vocab_word_freq_path) |
|
|
|
for section, texts in self.texts.items(): |
|
with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: |
|
for text in texts: |
|
f.write(json.dumps(text) + '\n') |
|
|
|
def load(self): |
|
self.vocab = vocab.Vocab.load(self.vocab_path) |
|
self.vocab_builder.load(self.vocab_word_freq_path) |
|
|
|
def dataset(self, section): |
|
return [ |
|
json.loads(line) |
|
for line in open(os.path.join(self.data_dir, section + '.jsonl'))] |
|
|
|
|
|
@registry.register('encoder', 'spiderv2') |
|
class SpiderEncoderV2(torch.nn.Module): |
|
|
|
batched = True |
|
Preproc = SpiderEncoderV2Preproc |
|
|
|
def __init__( |
|
self, |
|
device, |
|
preproc, |
|
word_emb_size=128, |
|
recurrent_size=256, |
|
dropout=0., |
|
question_encoder=('emb', 'bilstm'), |
|
column_encoder=('emb', 'bilstm'), |
|
table_encoder=('emb', 'bilstm'), |
|
update_config={}, |
|
include_in_memory=('question', 'column', 'table'), |
|
batch_encs_update=True, |
|
top_k_learnable = 0): |
|
super().__init__() |
|
self._device = device |
|
self.preproc = preproc |
|
|
|
self.vocab = preproc.vocab |
|
self.word_emb_size = word_emb_size |
|
self.recurrent_size = recurrent_size |
|
assert self.recurrent_size % 2 == 0 |
|
word_freq = self.preproc.vocab_builder.word_freq |
|
top_k_words = set([_a[0] for _a in word_freq.most_common(top_k_learnable)]) |
|
self.learnable_words = top_k_words |
|
|
|
self.include_in_memory = set(include_in_memory) |
|
self.dropout = dropout |
|
|
|
self.question_encoder = self._build_modules(question_encoder) |
|
self.column_encoder = self._build_modules(column_encoder) |
|
self.table_encoder = self._build_modules(table_encoder) |
|
|
|
|
|
update_modules = { |
|
'relational_transformer': |
|
spider_enc_modules.RelationalTransformerUpdate, |
|
'none': |
|
spider_enc_modules.NoOpUpdate, |
|
} |
|
|
|
self.encs_update = registry.instantiate( |
|
update_modules[update_config['name']], |
|
update_config, |
|
unused_keys={"name"}, |
|
device=self._device, |
|
hidden_size=recurrent_size, |
|
) |
|
self.batch_encs_update = batch_encs_update |
|
|
|
def _build_modules(self, module_types): |
|
module_builder = { |
|
'emb': lambda: spider_enc_modules.LookupEmbeddings( |
|
self._device, |
|
self.vocab, |
|
self.preproc.word_emb, |
|
self.word_emb_size, |
|
self.learnable_words), |
|
'linear': lambda: spider_enc_modules.EmbLinear( |
|
input_size=self.word_emb_size, |
|
output_size=self.word_emb_size), |
|
'bilstm': lambda: spider_enc_modules.BiLSTM( |
|
input_size=self.word_emb_size, |
|
output_size=self.recurrent_size, |
|
dropout=self.dropout, |
|
summarize=False), |
|
'bilstm-native': lambda: spider_enc_modules.BiLSTM( |
|
input_size=self.word_emb_size, |
|
output_size=self.recurrent_size, |
|
dropout=self.dropout, |
|
summarize=False, |
|
use_native=True), |
|
'bilstm-summarize': lambda: spider_enc_modules.BiLSTM( |
|
input_size=self.word_emb_size, |
|
output_size=self.recurrent_size, |
|
dropout=self.dropout, |
|
summarize=True), |
|
'bilstm-native-summarize': lambda: spider_enc_modules.BiLSTM( |
|
input_size=self.word_emb_size, |
|
output_size=self.recurrent_size, |
|
dropout=self.dropout, |
|
summarize=True, |
|
use_native=True), |
|
} |
|
|
|
modules = [] |
|
for module_type in module_types: |
|
modules.append(module_builder[module_type]()) |
|
return torch.nn.Sequential(*modules) |
|
|
|
|
|
def forward_unbatched(self, desc): |
|
|
|
|
|
|
|
|
|
|
|
q_enc, (_, _) = self.question_encoder([desc['question']]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
c_enc, c_boundaries = self.column_encoder(desc['columns']) |
|
column_pointer_maps = { |
|
i: list(range(left, right)) |
|
for i, (left, right) in enumerate(zip(c_boundaries, c_boundaries[1:])) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
t_enc, t_boundaries = self.table_encoder(desc['tables']) |
|
c_enc_length = c_enc.shape[0] |
|
table_pointer_maps = { |
|
i: [ |
|
idx |
|
for col in desc['table_to_columns'][str(i)] |
|
for idx in column_pointer_maps[col] |
|
] + list(range(left + c_enc_length, right + c_enc_length)) |
|
for i, (left, right) in enumerate(zip(t_boundaries, t_boundaries[1:])) |
|
} |
|
|
|
|
|
|
|
|
|
q_enc_new, c_enc_new, t_enc_new = self.encs_update( |
|
desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries) |
|
|
|
memory = [] |
|
words_for_copying = [] |
|
if 'question' in self.include_in_memory: |
|
memory.append(q_enc_new) |
|
if 'question_for_copying' in desc: |
|
assert q_enc_new.shape[1] == desc['question_for_copying'] |
|
words_for_copying += desc['question_for_copying'] |
|
else: |
|
words_for_copying += [''] * q_enc_new.shape[1] |
|
if 'column' in self.include_in_memory: |
|
memory.append(c_enc_new) |
|
words_for_copying += [''] * c_enc_new.shape[1] |
|
if 'table' in self.include_in_memory: |
|
memory.append(t_enc_new) |
|
words_for_copying += [''] * t_enc_new.shape[1] |
|
memory = torch.cat(memory, dim=1) |
|
|
|
return SpiderEncoderState( |
|
state=None, |
|
memory=memory, |
|
words=words_for_copying, |
|
pointer_memories={ |
|
'column': c_enc_new, |
|
'table': torch.cat((c_enc_new, t_enc_new), dim=1), |
|
}, |
|
pointer_maps={ |
|
'column': column_pointer_maps, |
|
'table': table_pointer_maps, |
|
} |
|
) |
|
|
|
def forward(self, descs): |
|
|
|
|
|
|
|
|
|
|
|
qs = [[desc['question']] for desc in descs] |
|
q_enc, _ = self.question_encoder(qs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
c_enc, c_boundaries = self.column_encoder([desc['columns'] for desc in descs]) |
|
|
|
column_pointer_maps = [ |
|
{ |
|
i: list(range(left, right)) |
|
for i, (left, right) in enumerate(zip(c_boundaries_for_item, c_boundaries_for_item[1:])) |
|
} |
|
for batch_idx, c_boundaries_for_item in enumerate(c_boundaries) |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
t_enc, t_boundaries = self.table_encoder([desc['tables'] for desc in descs]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
table_pointer_maps = [ |
|
{ |
|
i: list(range(left, right)) |
|
for i, (left, right) in enumerate(zip(t_boundaries_for_item, t_boundaries_for_item[1:])) |
|
} |
|
for batch_idx, (desc, t_boundaries_for_item) in enumerate(zip(descs, t_boundaries)) |
|
] |
|
|
|
|
|
|
|
|
|
if self.batch_encs_update: |
|
q_enc_new, c_enc_new, t_enc_new = self.encs_update( |
|
descs, q_enc, c_enc, c_boundaries, t_enc, t_boundaries) |
|
|
|
result = [] |
|
for batch_idx, desc in enumerate(descs): |
|
if self.batch_encs_update: |
|
q_enc_new_item = q_enc_new.select(batch_idx).unsqueeze(0) |
|
c_enc_new_item = c_enc_new.select(batch_idx).unsqueeze(0) |
|
t_enc_new_item = t_enc_new.select(batch_idx).unsqueeze(0) |
|
else: |
|
q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ |
|
self.encs_update.forward_unbatched( |
|
desc, |
|
q_enc.select(batch_idx).unsqueeze(1), |
|
c_enc.select(batch_idx).unsqueeze(1), |
|
c_boundaries[batch_idx], |
|
t_enc.select(batch_idx).unsqueeze(1), |
|
t_boundaries[batch_idx]) |
|
|
|
memory = [] |
|
words_for_copying = [] |
|
if 'question' in self.include_in_memory: |
|
memory.append(q_enc_new_item) |
|
if 'question_for_copying' in desc: |
|
assert q_enc_new_item.shape[1] == len(desc['question_for_copying']) |
|
words_for_copying += desc['question_for_copying'] |
|
else: |
|
words_for_copying += [''] * q_enc_new_item.shape[1] |
|
if 'column' in self.include_in_memory: |
|
memory.append(c_enc_new_item) |
|
words_for_copying += [''] * c_enc_new_item.shape[1] |
|
if 'table' in self.include_in_memory: |
|
memory.append(t_enc_new_item) |
|
words_for_copying += [''] * t_enc_new_item.shape[1] |
|
memory = torch.cat(memory, dim=1) |
|
|
|
result.append(SpiderEncoderState( |
|
state=None, |
|
memory=memory, |
|
question_memory=q_enc_new_item, |
|
schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
|
|
|
words=words_for_copying, |
|
pointer_memories={ |
|
'column': c_enc_new_item, |
|
'table': torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
|
}, |
|
pointer_maps={ |
|
'column': column_pointer_maps[batch_idx], |
|
'table': table_pointer_maps[batch_idx], |
|
}, |
|
m2c_align_mat=align_mat_item[0], |
|
m2t_align_mat=align_mat_item[1], |
|
)) |
|
return result |
|
|
|
|
|
class Bertokens: |
|
def __init__(self, pieces): |
|
self.pieces = pieces |
|
|
|
self.normalized_pieces = None |
|
self.idx_map = None |
|
|
|
self.normalize_toks() |
|
|
|
def normalize_toks(self): |
|
""" |
|
If the token is not a word piece, then find its lemma |
|
If it is, combine pieces into a word, and then find its lemma |
|
E.g., a ##b ##c will be normalized as "abc", "", "" |
|
NOTE: this is only used for schema linking |
|
""" |
|
self.startidx2pieces = dict() |
|
self.pieces2startidx = dict() |
|
cache_start = None |
|
for i, piece in enumerate(self.pieces + [""]): |
|
if piece.startswith("##"): |
|
if cache_start is None: |
|
cache_start = i - 1 |
|
|
|
self.pieces2startidx[i] = cache_start |
|
self.pieces2startidx[i-1] = cache_start |
|
else: |
|
if cache_start is not None: |
|
self.startidx2pieces[cache_start] = i |
|
cache_start = None |
|
assert cache_start is None |
|
|
|
|
|
combined_word = {} |
|
for start, end in self.startidx2pieces.items(): |
|
assert end - start + 1 < 10 |
|
pieces = [self.pieces[start]] + [self.pieces[_id].strip("##") for _id in range(start+1, end)] |
|
word = "".join(pieces) |
|
combined_word[start] = word |
|
|
|
|
|
idx_map = {} |
|
new_toks = [] |
|
for i, piece in enumerate(self.pieces): |
|
if i in combined_word: |
|
idx_map[len(new_toks)] = i |
|
new_toks.append(combined_word[i]) |
|
elif i in self.pieces2startidx: |
|
|
|
pass |
|
else: |
|
idx_map[len(new_toks)] = i |
|
new_toks.append(piece) |
|
self.idx_map = idx_map |
|
|
|
|
|
normalized_toks = [] |
|
for i, tok in enumerate(new_toks): |
|
ann = corenlp.annotate(tok, annotators = ['tokenize', 'ssplit', 'lemma']) |
|
lemmas = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] |
|
lemma_word = " ".join(lemmas) |
|
normalized_toks.append(lemma_word) |
|
|
|
self.normalized_pieces = normalized_toks |
|
|
|
def bert_schema_linking(self, columns, tables): |
|
question_tokens =self.normalized_pieces |
|
column_tokens = [c.normalized_pieces for c in columns] |
|
table_tokens = [t.normalized_pieces for t in tables] |
|
sc_link = compute_schema_linking(question_tokens, column_tokens, table_tokens) |
|
|
|
new_sc_link = {} |
|
for m_type in sc_link: |
|
_match = {} |
|
for ij_str in sc_link[m_type]: |
|
q_id_str, col_tab_id_str = ij_str.split(",") |
|
q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) |
|
real_q_id = self.idx_map[q_id] |
|
_match[f"{real_q_id},{col_tab_id}"] = sc_link[m_type][ij_str] |
|
|
|
new_sc_link[m_type] = _match |
|
return new_sc_link |
|
|
|
|
|
class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): |
|
|
|
def __init__( |
|
self, |
|
save_path, |
|
db_path, |
|
fix_issue_16_primary_keys=False, |
|
include_table_name_in_column = False, |
|
bert_version = "bert-base-uncased", |
|
compute_sc_link=True, |
|
compute_cv_link=False): |
|
|
|
self.data_dir = os.path.join(save_path, 'enc') |
|
self.db_path = db_path |
|
self.texts = collections.defaultdict(list) |
|
self.fix_issue_16_primary_keys = fix_issue_16_primary_keys |
|
self.include_table_name_in_column = include_table_name_in_column |
|
self.compute_sc_link = compute_sc_link |
|
self.compute_cv_link = compute_cv_link |
|
|
|
self.counted_db_ids = set() |
|
self.preprocessed_schemas = {} |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(bert_version) |
|
self.tokenizer.add_special_tokens({"additional_special_tokens": ["<col>"]}) |
|
|
|
column_types = ["text", "number", "time", "boolean", "others"] |
|
self.tokenizer.add_tokens([f"<type: {t}>" for t in column_types]) |
|
|
|
def _tokenize(self, presplit, unsplit): |
|
if self.tokenizer: |
|
toks = self.tokenizer.tokenize(unsplit) |
|
return toks |
|
return presplit |
|
|
|
|
|
def add_item(self, item, section, validation_info): |
|
preprocessed = self.preprocess_item(item, validation_info) |
|
self.texts[section].append(preprocessed) |
|
|
|
def preprocess_item(self, item, validation_info): |
|
question = self._tokenize(item.text, item.orig['question']) |
|
preproc_schema = self._preprocess_schema(item.schema) |
|
if self.compute_sc_link: |
|
question_bert_tokens = Bertokens(item.text) |
|
sc_link = question_bert_tokens.bert_schema_linking( |
|
preproc_schema.normalized_column_names, |
|
preproc_schema.normalized_table_names |
|
) |
|
else: |
|
sc_link = {"q_col_match": {}, "q_tab_match": {}} |
|
|
|
if self.compute_cv_link: |
|
question_bert_tokens = Bertokens(question) |
|
cv_link = compute_cell_value_linking( |
|
question_bert_tokens.normalized_pieces, item.schema, self.db_path) |
|
else: |
|
cv_link = {"num_date_match": {}, "cell_match": {}} |
|
|
|
return { |
|
'raw_question': item.orig['question'], |
|
'question': question, |
|
'db_id': item.schema.db_id, |
|
'sc_link': sc_link, |
|
'cv_link': cv_link, |
|
'columns': preproc_schema.column_names, |
|
'tables': preproc_schema.table_names, |
|
'table_bounds': preproc_schema.table_bounds, |
|
'column_to_table': preproc_schema.column_to_table, |
|
'table_to_columns': preproc_schema.table_to_columns, |
|
'foreign_keys': preproc_schema.foreign_keys, |
|
'foreign_keys_tables': preproc_schema.foreign_keys_tables, |
|
'primary_keys': preproc_schema.primary_keys, |
|
} |
|
|
|
def validate_item(self, item, section): |
|
question = self._tokenize(item.text, item.orig['question']) |
|
preproc_schema = self._preprocess_schema(item.schema) |
|
|
|
num_words = len(question) + 2 + \ |
|
sum(len(c) + 1 for c in preproc_schema.column_names) + \ |
|
sum(len(t) + 1 for t in preproc_schema.table_names) |
|
if num_words > 512: |
|
return False, None |
|
else: |
|
return True, None |
|
|
|
def _preprocess_schema(self, schema): |
|
if schema.db_id in self.preprocessed_schemas: |
|
return self.preprocessed_schemas[schema.db_id] |
|
result = preprocess_schema_uncached(schema, self._tokenize, |
|
self.include_table_name_in_column, |
|
self.fix_issue_16_primary_keys, bert=True) |
|
self.preprocessed_schemas[schema.db_id] = result |
|
return result |
|
|
|
|
|
def save(self): |
|
os.makedirs(self.data_dir, exist_ok=True) |
|
self.tokenizer.save_pretrained(self.data_dir) |
|
|
|
for section, texts in self.texts.items(): |
|
with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: |
|
for text in texts: |
|
f.write(json.dumps(text) + '\n') |
|
|
|
def load(self): |
|
self.tokenizer = BertTokenizer.from_pretrained(self.data_dir) |
|
|
|
|
|
|
|
@registry.register('encoder', 'spider-bert') |
|
class SpiderEncoderBert(torch.nn.Module): |
|
|
|
Preproc = SpiderEncoderBertPreproc |
|
batched = True |
|
|
|
def __init__( |
|
self, |
|
device, |
|
preproc, |
|
update_config={}, |
|
bert_token_type=False, |
|
bert_version="bert-base-uncased", |
|
summarize_header="first", |
|
use_column_type=True, |
|
include_in_memory=('question', 'column', 'table')): |
|
super().__init__() |
|
self._device = device |
|
self.preproc = preproc |
|
self.bert_token_type = bert_token_type |
|
self.base_enc_hidden_size = 1024 if bert_version == "bert-large-uncased-whole-word-masking" else 768 |
|
|
|
assert summarize_header in ["first", "avg"] |
|
self.summarize_header = summarize_header |
|
self.enc_hidden_size = self.base_enc_hidden_size |
|
self.use_column_type = use_column_type |
|
|
|
self.include_in_memory = set(include_in_memory) |
|
update_modules = { |
|
'relational_transformer': |
|
spider_enc_modules.RelationalTransformerUpdate, |
|
'none': |
|
spider_enc_modules.NoOpUpdate, |
|
} |
|
|
|
self.encs_update = registry.instantiate( |
|
update_modules[update_config['name']], |
|
update_config, |
|
unused_keys={"name"}, |
|
device=self._device, |
|
hidden_size=self.enc_hidden_size, |
|
sc_link = True, |
|
) |
|
|
|
self.bert_model = BertModel.from_pretrained(bert_version) |
|
self.tokenizer = self.preproc.tokenizer |
|
self.bert_model.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
|
|
def forward(self, descs): |
|
batch_token_lists = [] |
|
batch_id_to_retrieve_question = [] |
|
batch_id_to_retrieve_column = [] |
|
batch_id_to_retrieve_table = [] |
|
if self.summarize_header == "avg": |
|
batch_id_to_retrieve_column_2 = [] |
|
batch_id_to_retrieve_table_2 = [] |
|
long_seq_set = set() |
|
batch_id_map = {} |
|
for batch_idx, desc in enumerate(descs): |
|
qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
|
if self.use_column_type: |
|
cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']] |
|
else: |
|
cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']] |
|
tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']] |
|
|
|
token_list = qs + [c for col in cols for c in col] + \ |
|
[t for tab in tabs for t in tab] |
|
assert self.check_bert_seq(token_list) |
|
if len(token_list) > 512: |
|
long_seq_set.add(batch_idx) |
|
continue |
|
|
|
q_b = len(qs) |
|
col_b = q_b + sum(len(c) for c in cols) |
|
|
|
question_indexes = list(range(q_b))[1:-1] |
|
|
|
column_indexes = \ |
|
np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1] ]).tolist() |
|
table_indexes = \ |
|
np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() |
|
if self.summarize_header == "avg": |
|
column_indexes_2 = \ |
|
np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:] |
|
table_indexes_2 = \ |
|
np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:] |
|
|
|
indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) |
|
batch_token_lists.append(indexed_token_list) |
|
|
|
question_rep_ids = torch.LongTensor(question_indexes).to(self._device) |
|
batch_id_to_retrieve_question.append(question_rep_ids) |
|
column_rep_ids = torch.LongTensor(column_indexes).to(self._device) |
|
batch_id_to_retrieve_column.append(column_rep_ids) |
|
table_rep_ids = torch.LongTensor(table_indexes).to(self._device) |
|
batch_id_to_retrieve_table.append(table_rep_ids) |
|
if self.summarize_header == "avg": |
|
assert(all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) |
|
column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) |
|
batch_id_to_retrieve_column_2.append(column_rep_ids_2) |
|
assert(all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) |
|
table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) |
|
batch_id_to_retrieve_table_2.append(table_rep_ids_2) |
|
|
|
batch_id_map[batch_idx] = len(batch_id_map) |
|
|
|
padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists) |
|
tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) |
|
att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) |
|
|
|
if self.bert_token_type: |
|
tok_type_tensor = torch.LongTensor(tok_type_lists).to(self._device) |
|
bert_output = self.bert_model(tokens_tensor, |
|
attention_mask=att_masks_tensor, token_type_ids=tok_type_tensor)[0] |
|
else: |
|
bert_output = self.bert_model(tokens_tensor, |
|
attention_mask=att_masks_tensor)[0] |
|
|
|
enc_output = bert_output |
|
|
|
column_pointer_maps = [ |
|
{ |
|
i: [i] |
|
for i in range(len(desc['columns'])) |
|
} |
|
for desc in descs |
|
] |
|
table_pointer_maps = [ |
|
{ |
|
i: [i] |
|
for i in range(len(desc['tables'])) |
|
} |
|
for desc in descs |
|
] |
|
|
|
assert len(long_seq_set) == 0 |
|
|
|
result = [] |
|
for batch_idx, desc in enumerate(descs): |
|
c_boundary = list(range(len(desc["columns"]) + 1)) |
|
t_boundary = list(range(len(desc["tables"]) + 1)) |
|
|
|
if batch_idx in long_seq_set: |
|
q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) |
|
else: |
|
bert_batch_idx = batch_id_map[batch_idx] |
|
q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] |
|
col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] |
|
tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] |
|
|
|
if self.summarize_header == "avg": |
|
col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] |
|
tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] |
|
|
|
col_enc = (col_enc + col_enc_2) / 2.0 |
|
tab_enc = (tab_enc + tab_enc_2) / 2.0 |
|
|
|
assert q_enc.size()[0] == len(desc["question"]) |
|
assert col_enc.size()[0] == c_boundary[-1] |
|
assert tab_enc.size()[0] == t_boundary[-1] |
|
|
|
q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ |
|
self.encs_update.forward_unbatched( |
|
desc, |
|
q_enc.unsqueeze(1), |
|
col_enc.unsqueeze(1), |
|
c_boundary, |
|
tab_enc.unsqueeze(1), |
|
t_boundary) |
|
import pickle |
|
pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, |
|
"t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) |
|
|
|
|
|
memory = [] |
|
if 'question' in self.include_in_memory: |
|
memory.append(q_enc_new_item) |
|
if 'column' in self.include_in_memory: |
|
memory.append(c_enc_new_item) |
|
if 'table' in self.include_in_memory: |
|
memory.append(t_enc_new_item) |
|
memory = torch.cat(memory, dim=1) |
|
|
|
result.append(SpiderEncoderState( |
|
state=None, |
|
memory=memory, |
|
question_memory=q_enc_new_item, |
|
schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
|
|
|
words=desc['question'], |
|
pointer_memories={ |
|
'column': c_enc_new_item, |
|
'table': t_enc_new_item, |
|
}, |
|
pointer_maps={ |
|
'column': column_pointer_maps[batch_idx], |
|
'table': table_pointer_maps[batch_idx], |
|
}, |
|
m2c_align_mat=align_mat_item[0], |
|
m2t_align_mat=align_mat_item[1], |
|
)) |
|
return result |
|
|
|
@DeprecationWarning |
|
def encoder_long_seq(self, desc): |
|
""" |
|
Since bert cannot handle sequence longer than 512, each column/table is encoded individually |
|
The representation of a column/table is the vector of the first token [CLS] |
|
""" |
|
qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
|
cols = [self.pad_single_sentence_for_bert(c, cls=True) for c in desc['columns']] |
|
tabs = [self.pad_single_sentence_for_bert(t, cls=True) for t in desc['tables']] |
|
|
|
enc_q = self._bert_encode(qs) |
|
enc_col = self._bert_encode(cols) |
|
enc_tab = self._bert_encode(tabs) |
|
return enc_q, enc_col, enc_tab |
|
|
|
@DeprecationWarning |
|
def _bert_encode(self, toks): |
|
if not isinstance(toks[0], list): |
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) |
|
tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) |
|
outputs = self.bert_model(tokens_tensor) |
|
return outputs[0][0, 1:-1] |
|
else: |
|
max_len = max([len(it) for it in toks]) |
|
tok_ids = [] |
|
for item_toks in toks: |
|
item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) |
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) |
|
tok_ids.append(indexed_tokens) |
|
|
|
tokens_tensor = torch.tensor(tok_ids).to(self._device) |
|
outputs = self.bert_model(tokens_tensor) |
|
return outputs[0][:,0,:] |
|
|
|
def check_bert_seq(self, toks): |
|
if toks[0] == self.tokenizer.cls_token and toks[-1] == self.tokenizer.sep_token: |
|
return True |
|
else: |
|
return False |
|
|
|
def pad_single_sentence_for_bert(self, toks, cls=True): |
|
if cls: |
|
return [self.tokenizer.cls_token] + toks + [self.tokenizer.sep_token] |
|
else: |
|
return toks + [self.tokenizer.sep_token] |
|
|
|
def pad_sequence_for_bert_batch(self, tokens_lists): |
|
pad_id = self.tokenizer.pad_token_id |
|
max_len = max([len(it) for it in tokens_lists]) |
|
assert max_len <= 512 |
|
toks_ids = [] |
|
att_masks = [] |
|
tok_type_lists = [] |
|
for item_toks in tokens_lists: |
|
padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) |
|
toks_ids.append(padded_item_toks) |
|
|
|
_att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) |
|
att_masks.append(_att_mask) |
|
|
|
first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) |
|
assert first_sep_id > 0 |
|
_tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) |
|
tok_type_lists.append(_tok_type_list) |
|
return toks_ids, att_masks, tok_type_lists |
|
|
|
|
|
""" |
|
############################### |
|
BART models |
|
############################### |
|
""" |
|
|
|
class BartTokens: |
|
def __init__(self, text, tokenizer): |
|
self.text = text |
|
|
|
self.tokenizer = tokenizer |
|
self.normalized_pieces = None |
|
self.idx_map = None |
|
|
|
self.normalize_toks() |
|
|
|
def normalize_toks(self): |
|
tokens = nltk.word_tokenize(self.text.replace("'", " ' ").replace('"', ' " ')) |
|
self.idx_map = {} |
|
|
|
toks = [] |
|
for i, tok in enumerate(tokens): |
|
self.idx_map[i] = len(toks) |
|
toks.extend(self.tokenizer.tokenize(tok, add_prefix_space=True)) |
|
|
|
normalized_toks = [] |
|
for i, tok in enumerate(tokens): |
|
ann = corenlp.annotate(tok, annotators=["tokenize", "ssplit", "lemma"]) |
|
lemmas = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] |
|
lemma_word = " ".join(lemmas) |
|
normalized_toks.append(lemma_word) |
|
self.normalized_pieces = normalized_toks |
|
|
|
def bart_schema_linking(self, columns, tables): |
|
question_tokens = self.normalized_pieces |
|
column_tokens = [c.normalized_pieces for c in columns] |
|
table_tokens = [t.normalized_pieces for t in tables] |
|
sc_link = compute_schema_linking(question_tokens, column_tokens, table_tokens) |
|
|
|
new_sc_link = {} |
|
for m_type in sc_link: |
|
_match = {} |
|
for ij_str in sc_link[m_type]: |
|
q_id_str, col_tab_id_str = ij_str.split(",") |
|
q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) |
|
real_q_id = self.idx_map[q_id] |
|
_match[f"{real_q_id},{col_tab_id}"] = sc_link[m_type][ij_str] |
|
new_sc_link[m_type] = _match |
|
return new_sc_link |
|
|
|
def bart_cv_linking(self, schema, db_path): |
|
question_tokens = self.normalized_pieces |
|
cv_link = compute_cell_value_linking(question_tokens, schema, db_path) |
|
|
|
new_cv_link = {} |
|
for m_type in cv_link: |
|
if m_type != "normalized_token": |
|
_match = {} |
|
for ij_str in cv_link[m_type]: |
|
q_id_str, col_tab_id_str = ij_str.split(",") |
|
q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) |
|
real_q_id = self.idx_map[q_id] |
|
_match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str] |
|
|
|
new_cv_link[m_type] = _match |
|
else: |
|
new_cv_link[m_type] = cv_link[m_type] |
|
return new_cv_link |
|
|
|
|
|
|
|
|
|
def preprocess_schema_uncached_bart(schema, |
|
tokenizer, |
|
tokenize_func, |
|
include_table_name_in_column, |
|
fix_issue_16_primary_keys, |
|
bart=False): |
|
"""If it's bert, we also cache the normalized version of |
|
question/column/table for schema linking""" |
|
r = PreprocessedSchema() |
|
|
|
if bart: assert not include_table_name_in_column |
|
|
|
last_table_id = None |
|
for i, column in enumerate(schema.columns): |
|
col_toks = tokenize_func( |
|
column.name, column.unsplit_name) |
|
|
|
|
|
type_tok = '<type: {}>'.format(column.type) |
|
if bart: |
|
|
|
column_name = col_toks + [type_tok] |
|
r.normalized_column_names.append(BartTokens(column.unsplit_name, tokenizer)) |
|
else: |
|
column_name = [type_tok] + col_toks |
|
|
|
if include_table_name_in_column: |
|
if column.table is None: |
|
table_name = ['<any-table>'] |
|
else: |
|
table_name = tokenize_func( |
|
column.table.name, column.table.unsplit_name) |
|
column_name += ['<table-sep>'] + table_name |
|
r.column_names.append(column_name) |
|
|
|
table_id = None if column.table is None else column.table.id |
|
r.column_to_table[str(i)] = table_id |
|
if table_id is not None: |
|
columns = r.table_to_columns.setdefault(str(table_id), []) |
|
columns.append(i) |
|
if last_table_id != table_id: |
|
r.table_bounds.append(i) |
|
last_table_id = table_id |
|
|
|
if column.foreign_key_for is not None: |
|
r.foreign_keys[str(column.id)] = column.foreign_key_for.id |
|
r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) |
|
|
|
r.table_bounds.append(len(schema.columns)) |
|
assert len(r.table_bounds) == len(schema.tables) + 1 |
|
|
|
for i, table in enumerate(schema.tables): |
|
table_toks = tokenize_func( |
|
table.name, table.unsplit_name) |
|
r.table_names.append(table_toks) |
|
if bart: |
|
r.normalized_table_names.append(BartTokens(table.unsplit_name, tokenizer)) |
|
last_table = schema.tables[-1] |
|
|
|
r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) |
|
r.primary_keys = [ |
|
column.id |
|
for table in schema.tables |
|
for column in table.primary_keys |
|
] if fix_issue_16_primary_keys else [ |
|
column.id |
|
for column in last_table.primary_keys |
|
for table in schema.tables |
|
] |
|
|
|
return r |
|
|
|
import nltk |
|
class SpiderEncoderBartPreproc(SpiderEncoderV2Preproc): |
|
|
|
def __init__( |
|
self, |
|
save_path, |
|
db_path, |
|
fix_issue_16_primary_keys=False, |
|
include_table_name_in_column=False, |
|
bart_version = "bart-large", |
|
compute_sc_link=True, |
|
compute_cv_link=False): |
|
self.data_dir = os.path.join(save_path, 'enc') |
|
self.db_path = db_path |
|
self.texts = collections.defaultdict(list) |
|
self.fix_issue_16_primary_keys = fix_issue_16_primary_keys |
|
self.include_table_name_in_column = include_table_name_in_column |
|
self.compute_sc_link = compute_sc_link |
|
self.compute_cv_link = compute_cv_link |
|
|
|
self.counted_db_ids = set() |
|
self.preprocessed_schemas = {} |
|
|
|
self.tokenizer = BartTokenizer.from_pretrained(bart_version) |
|
|
|
column_types = ["text", "number", "time", "boolean", "others"] |
|
self.tokenizer.add_tokens([f"<type: {t}>" for t in column_types]) |
|
|
|
def _tokenize(self, presplit, unsplit): |
|
|
|
|
|
tokens = nltk.word_tokenize(unsplit.replace("'", " ' ").replace('"', ' " ')) |
|
toks = [] |
|
for token in tokens: |
|
toks.extend(self.tokenizer.tokenize(token, add_prefix_space=True)) |
|
return toks |
|
|
|
def add_item(self, item, section, validation_info): |
|
preprocessed = self.preprocess_item(item, validation_info) |
|
self.texts[section].append(preprocessed) |
|
|
|
def preprocess_item(self, item, validation_info): |
|
|
|
|
|
question = self._tokenize(item.text, item.orig['question']) |
|
preproc_schema = self._preprocess_schema(item.schema) |
|
question_bart_tokens = BartTokens(item.orig['question'], self.tokenizer) |
|
if self.compute_sc_link: |
|
|
|
sc_link = question_bart_tokens.bart_schema_linking( |
|
preproc_schema.normalized_column_names, |
|
preproc_schema.normalized_table_names |
|
) |
|
else: |
|
sc_link = {"q_col_match": {}, "q_tab_match": {}} |
|
|
|
if self.compute_cv_link: |
|
cv_link = question_bart_tokens.bart_cv_linking( |
|
item.schema, self.db_path) |
|
else: |
|
cv_link = {"num_date_match": {}, "cell_match": {}} |
|
|
|
return { |
|
'raw_question': item.orig['question'], |
|
'question': question, |
|
'db_id': item.schema.db_id, |
|
'sc_link': sc_link, |
|
'cv_link': cv_link, |
|
'columns': preproc_schema.column_names, |
|
'tables': preproc_schema.table_names, |
|
'table_bounds': preproc_schema.table_bounds, |
|
'column_to_table': preproc_schema.column_to_table, |
|
'table_to_columns': preproc_schema.table_to_columns, |
|
'foreign_keys': preproc_schema.foreign_keys, |
|
'foreign_keys_tables': preproc_schema.foreign_keys_tables, |
|
'primary_keys': preproc_schema.primary_keys, |
|
} |
|
|
|
def validate_item(self, item, section): |
|
question = self._tokenize(item.text, item.orig['question']) |
|
preproc_schema = self._preprocess_schema(item.schema) |
|
|
|
num_words = len(question) + 2 + \ |
|
sum(len(c) + 1 for c in preproc_schema.column_names) + \ |
|
sum(len(t) + 1 for t in preproc_schema.table_names) |
|
if num_words > 512: |
|
return False, None |
|
else: |
|
return True, None |
|
|
|
def _preprocess_schema(self, schema): |
|
if schema.db_id in self.preprocessed_schemas: |
|
return self.preprocessed_schemas[schema.db_id] |
|
result = preprocess_schema_uncached_bart(schema, self.tokenizer, self._tokenize, |
|
self.include_table_name_in_column, |
|
self.fix_issue_16_primary_keys, bart=True) |
|
self.preprocessed_schemas[schema.db_id] = result |
|
return result |
|
|
|
def save(self): |
|
os.makedirs(self.data_dir, exist_ok=True) |
|
self.tokenizer.save_pretrained(self.data_dir) |
|
|
|
for section, texts in self.texts.items(): |
|
with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: |
|
for text in texts: |
|
f.write(json.dumps(text) + '\n') |
|
|
|
def load(self): |
|
self.tokenizer = BartTokenizer.from_pretrained(self.data_dir) |
|
|
|
|
|
@registry.register('encoder', 'spider-bart') |
|
class SpiderEncoderBart(torch.nn.Module): |
|
Preproc = SpiderEncoderBartPreproc |
|
batched = True |
|
|
|
def __init__( |
|
self, |
|
device, |
|
preproc, |
|
update_config={}, |
|
bart_version="facebook/bart-large", |
|
summarize_header="first", |
|
use_column_type=True, |
|
include_in_memory=('question', 'column', 'table')): |
|
super().__init__() |
|
self._device = device |
|
self.preproc = preproc |
|
self.base_enc_hidden_size = 1024 |
|
|
|
assert summarize_header in ["first", "avg"] |
|
self.summarize_header = summarize_header |
|
self.enc_hidden_size = self.base_enc_hidden_size |
|
self.use_column_type = use_column_type |
|
|
|
self.include_in_memory = set(include_in_memory) |
|
update_modules = { |
|
'relational_transformer': |
|
spider_enc_modules.RelationalTransformerUpdate, |
|
'none': |
|
spider_enc_modules.NoOpUpdate, |
|
} |
|
|
|
self.encs_update = registry.instantiate( |
|
update_modules[update_config['name']], |
|
update_config, |
|
unused_keys={"name"}, |
|
device=self._device, |
|
hidden_size=self.enc_hidden_size, |
|
sc_link=True, |
|
) |
|
|
|
self.bert_model = BartModel.from_pretrained(bart_version) |
|
print(next(self.bert_model.encoder.parameters())) |
|
|
|
def replace_model_with_pretrained(model, path, prefix): |
|
restore_state_dict = torch.load( |
|
path, map_location=lambda storage, location: storage) |
|
keep_keys = [] |
|
for key in restore_state_dict.keys(): |
|
if key.startswith(prefix): |
|
keep_keys.append(key) |
|
loaded_dict = {k.replace(prefix, ""): restore_state_dict[k] for k in keep_keys} |
|
model.load_state_dict(loaded_dict) |
|
print("Updated the model with {}".format(path)) |
|
|
|
|
|
self.tokenizer = self.preproc.tokenizer |
|
self.bert_model.resize_token_embeddings(50266) |
|
|
|
replace_model_with_pretrained(self.bert_model.encoder, os.path.join( |
|
"./pretrained_checkpoint", |
|
"pytorch_model.bin"), "bert.model.encoder.") |
|
self.bert_model.resize_token_embeddings(len(self.tokenizer)) |
|
self.bert_model = self.bert_model.encoder |
|
self.bert_model.decoder = None |
|
|
|
print(next(self.bert_model.parameters())) |
|
|
|
def forward(self, descs): |
|
batch_token_lists = [] |
|
batch_id_to_retrieve_question = [] |
|
batch_id_to_retrieve_column = [] |
|
batch_id_to_retrieve_table = [] |
|
if self.summarize_header == "avg": |
|
batch_id_to_retrieve_column_2 = [] |
|
batch_id_to_retrieve_table_2 = [] |
|
long_seq_set = set() |
|
batch_id_map = {} |
|
for batch_idx, desc in enumerate(descs): |
|
qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
|
if self.use_column_type: |
|
cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']] |
|
else: |
|
cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']] |
|
tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']] |
|
|
|
token_list = qs + [c for col in cols for c in col] + \ |
|
[t for tab in tabs for t in tab] |
|
assert self.check_bert_seq(token_list) |
|
if len(token_list) > 512: |
|
long_seq_set.add(batch_idx) |
|
continue |
|
|
|
q_b = len(qs) |
|
col_b = q_b + sum(len(c) for c in cols) |
|
|
|
question_indexes = list(range(q_b))[1:-1] |
|
|
|
column_indexes = \ |
|
np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1]]).tolist() |
|
table_indexes = \ |
|
np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() |
|
if self.summarize_header == "avg": |
|
column_indexes_2 = \ |
|
np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:] |
|
table_indexes_2 = \ |
|
np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:] |
|
|
|
indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) |
|
batch_token_lists.append(indexed_token_list) |
|
|
|
question_rep_ids = torch.LongTensor(question_indexes).to(self._device) |
|
batch_id_to_retrieve_question.append(question_rep_ids) |
|
column_rep_ids = torch.LongTensor(column_indexes).to(self._device) |
|
batch_id_to_retrieve_column.append(column_rep_ids) |
|
table_rep_ids = torch.LongTensor(table_indexes).to(self._device) |
|
batch_id_to_retrieve_table.append(table_rep_ids) |
|
if self.summarize_header == "avg": |
|
assert (all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) |
|
column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) |
|
batch_id_to_retrieve_column_2.append(column_rep_ids_2) |
|
assert (all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) |
|
table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) |
|
batch_id_to_retrieve_table_2.append(table_rep_ids_2) |
|
|
|
batch_id_map[batch_idx] = len(batch_id_map) |
|
|
|
padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists) |
|
tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) |
|
att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) |
|
|
|
|
|
bert_output = self.bert_model(tokens_tensor, |
|
attention_mask=att_masks_tensor)[0] |
|
|
|
enc_output = bert_output |
|
|
|
column_pointer_maps = [ |
|
{ |
|
i: [i] |
|
for i in range(len(desc['columns'])) |
|
} |
|
for desc in descs |
|
] |
|
table_pointer_maps = [ |
|
{ |
|
i: [i] |
|
for i in range(len(desc['tables'])) |
|
} |
|
for desc in descs |
|
] |
|
|
|
assert len(long_seq_set) == 0 |
|
|
|
result = [] |
|
for batch_idx, desc in enumerate(descs): |
|
c_boundary = list(range(len(desc["columns"]) + 1)) |
|
t_boundary = list(range(len(desc["tables"]) + 1)) |
|
|
|
if batch_idx in long_seq_set: |
|
q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) |
|
else: |
|
bert_batch_idx = batch_id_map[batch_idx] |
|
q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] |
|
col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] |
|
tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] |
|
|
|
if self.summarize_header == "avg": |
|
col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] |
|
tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] |
|
|
|
col_enc = (col_enc + col_enc_2) / 2.0 |
|
tab_enc = (tab_enc + tab_enc_2) / 2.0 |
|
|
|
assert q_enc.size()[0] == len(desc["question"]) |
|
assert col_enc.size()[0] == c_boundary[-1] |
|
assert tab_enc.size()[0] == t_boundary[-1] |
|
|
|
q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ |
|
self.encs_update.forward_unbatched( |
|
desc, |
|
q_enc.unsqueeze(1), |
|
col_enc.unsqueeze(1), |
|
c_boundary, |
|
tab_enc.unsqueeze(1), |
|
t_boundary) |
|
import pickle |
|
pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, |
|
"t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) |
|
|
|
memory = [] |
|
if 'question' in self.include_in_memory: |
|
memory.append(q_enc_new_item) |
|
if 'column' in self.include_in_memory: |
|
memory.append(c_enc_new_item) |
|
if 'table' in self.include_in_memory: |
|
memory.append(t_enc_new_item) |
|
memory = torch.cat(memory, dim=1) |
|
|
|
result.append(SpiderEncoderState( |
|
state=None, |
|
memory=memory, |
|
question_memory=q_enc_new_item, |
|
schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
|
|
|
words=desc['question'], |
|
pointer_memories={ |
|
'column': c_enc_new_item, |
|
'table': t_enc_new_item, |
|
}, |
|
pointer_maps={ |
|
'column': column_pointer_maps[batch_idx], |
|
'table': table_pointer_maps[batch_idx], |
|
}, |
|
m2c_align_mat=align_mat_item[0], |
|
m2t_align_mat=align_mat_item[1], |
|
)) |
|
return result |
|
|
|
@DeprecationWarning |
|
def encoder_long_seq(self, desc): |
|
""" |
|
Since bert cannot handle sequence longer than 512, each column/table is encoded individually |
|
The representation of a column/table is the vector of the first token [CLS] |
|
""" |
|
qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
|
cols = [self.pad_single_sentence_for_bert(c, cls=True) for c in desc['columns']] |
|
tabs = [self.pad_single_sentence_for_bert(t, cls=True) for t in desc['tables']] |
|
|
|
enc_q = self._bert_encode(qs) |
|
enc_col = self._bert_encode(cols) |
|
enc_tab = self._bert_encode(tabs) |
|
return enc_q, enc_col, enc_tab |
|
|
|
@DeprecationWarning |
|
def _bert_encode(self, toks): |
|
if not isinstance(toks[0], list): |
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) |
|
tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) |
|
outputs = self.bert_model(tokens_tensor) |
|
return outputs[0][0, 1:-1] |
|
else: |
|
max_len = max([len(it) for it in toks]) |
|
tok_ids = [] |
|
for item_toks in toks: |
|
item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) |
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) |
|
tok_ids.append(indexed_tokens) |
|
|
|
tokens_tensor = torch.tensor(tok_ids).to(self._device) |
|
outputs = self.bert_model(tokens_tensor) |
|
return outputs[0][:, 0, :] |
|
|
|
def check_bert_seq(self, toks): |
|
if toks[0] == self.tokenizer.cls_token and toks[-1] == self.tokenizer.sep_token: |
|
return True |
|
else: |
|
return False |
|
|
|
def pad_single_sentence_for_bert(self, toks, cls=True): |
|
if cls: |
|
return [self.tokenizer.cls_token] + toks + [self.tokenizer.sep_token] |
|
else: |
|
return toks + [self.tokenizer.sep_token] |
|
|
|
def pad_sequence_for_bert_batch(self, tokens_lists): |
|
pad_id = self.tokenizer.pad_token_id |
|
max_len = max([len(it) for it in tokens_lists]) |
|
assert max_len <= 512 |
|
toks_ids = [] |
|
att_masks = [] |
|
tok_type_lists = [] |
|
for item_toks in tokens_lists: |
|
padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) |
|
toks_ids.append(padded_item_toks) |
|
|
|
_att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) |
|
att_masks.append(_att_mask) |
|
|
|
first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) |
|
assert first_sep_id > 0 |
|
_tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) |
|
tok_type_lists.append(_tok_type_list) |
|
return toks_ids, att_masks, tok_type_lists |