import collections import itertools import json import os import attr import nltk.corpus import torch import torchtext import numpy as np from seq2struct.models import abstract_preproc, transformer from seq2struct.models.spider import spider_enc_modules#, spider_enc_modulesT5 from seq2struct.resources import pretrained_embeddings from seq2struct.utils import registry from seq2struct.utils import vocab from seq2struct.utils import serialization from seq2struct import resources #from seq2struct.resources import corenlp #from transformers import BertModel, BertTokenizer, BartModel, BartTokenizer from transformers import BertModel, BertTokenizer, BartModel, BartTokenizer, MBartModel, MBart50Tokenizer, MT5Model, MT5Tokenizer, T5Tokenizer, T5ForConditionalGeneration #model mT5 MT5Model, MT5Tokenizer, AutoModelForSeq2SeqLM, AutoTokenizer from seq2struct.models.spider.spider_match_utils import ( compute_schema_linking, compute_cell_value_linking ) import simplemma from simplemma import text_lemmatizer #langdata = simplemma.load_data('en') #langdata = simplemma.load_data('pt','en') langdata = simplemma.load_data('en','pt','es','fr') @attr.s class SpiderEncoderState: state = attr.ib() memory = attr.ib() question_memory = attr.ib() schema_memory = attr.ib() words = attr.ib() pointer_memories = attr.ib() pointer_maps = attr.ib() m2c_align_mat = attr.ib() m2t_align_mat = attr.ib() def find_word_occurrences(self, word): return [i for i, w in enumerate(self.words) if w == word] @attr.s class PreprocessedSchema: column_names = attr.ib(factory=list) table_names = attr.ib(factory=list) table_bounds = attr.ib(factory=list) column_to_table = attr.ib(factory=dict) table_to_columns = attr.ib(factory=dict) foreign_keys = attr.ib(factory=dict) foreign_keys_tables = attr.ib(factory=lambda: collections.defaultdict(set)) primary_keys = attr.ib(factory=list) # only for bert version normalized_column_names = attr.ib(factory=list) normalized_table_names = attr.ib(factory=list) def preprocess_schema_uncached(schema, tokenize_func, include_table_name_in_column, fix_issue_16_primary_keys, bert=False): """If it's bert, we also cache the normalized version of question/column/table for schema linking""" r = PreprocessedSchema() if bert: assert not include_table_name_in_column last_table_id = None for i, column in enumerate(schema.columns): col_toks = tokenize_func( column.name, column.unsplit_name) # assert column.type in ["text", "number", "time", "boolean", "others"] type_tok = ''.format(column.type) if bert: # for bert, we take the representation of the first word column_name = col_toks + [type_tok] r.normalized_column_names.append(Bertokens(col_toks)) else: column_name = [type_tok] + col_toks if include_table_name_in_column: if column.table is None: table_name = [''] else: table_name = tokenize_func( column.table.name, column.table.unsplit_name) column_name += [''] + table_name r.column_names.append(column_name) table_id = None if column.table is None else column.table.id r.column_to_table[str(i)] = table_id if table_id is not None: columns = r.table_to_columns.setdefault(str(table_id), []) columns.append(i) if last_table_id != table_id: r.table_bounds.append(i) last_table_id = table_id if column.foreign_key_for is not None: r.foreign_keys[str(column.id)] = column.foreign_key_for.id r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) r.table_bounds.append(len(schema.columns)) assert len(r.table_bounds) == len(schema.tables) + 1 for i, table in enumerate(schema.tables): table_toks = tokenize_func( table.name, table.unsplit_name) r.table_names.append(table_toks) if bert: r.normalized_table_names.append(Bertokens(table_toks)) last_table = schema.tables[-1] r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) r.primary_keys = [ column.id for table in schema.tables for column in table.primary_keys ] if fix_issue_16_primary_keys else [ column.id for column in last_table.primary_keys for table in schema.tables ] return r class SpiderEncoderV2Preproc(abstract_preproc.AbstractPreproc): def __init__( self, save_path, min_freq=3, max_count=5000, include_table_name_in_column=True, word_emb=None, count_tokens_in_word_emb_for_vocab=False, # https://github.com/rshin/seq2struct/issues/16 fix_issue_16_primary_keys=False, compute_sc_link=False, compute_cv_link=False, db_path=None): if word_emb is None: self.word_emb = None else: self.word_emb = registry.construct('word_emb', word_emb) self.data_dir = os.path.join(save_path, 'enc') self.include_table_name_in_column = include_table_name_in_column self.count_tokens_in_word_emb_for_vocab = count_tokens_in_word_emb_for_vocab self.fix_issue_16_primary_keys = fix_issue_16_primary_keys self.compute_sc_link = compute_sc_link self.compute_cv_link = compute_cv_link self.texts = collections.defaultdict(list) self.db_path = db_path if self.compute_cv_link: assert self.db_path is not None self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) self.vocab_path = os.path.join(save_path, 'enc_vocab.json') self.vocab_word_freq_path = os.path.join(save_path, 'enc_word_freq.json') self.vocab = None self.counted_db_ids = set() self.preprocessed_schemas = {} def validate_item(self, item, section): return True, None def add_item(self, item, section, validation_info): preprocessed = self.preprocess_item(item, validation_info) self.texts[section].append(preprocessed) if section == 'train': if item.schema.db_id in self.counted_db_ids: to_count = preprocessed['question'] else: self.counted_db_ids.add(item.schema.db_id) to_count = itertools.chain( preprocessed['question'], *preprocessed['columns'], *preprocessed['tables']) for token in to_count: count_token = ( self.word_emb is None or self.count_tokens_in_word_emb_for_vocab or self.word_emb.lookup(token) is None) if count_token: self.vocab_builder.add_word(token) def clear_items(self): self.texts = collections.defaultdict(list) def preprocess_item(self, item, validation_info): question, question_for_copying = self._tokenize_for_copying(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) if self.compute_sc_link: assert preproc_schema.column_names[0][0].startswith(""]}) #TODO: should get types from the data column_types = ["text", "number", "time", "boolean", "others"] self.tokenizer.add_tokens([f"" for t in column_types]) def _tokenize(self, presplit, unsplit): if self.tokenizer: toks = self.tokenizer.tokenize(unsplit) return toks return presplit def add_item(self, item, section, validation_info): preprocessed = self.preprocess_item(item, validation_info) self.texts[section].append(preprocessed) def preprocess_item(self, item, validation_info): question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) if self.compute_sc_link: question_bert_tokens = Bertokens(item.text) sc_link = question_bert_tokens.bert_schema_linking( preproc_schema.normalized_column_names, preproc_schema.normalized_table_names ) else: sc_link = {"q_col_match": {}, "q_tab_match": {}} if self.compute_cv_link: question_bert_tokens = Bertokens(question) cv_link = compute_cell_value_linking( question_bert_tokens.normalized_pieces, item.schema, self.db_path) else: cv_link = {"num_date_match": {}, "cell_match": {}} return { 'raw_question': item.orig['question'], 'question': question, 'db_id': item.schema.db_id, 'sc_link': sc_link, 'cv_link': cv_link, 'columns': preproc_schema.column_names, 'tables': preproc_schema.table_names, 'table_bounds': preproc_schema.table_bounds, 'column_to_table': preproc_schema.column_to_table, 'table_to_columns': preproc_schema.table_to_columns, 'foreign_keys': preproc_schema.foreign_keys, 'foreign_keys_tables': preproc_schema.foreign_keys_tables, 'primary_keys': preproc_schema.primary_keys, } def validate_item(self, item, section): question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) num_words = len(question) + 2 + \ sum(len(c) + 1 for c in preproc_schema.column_names) + \ sum(len(t) + 1 for t in preproc_schema.table_names) if num_words > 512: return False, None # remove long sequences else: return True, None def _preprocess_schema(self, schema): if schema.db_id in self.preprocessed_schemas: return self.preprocessed_schemas[schema.db_id] result = preprocess_schema_uncached(schema, self._tokenize, self.include_table_name_in_column, self.fix_issue_16_primary_keys, bert=True) self.preprocessed_schemas[schema.db_id] = result return result def save(self): os.makedirs(self.data_dir, exist_ok=True) self.tokenizer.save_pretrained(self.data_dir) for section, texts in self.texts.items(): with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f:#UTF-8 #with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: for text in texts: f.write(json.dumps(text, ensure_ascii=False) + '\n')#UTF-8 #f.write(json.dumps(text) + '\n') def load(self): self.tokenizer = BertTokenizer.from_pretrained(self.data_dir) @registry.register('encoder', 'spider-bert') class SpiderEncoderBert(torch.nn.Module): Preproc = SpiderEncoderBertPreproc batched = True def __init__( self, device, preproc, update_config={}, bert_token_type=False, bert_version="bert-base-uncased", summarize_header="first", use_column_type=True, include_in_memory=('question', 'column', 'table')): super().__init__() self._device = device self.preproc = preproc self.bert_token_type = bert_token_type self.base_enc_hidden_size = 1024 if bert_version == "bert-large-uncased-whole-word-masking" or bert_version == "neuralmind/bert-large-portuguese-cased" else 768 assert summarize_header in ["first", "avg"] self.summarize_header = summarize_header self.enc_hidden_size = self.base_enc_hidden_size self.use_column_type = use_column_type self.include_in_memory = set(include_in_memory) update_modules = { 'relational_transformer': spider_enc_modules.RelationalTransformerUpdate, 'none': spider_enc_modules.NoOpUpdate, } self.encs_update = registry.instantiate( update_modules[update_config['name']], update_config, unused_keys={"name"}, device=self._device, hidden_size=self.enc_hidden_size, sc_link = True, ) self.bert_model = BertModel.from_pretrained(bert_version) self.tokenizer = self.preproc.tokenizer self.bert_model.resize_token_embeddings(len(self.tokenizer)) # several tokens added def forward(self, descs): batch_token_lists = [] batch_id_to_retrieve_question = [] batch_id_to_retrieve_column = [] batch_id_to_retrieve_table = [] if self.summarize_header == "avg": batch_id_to_retrieve_column_2 = [] batch_id_to_retrieve_table_2 = [] long_seq_set = set() batch_id_map = {} # some long examples are not included for batch_idx, desc in enumerate(descs): qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) if self.use_column_type: cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']] else: cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']] tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']] token_list = qs + [c for col in cols for c in col] + \ [t for tab in tabs for t in tab] assert self.check_bert_seq(token_list) if len(token_list) > 512: long_seq_set.add(batch_idx) continue q_b = len(qs) col_b = q_b + sum(len(c) for c in cols) # leave out [CLS] and [SEP] question_indexes = list(range(q_b))[1:-1] # use the first representation for column/table column_indexes = \ np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1] ]).tolist() table_indexes = \ np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() if self.summarize_header == "avg": column_indexes_2 = \ np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:] table_indexes_2 = \ np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:] indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) batch_token_lists.append(indexed_token_list) question_rep_ids = torch.LongTensor(question_indexes).to(self._device) batch_id_to_retrieve_question.append(question_rep_ids) column_rep_ids = torch.LongTensor(column_indexes).to(self._device) batch_id_to_retrieve_column.append(column_rep_ids) table_rep_ids = torch.LongTensor(table_indexes).to(self._device) batch_id_to_retrieve_table.append(table_rep_ids) if self.summarize_header == "avg": assert(all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) batch_id_to_retrieve_column_2.append(column_rep_ids_2) assert(all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) batch_id_to_retrieve_table_2.append(table_rep_ids_2) batch_id_map[batch_idx] = len(batch_id_map) padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists) tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) if self.bert_token_type: tok_type_tensor = torch.LongTensor(tok_type_lists).to(self._device) bert_output = self.bert_model(tokens_tensor, attention_mask=att_masks_tensor, token_type_ids=tok_type_tensor)[0] else: bert_output = self.bert_model(tokens_tensor, attention_mask=att_masks_tensor)[0] enc_output = bert_output column_pointer_maps = [ { i: [i] for i in range(len(desc['columns'])) } for desc in descs ] table_pointer_maps = [ { i: [i] for i in range(len(desc['tables'])) } for desc in descs ] assert len(long_seq_set) == 0 # remove them for now result = [] for batch_idx, desc in enumerate(descs): c_boundary = list(range(len(desc["columns"]) + 1)) t_boundary = list(range(len(desc["tables"]) + 1)) if batch_idx in long_seq_set: q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) else: bert_batch_idx = batch_id_map[batch_idx] q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] if self.summarize_header == "avg": col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] col_enc = (col_enc + col_enc_2) / 2.0 # avg of first and last token tab_enc = (tab_enc + tab_enc_2) / 2.0 # avg of first and last token assert q_enc.size()[0] == len(desc["question"]) assert col_enc.size()[0] == c_boundary[-1] assert tab_enc.size()[0] == t_boundary[-1] q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ self.encs_update.forward_unbatched( desc, q_enc.unsqueeze(1), col_enc.unsqueeze(1), c_boundary, tab_enc.unsqueeze(1), t_boundary) import pickle pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, "t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) memory = [] if 'question' in self.include_in_memory: memory.append(q_enc_new_item) if 'column' in self.include_in_memory: memory.append(c_enc_new_item) if 'table' in self.include_in_memory: memory.append(t_enc_new_item) memory = torch.cat(memory, dim=1) result.append(SpiderEncoderState( state=None, memory=memory, question_memory=q_enc_new_item, schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), # TODO: words should match memory words=desc['question'], pointer_memories={ 'column': c_enc_new_item, 'table': t_enc_new_item, }, pointer_maps={ 'column': column_pointer_maps[batch_idx], 'table': table_pointer_maps[batch_idx], }, m2c_align_mat=align_mat_item[0], m2t_align_mat=align_mat_item[1], )) return result @DeprecationWarning def encoder_long_seq(self, desc): """ Since bert cannot handle sequence longer than 512, each column/table is encoded individually The representation of a column/table is the vector of the first token [CLS] """ qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) cols = [self.pad_single_sentence_for_bert(c, cls=True) for c in desc['columns']] tabs = [self.pad_single_sentence_for_bert(t, cls=True) for t in desc['tables']] enc_q = self._bert_encode(qs) enc_col = self._bert_encode(cols) enc_tab = self._bert_encode(tabs) return enc_q, enc_col, enc_tab @DeprecationWarning def _bert_encode(self, toks): if not isinstance(toks[0], list): # encode question words indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) outputs = self.bert_model(tokens_tensor) return outputs[0][0, 1:-1] # remove [CLS] and [SEP] else: max_len = max([len(it) for it in toks]) tok_ids = [] for item_toks in toks: item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) tok_ids.append(indexed_tokens) tokens_tensor = torch.tensor(tok_ids).to(self._device) outputs = self.bert_model(tokens_tensor) return outputs[0][:,0,:] def check_bert_seq(self, toks): if toks[0] == self.tokenizer.cls_token and toks[-1] == self.tokenizer.sep_token: return True else: return False def pad_single_sentence_for_bert(self, toks, cls=True): if cls: return [self.tokenizer.cls_token] + toks + [self.tokenizer.sep_token] else: return toks + [self.tokenizer.sep_token] def pad_sequence_for_bert_batch(self, tokens_lists): pad_id = self.tokenizer.pad_token_id max_len = max([len(it) for it in tokens_lists]) assert max_len <= 512 toks_ids = [] att_masks = [] tok_type_lists = [] for item_toks in tokens_lists: padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) toks_ids.append(padded_item_toks) _att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) att_masks.append(_att_mask) first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) assert first_sep_id > 0 _tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) tok_type_lists.append(_tok_type_list) return toks_ids, att_masks, tok_type_lists """ ############################### BART models ############################### """ class BartTokens: def __init__(self, text, bart_version, tokenizer): self.text = text # pieces is tokenized tokens. self.tokenizer = tokenizer self.normalized_pieces = None self.idx_map = None self.bart_version = bart_version #print(self.tokenizer.name_or_path) self.normalize_toks() def normalize_toks(self): tokens = nltk.word_tokenize(self.text.replace("'", " ' ").replace('"', ' " ')) self.idx_map = {} # This map piece index to token index toks = [] if(self.bart_version != "facebook/bart-large" and self.bart_version != "facebook/mbart-large-50-many-to-many-mmt"): assert False, "Model version not defined." if self.bart_version == "facebook/bart-large": for i, tok in enumerate(tokens): self.idx_map[i] = len(toks) toks.extend(self.tokenizer.tokenize(tok, add_prefix_space=True)) if self.bart_version == "facebook/mbart-large-50-many-to-many-mmt": for i, tok in enumerate(tokens): self.idx_map[i] = len(toks) toks.extend(self.tokenizer.tokenize(tok)) # for i, tok in enumerate(tokens): # self.idx_map[i] = len(toks) # toks.extend(self.tokenizer.tokenize(tok, add_prefix_space=True)) normalized_toks = [] for i, tok in enumerate(tokens): normalized_toks.append(simplemma.lemmatize(tok, langdata)) # normalized_toks = [] # for i, tok in enumerate(tokens): # ann = corenlp.annotate(tok, annotators=["tokenize", "ssplit", "lemma"]) # lemmas = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] # lemma_word = " ".join(lemmas) # normalized_toks.append(lemma_word) self.normalized_pieces = normalized_toks def bart_schema_linking(self, columns, tables): question_tokens = self.normalized_pieces column_tokens = [c.normalized_pieces for c in columns] table_tokens = [t.normalized_pieces for t in tables] sc_link = compute_schema_linking(question_tokens, column_tokens, table_tokens) new_sc_link = {} for m_type in sc_link: _match = {} for ij_str in sc_link[m_type]: q_id_str, col_tab_id_str = ij_str.split(",") q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) real_q_id = self.idx_map[q_id] _match[f"{real_q_id},{col_tab_id}"] = sc_link[m_type][ij_str] new_sc_link[m_type] = _match return new_sc_link def bart_cv_linking(self, schema, db_path): question_tokens = self.normalized_pieces cv_link = compute_cell_value_linking(question_tokens, schema, db_path) new_cv_link = {} for m_type in cv_link: if m_type != "normalized_token": _match = {} for ij_str in cv_link[m_type]: q_id_str, col_tab_id_str = ij_str.split(",") q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) real_q_id = self.idx_map[q_id] _match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str] new_cv_link[m_type] = _match else: new_cv_link[m_type] = cv_link[m_type] return new_cv_link def preprocess_schema_uncached_bart(schema, tokenizer, tokenize_func, include_table_name_in_column, fix_issue_16_primary_keys, bart_version, bart=False): """If it's bert, we also cache the normalized version of question/column/table for schema linking""" r = PreprocessedSchema() if bart: assert not include_table_name_in_column last_table_id = None for i, column in enumerate(schema.columns): col_toks = tokenize_func( column.name, column.unsplit_name) # assert column.type in ["text", "number", "time", "boolean", "others"] type_tok = ''.format(column.type) if bart: # for bert, we take the representation of the first word column_name = col_toks + [type_tok] r.normalized_column_names.append(BartTokens(column.unsplit_name, bart_version, tokenizer)) else: column_name = [type_tok] + col_toks if include_table_name_in_column: if column.table is None: table_name = [''] else: table_name = tokenize_func( column.table.name, column.table.unsplit_name) column_name += [''] + table_name r.column_names.append(column_name) table_id = None if column.table is None else column.table.id r.column_to_table[str(i)] = table_id if table_id is not None: columns = r.table_to_columns.setdefault(str(table_id), []) columns.append(i) if last_table_id != table_id: r.table_bounds.append(i) last_table_id = table_id if column.foreign_key_for is not None: r.foreign_keys[str(column.id)] = column.foreign_key_for.id r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) r.table_bounds.append(len(schema.columns)) assert len(r.table_bounds) == len(schema.tables) + 1 for i, table in enumerate(schema.tables): table_toks = tokenize_func( table.name, table.unsplit_name) r.table_names.append(table_toks) if bart: r.normalized_table_names.append(BartTokens(table.unsplit_name, bart_version, tokenizer)) last_table = schema.tables[-1] r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) r.primary_keys = [ column.id for table in schema.tables for column in table.primary_keys ] if fix_issue_16_primary_keys else [ column.id for column in last_table.primary_keys for table in schema.tables ] return r import nltk class SpiderEncoderBartPreproc(SpiderEncoderV2Preproc): # Why in the BERT model, we set the include_table_name_in_column as False? def __init__( self, save_path, db_path, fix_issue_16_primary_keys=False, include_table_name_in_column=False, bart_version = "bart-large", pretrained_checkpoint = "pretrained_checkpoint/pytorch_model.bin", compute_sc_link=True, compute_cv_link=False): self.data_dir = os.path.join(save_path, 'enc') self.db_path = db_path self.texts = collections.defaultdict(list) self.fix_issue_16_primary_keys = fix_issue_16_primary_keys self.include_table_name_in_column = include_table_name_in_column self.compute_sc_link = compute_sc_link self.compute_cv_link = compute_cv_link self.bart_version = bart_version self.counted_db_ids = set() self.preprocessed_schemas = {} #Adicionado para tratamento para model MBART50 print(f"SpiderEncoderBartPreproc Model: {bart_version}") print(f"SpiderEncoderBartPreproc Pretrained Checkpoint: {pretrained_checkpoint}") if(bart_version != "facebook/bart-large" and bart_version != "facebook/mbart-large-50-many-to-many-mmt"): assert False, "Model version not defined." if bart_version == "facebook/bart-large": self.tokenizer = BartTokenizer.from_pretrained(bart_version) if bart_version == "facebook/mbart-large-50-many-to-many-mmt": self.tokenizer = MBart50Tokenizer.from_pretrained(bart_version) #self.tokenizer = BartTokenizer.from_pretrained(bart_version) column_types = ["text", "number", "time", "boolean", "others"] self.tokenizer.add_tokens([f"" for t in column_types]) def _tokenize(self, presplit, unsplit): # I want to keep this tokenization consistent with BartTokens. # Presplit is required here. tokens = nltk.word_tokenize(unsplit.replace("'", " ' ").replace('"', ' " ')) toks = [] if(self.bart_version != "facebook/bart-large" and self.bart_version != "facebook/mbart-large-50-many-to-many-mmt"): assert False, "Model version not defined." if self.bart_version == "facebook/bart-large": for token in tokens: toks.extend(self.tokenizer.tokenize(token, add_prefix_space=True)) if self.bart_version == "facebook/mbart-large-50-many-to-many-mmt": for token in tokens: toks.extend(self.tokenizer.tokenize(token)) # for token in tokens: # toks.extend(self.tokenizer.tokenize(token, add_prefix_space=True)) return toks def add_item(self, item, section, validation_info): preprocessed = self.preprocess_item(item, validation_info) self.texts[section].append(preprocessed) def preprocess_item(self, item, validation_info): # For bart, there is a punctuation issue if we want to merge it back to words. # So here I will use nltk to further tokenize the sentence first. question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) question_bart_tokens = BartTokens(item.orig['question'], self.bart_version, self.tokenizer) if self.compute_sc_link: # We do not want to transform pieces back to word. sc_link = question_bart_tokens.bart_schema_linking( preproc_schema.normalized_column_names, preproc_schema.normalized_table_names ) else: sc_link = {"q_col_match": {}, "q_tab_match": {}} if self.compute_cv_link: cv_link = question_bart_tokens.bart_cv_linking( item.schema, self.db_path) else: cv_link = {"num_date_match": {}, "cell_match": {}} return { 'raw_question': item.orig['question'], 'question': question, 'db_id': item.schema.db_id, 'sc_link': sc_link, 'cv_link': cv_link, 'columns': preproc_schema.column_names, 'tables': preproc_schema.table_names, 'table_bounds': preproc_schema.table_bounds, 'column_to_table': preproc_schema.column_to_table, 'table_to_columns': preproc_schema.table_to_columns, 'foreign_keys': preproc_schema.foreign_keys, 'foreign_keys_tables': preproc_schema.foreign_keys_tables, 'primary_keys': preproc_schema.primary_keys, } def validate_item(self, item, section): question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) # 2 is for cls and sep special tokens. +1 is for sep num_words = len(question) + 2 + \ sum(len(c) + 1 for c in preproc_schema.column_names) + \ sum(len(t) + 1 for t in preproc_schema.table_names) if num_words > 512: return False, None # remove long sequences else: return True, None def _preprocess_schema(self, schema): if schema.db_id in self.preprocessed_schemas: return self.preprocessed_schemas[schema.db_id] result = preprocess_schema_uncached_bart(schema, self.tokenizer, self._tokenize, self.include_table_name_in_column, self.fix_issue_16_primary_keys, self.bart_version, bart=True) self.preprocessed_schemas[schema.db_id] = result return result def save(self): os.makedirs(self.data_dir, exist_ok=True) self.tokenizer.save_pretrained(self.data_dir) for section, texts in self.texts.items(): with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f:#UTF-8 #with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: for text in texts: f.write(json.dumps(text, ensure_ascii=False) + '\n')#UTF-8 #f.write(json.dumps(text) + '\n') def load(self): #Adicionado para tratamento para model MBART50 print(f"BART load Model: {self.bart_version}") if(self.bart_version != "facebook/bart-large" and self.bart_version != "facebook/mbart-large-50-many-to-many-mmt"): assert False, "Model version not defined." if self.bart_version == "facebook/bart-large": self.tokenizer = BartTokenizer.from_pretrained(self.data_dir) if self.bart_version == "facebook/mbart-large-50-many-to-many-mmt": self.tokenizer = MBart50Tokenizer.from_pretrained(self.data_dir) #self.tokenizer = BartTokenizer.from_pretrained(self.data_dir) @registry.register('encoder', 'spider-bart') class SpiderEncoderBart(torch.nn.Module): Preproc = SpiderEncoderBartPreproc batched = True def __init__( self, device, preproc, update_config={}, bart_version="facebook/bart-large", pretrained_checkpoint = "pretrained_checkpoint/pytorch_model.bin", summarize_header="first", use_column_type=True, include_in_memory=('question', 'column', 'table')): super().__init__() self._device = device self.preproc = preproc self.base_enc_hidden_size = 1024 assert summarize_header in ["first", "avg"] self.summarize_header = summarize_header self.enc_hidden_size = self.base_enc_hidden_size self.use_column_type = use_column_type self.include_in_memory = set(include_in_memory) update_modules = { 'relational_transformer': spider_enc_modules.RelationalTransformerUpdate, 'none': spider_enc_modules.NoOpUpdate, } self.encs_update = registry.instantiate( update_modules[update_config['name']], update_config, unused_keys={"name"}, device=self._device, hidden_size=self.enc_hidden_size, sc_link=True, ) #Adicionado para tratamento para model MBART50 print(f"SpiderEncoderBart Model: {bart_version}") if(bart_version != "facebook/bart-large" and bart_version != "facebook/mbart-large-50-many-to-many-mmt"): assert False, "Model version not defined." if bart_version == "facebook/bart-large": self.bert_model = BartModel.from_pretrained(bart_version) if bart_version == "facebook/mbart-large-50-many-to-many-mmt": self.bert_model = MBartModel.from_pretrained(bart_version) #self.bert_model = BartModel.from_pretrained(bart_version) print(next(self.bert_model.encoder.parameters())) def replace_model_with_pretrained(model, path, prefix): restore_state_dict = torch.load( path, map_location=lambda storage, location: storage) keep_keys = [] for key in restore_state_dict.keys(): if key.startswith(prefix): keep_keys.append(key) loaded_dict = {k.replace(prefix, ""): restore_state_dict[k] for k in keep_keys} model.load_state_dict(loaded_dict) print("Updated the model with {}".format(path)) self.tokenizer = self.preproc.tokenizer if(bart_version != "facebook/bart-large" and bart_version != "facebook/mbart-large-50-many-to-many-mmt"): assert False, "Model version not defined." if bart_version == "facebook/bart-large": print(f"SpiderEncoderBart Pretrained Checkpoint: {pretrained_checkpoint}") self.bert_model.resize_token_embeddings(50266) # several tokens added replace_model_with_pretrained(self.bert_model.encoder, pretrained_checkpoint, "bert.model.encoder.") #"bert.model.encoder." caracteristico do BART if bart_version == "facebook/mbart-large-50-many-to-many-mmt": print("No GAP - Generation-Augmented Pre-Training") #desativando GAP - Generation-Augmented Pre-Training # #self.bert_model.resize_token_embeddings(50266) # several tokens added #Especifico do BART, nao usar no MBART50 # replace_model_with_pretrained(self.bert_model.encoder, pretrained_checkpoint, "model.encoder.") #"model.encoder." caracteristico do MBART50, para descobrir de outro modelo, faça print key em replace_model_with_pretrained self.bert_model.resize_token_embeddings(len(self.tokenizer)) self.bert_model = self.bert_model.encoder self.bert_model.decoder = None print(next(self.bert_model.parameters())) def forward(self, descs): batch_token_lists = [] batch_id_to_retrieve_question = [] batch_id_to_retrieve_column = [] batch_id_to_retrieve_table = [] if self.summarize_header == "avg": batch_id_to_retrieve_column_2 = [] batch_id_to_retrieve_table_2 = [] long_seq_set = set() batch_id_map = {} # some long examples are not included for batch_idx, desc in enumerate(descs): qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) if self.use_column_type: cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']] else: cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']] tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']] #print(f"Código do T5, mas na execução BART\nQuestion: {qs}\nColumns: {cols}\nTables: {tabs}\n") token_list = qs + [c for col in cols for c in col] + \ [t for tab in tabs for t in tab] assert self.check_bert_seq(token_list) if len(token_list) > 512: long_seq_set.add(batch_idx) continue q_b = len(qs) col_b = q_b + sum(len(c) for c in cols) # leave out [CLS] and [SEP] question_indexes = list(range(q_b))[1:-1] # use the first representation for column/table column_indexes = \ np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1]]).tolist() table_indexes = \ np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() if self.summarize_header == "avg": column_indexes_2 = \ np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:] table_indexes_2 = \ np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:] #print(f"column_indexes: {column_indexes}\ncolumn_indexes_2: {column_indexes_2}\ntable_indexes: {table_indexes}\ntable_indexes_2: {table_indexes_2}\n") indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) batch_token_lists.append(indexed_token_list) question_rep_ids = torch.LongTensor(question_indexes).to(self._device) batch_id_to_retrieve_question.append(question_rep_ids) column_rep_ids = torch.LongTensor(column_indexes).to(self._device) batch_id_to_retrieve_column.append(column_rep_ids) table_rep_ids = torch.LongTensor(table_indexes).to(self._device) batch_id_to_retrieve_table.append(table_rep_ids) if self.summarize_header == "avg": assert (all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) batch_id_to_retrieve_column_2.append(column_rep_ids_2) assert (all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) batch_id_to_retrieve_table_2.append(table_rep_ids_2) batch_id_map[batch_idx] = len(batch_id_map) padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists) tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) bert_output = self.bert_model(tokens_tensor, attention_mask=att_masks_tensor)[0] enc_output = bert_output column_pointer_maps = [ { i: [i] for i in range(len(desc['columns'])) } for desc in descs ] table_pointer_maps = [ { i: [i] for i in range(len(desc['tables'])) } for desc in descs ] assert len(long_seq_set) == 0 # remove them for now result = [] for batch_idx, desc in enumerate(descs): c_boundary = list(range(len(desc["columns"]) + 1)) t_boundary = list(range(len(desc["tables"]) + 1)) if batch_idx in long_seq_set: q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) else: bert_batch_idx = batch_id_map[batch_idx] q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] if self.summarize_header == "avg": col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] col_enc = (col_enc + col_enc_2) / 2.0 # avg of first and last token tab_enc = (tab_enc + tab_enc_2) / 2.0 # avg of first and last token assert q_enc.size()[0] == len(desc["question"]) assert col_enc.size()[0] == c_boundary[-1] assert tab_enc.size()[0] == t_boundary[-1] q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ self.encs_update.forward_unbatched( desc, q_enc.unsqueeze(1), col_enc.unsqueeze(1), c_boundary, tab_enc.unsqueeze(1), t_boundary) import pickle pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, "t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) memory = [] if 'question' in self.include_in_memory: memory.append(q_enc_new_item) if 'column' in self.include_in_memory: memory.append(c_enc_new_item) if 'table' in self.include_in_memory: memory.append(t_enc_new_item) memory = torch.cat(memory, dim=1) result.append(SpiderEncoderState( state=None, memory=memory, question_memory=q_enc_new_item, schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), # TODO: words should match memory words=desc['question'], pointer_memories={ 'column': c_enc_new_item, 'table': t_enc_new_item, }, pointer_maps={ 'column': column_pointer_maps[batch_idx], 'table': table_pointer_maps[batch_idx], }, m2c_align_mat=align_mat_item[0], m2t_align_mat=align_mat_item[1], )) return result @DeprecationWarning def encoder_long_seq(self, desc): """ Since bert cannot handle sequence longer than 512, each column/table is encoded individually The representation of a column/table is the vector of the first token [CLS] """ qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) cols = [self.pad_single_sentence_for_bert(c, cls=True) for c in desc['columns']] tabs = [self.pad_single_sentence_for_bert(t, cls=True) for t in desc['tables']] enc_q = self._bert_encode(qs) enc_col = self._bert_encode(cols) enc_tab = self._bert_encode(tabs) return enc_q, enc_col, enc_tab @DeprecationWarning def _bert_encode(self, toks): if not isinstance(toks[0], list): # encode question words indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) outputs = self.bert_model(tokens_tensor) return outputs[0][0, 1:-1] # remove [CLS] and [SEP] else: max_len = max([len(it) for it in toks]) tok_ids = [] for item_toks in toks: item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) tok_ids.append(indexed_tokens) tokens_tensor = torch.tensor(tok_ids).to(self._device) outputs = self.bert_model(tokens_tensor) return outputs[0][:, 0, :] def check_bert_seq(self, toks): if toks[0] == self.tokenizer.cls_token and toks[-1] == self.tokenizer.sep_token: return True else: return False def pad_single_sentence_for_bert(self, toks, cls=True): if cls: return [self.tokenizer.cls_token] + toks + [self.tokenizer.sep_token] else: return toks + [self.tokenizer.sep_token] def pad_sequence_for_bert_batch(self, tokens_lists): pad_id = self.tokenizer.pad_token_id max_len = max([len(it) for it in tokens_lists]) assert max_len <= 512 toks_ids = [] att_masks = [] tok_type_lists = [] for item_toks in tokens_lists: padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) toks_ids.append(padded_item_toks) _att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) att_masks.append(_att_mask) first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) assert first_sep_id > 0 _tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) tok_type_lists.append(_tok_type_list) return toks_ids, att_masks, tok_type_lists """ ############################### T5 models ############################### """ class T5Tokens: def __init__(self, text, t5_version, tokenizer): self.text = text # pieces is tokenized tokens. self.tokenizer = tokenizer self.normalized_pieces = None self.idx_map = None #print(f"T5Tokens self.tokenizer.name_or_path: {self.tokenizer.name_or_path}") self.t5_version = t5_version self.normalize_toks() def normalize_toks(self): tokens = nltk.word_tokenize(self.text.replace("'", " ' ").replace('"', ' " ')) self.idx_map = {} # This map piece index to token index toks = [] #Adicionado para tratamento para model mt5 if(self.t5_version != "facebook/bart-large" and self.t5_version != "facebook/mbart-large-50-many-to-many-mmt" and self.t5_version != "google/mt5-large" and self.t5_version != "google/t5-v1_1-large"): assert False, "Model version not defined." if self.t5_version == "facebook/bart-large": for i, tok in enumerate(tokens): self.idx_map[i] = len(toks) toks.extend(self.tokenizer.tokenize(tok, add_prefix_space=True)) if self.t5_version == "facebook/mbart-large-50-many-to-many-mmt": for i, tok in enumerate(tokens): self.idx_map[i] = len(toks) toks.extend(self.tokenizer.tokenize(tok)) if self.t5_version == "google/mt5-large": for i, tok in enumerate(tokens): self.idx_map[i] = len(toks) toks.extend(self.tokenizer.tokenize(tok)) if self.t5_version == "google/t5-v1_1-large": for i, tok in enumerate(tokens): self.idx_map[i] = len(toks) toks.extend(self.tokenizer.tokenize(tok)) # for i, tok in enumerate(tokens): # self.idx_map[i] = len(toks) # toks.extend(self.tokenizer.tokenize(tok, add_prefix_space=True)) normalized_toks = [] for i, tok in enumerate(tokens): normalized_toks.append(simplemma.lemmatize(tok, langdata)) # normalized_toks = [] # for i, tok in enumerate(tokens): # ann = corenlp.annotate(tok, annotators=["tokenize", "ssplit", "lemma"]) # lemmas = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] # lemma_word = " ".join(lemmas) # normalized_toks.append(lemma_word) self.normalized_pieces = normalized_toks def t5_schema_linking(self, columns, tables): question_tokens = self.normalized_pieces column_tokens = [c.normalized_pieces for c in columns] table_tokens = [t.normalized_pieces for t in tables] sc_link = compute_schema_linking(question_tokens, column_tokens, table_tokens) new_sc_link = {} for m_type in sc_link: _match = {} for ij_str in sc_link[m_type]: q_id_str, col_tab_id_str = ij_str.split(",") q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) real_q_id = self.idx_map[q_id] _match[f"{real_q_id},{col_tab_id}"] = sc_link[m_type][ij_str] new_sc_link[m_type] = _match return new_sc_link def t5_cv_linking(self, schema, db_path): question_tokens = self.normalized_pieces cv_link = compute_cell_value_linking(question_tokens, schema, db_path) new_cv_link = {} for m_type in cv_link: if m_type != "normalized_token": _match = {} for ij_str in cv_link[m_type]: q_id_str, col_tab_id_str = ij_str.split(",") q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) real_q_id = self.idx_map[q_id] _match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str] new_cv_link[m_type] = _match else: new_cv_link[m_type] = cv_link[m_type] return new_cv_link def preprocess_schema_uncached_t5(schema, tokenizer, tokenize_func, include_table_name_in_column, fix_issue_16_primary_keys, t5_version, t5=False): """If it's bert, we also cache the normalized version of question/column/table for schema linking""" r = PreprocessedSchema() if t5: assert not include_table_name_in_column last_table_id = None for i, column in enumerate(schema.columns): col_toks = tokenize_func( column.name, column.unsplit_name) # assert column.type in ["text", "number", "time", "boolean", "others"] type_tok = ''.format(column.type) if t5: # for bert, we take the representation of the first word column_name = col_toks + [type_tok] r.normalized_column_names.append(T5Tokens(column.unsplit_name, t5_version, tokenizer)) else: column_name = [type_tok] + col_toks if include_table_name_in_column: if column.table is None: table_name = [''] else: table_name = tokenize_func( column.table.name, column.table.unsplit_name) column_name += [''] + table_name r.column_names.append(column_name) table_id = None if column.table is None else column.table.id r.column_to_table[str(i)] = table_id if table_id is not None: columns = r.table_to_columns.setdefault(str(table_id), []) columns.append(i) if last_table_id != table_id: r.table_bounds.append(i) last_table_id = table_id if column.foreign_key_for is not None: r.foreign_keys[str(column.id)] = column.foreign_key_for.id r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) r.table_bounds.append(len(schema.columns)) assert len(r.table_bounds) == len(schema.tables) + 1 for i, table in enumerate(schema.tables): table_toks = tokenize_func( table.name, table.unsplit_name) r.table_names.append(table_toks) if t5: r.normalized_table_names.append(T5Tokens(table.unsplit_name, t5_version, tokenizer)) last_table = schema.tables[-1] r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) r.primary_keys = [ column.id for table in schema.tables for column in table.primary_keys ] if fix_issue_16_primary_keys else [ column.id for column in last_table.primary_keys for table in schema.tables ] return r import nltk class SpiderEncoderT5Preproc(SpiderEncoderV2Preproc): # Why in the BERT model, we set the include_table_name_in_column as False? def __init__( self, save_path, db_path, fix_issue_16_primary_keys=False, include_table_name_in_column=False, t5_version = "google/mt5-large", pretrained_checkpoint = "pretrained_checkpoint/pytorch_model.bin", compute_sc_link=True, compute_cv_link=False): self.data_dir = os.path.join(save_path, 'enc') self.db_path = db_path self.texts = collections.defaultdict(list) self.fix_issue_16_primary_keys = fix_issue_16_primary_keys self.include_table_name_in_column = include_table_name_in_column self.compute_sc_link = compute_sc_link self.compute_cv_link = compute_cv_link self.t5_version = t5_version self.counted_db_ids = set() self.preprocessed_schemas = {} #Adicionado para tratamento para model mt5 print(f"SpiderEncoderT5Preproc Model: {t5_version}") print("No GAP - Generation-Augmented Pre-Training") #print(f"SpiderEncoderT5Preproc Pretrained Checkpoint (not used): {pretrained_checkpoint}") if t5_version == "facebook/bart-large": self.tokenizer = BartTokenizer.from_pretrained(t5_version) if t5_version == "facebook/mbart-large-50-many-to-many-mmt": self.tokenizer = MBart50Tokenizer.from_pretrained(t5_version) if t5_version == "google/mt5-large": self.tokenizer = MT5Tokenizer.from_pretrained(t5_version) if t5_version == "google/t5-v1_1-large": self.tokenizer = T5Tokenizer.from_pretrained(t5_version) column_types = ["text", "number", "time", "boolean", "others"] self.tokenizer.add_tokens([f"" for t in column_types]) def _tokenize(self, presplit, unsplit): # I want to keep this tokenization consistent with BartTokens. # Presplit is required here. tokens = nltk.word_tokenize(unsplit.replace("'", " ' ").replace('"', ' " ')) toks = [] #print(f"_tokenize self.tokenizer.name_or_path: {self.tokenizer.name_or_path}") if(self.t5_version != "facebook/bart-large" and self.t5_version != "facebook/mbart-large-50-many-to-many-mmt" and self.t5_version != "google/mt5-large" and self.t5_version != "google/t5-v1_1-large"): assert False, "Model version not defined." if self.t5_version == "facebook/bart-large": for token in tokens: toks.extend(self.tokenizer.tokenize(token, add_prefix_space=True)) if self.t5_version == "facebook/mbart-large-50-many-to-many-mmt": for token in tokens: toks.extend(self.tokenizer.tokenize(token)) if self.t5_version == "google/mt5-large": for token in tokens: toks.extend(self.tokenizer.tokenize(token)) if self.t5_version == "google/t5-v1_1-large": for token in tokens: toks.extend(self.tokenizer.tokenize(token)) # for token in tokens: # toks.extend(self.tokenizer.tokenize(token, add_prefix_space=True)) return toks def add_item(self, item, section, validation_info): preprocessed = self.preprocess_item(item, validation_info) self.texts[section].append(preprocessed) def preprocess_item(self, item, validation_info): # For bart, there is a punctuation issue if we want to merge it back to words. # So here I will use nltk to further tokenize the sentence first. question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) question_t5_tokens = T5Tokens(item.orig['question'], self.t5_version, self.tokenizer) if self.compute_sc_link: # We do not want to transform pieces back to word. sc_link = question_t5_tokens.t5_schema_linking( preproc_schema.normalized_column_names, preproc_schema.normalized_table_names ) else: sc_link = {"q_col_match": {}, "q_tab_match": {}} if self.compute_cv_link: cv_link = question_t5_tokens.t5_cv_linking( item.schema, self.db_path) else: cv_link = {"num_date_match": {}, "cell_match": {}} return { 'raw_question': item.orig['question'], 'question': question, 'db_id': item.schema.db_id, 'sc_link': sc_link, 'cv_link': cv_link, 'columns': preproc_schema.column_names, 'tables': preproc_schema.table_names, 'table_bounds': preproc_schema.table_bounds, 'column_to_table': preproc_schema.column_to_table, 'table_to_columns': preproc_schema.table_to_columns, 'foreign_keys': preproc_schema.foreign_keys, 'foreign_keys_tables': preproc_schema.foreign_keys_tables, 'primary_keys': preproc_schema.primary_keys, } def validate_item(self, item, section): question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) # 2 is for cls and sep special tokens. +1 is for sep num_words = len(question) + 2 + \ sum(len(c) + 1 for c in preproc_schema.column_names) + \ sum(len(t) + 1 for t in preproc_schema.table_names) if num_words > 512: return False, None # remove long sequences else: return True, None def _preprocess_schema(self, schema): if schema.db_id in self.preprocessed_schemas: return self.preprocessed_schemas[schema.db_id] result = preprocess_schema_uncached_t5(schema, self.tokenizer, self._tokenize, self.include_table_name_in_column, self.fix_issue_16_primary_keys, self.t5_version, t5=True) self.preprocessed_schemas[schema.db_id] = result return result def save(self): os.makedirs(self.data_dir, exist_ok=True) self.tokenizer.save_pretrained(self.data_dir) for section, texts in self.texts.items(): with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f:#UTF-8 #with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: for text in texts: f.write(json.dumps(text, ensure_ascii=False) + '\n')#UTF-8 #f.write(json.dumps(text) + '\n') def load(self): #Adicionado para tratamento para model mt5 print(f"T5 load Model: {self.t5_version}") if(self.t5_version != "facebook/bart-large" and self.t5_version != "facebook/mbart-large-50-many-to-many-mmt" and self.t5_version != "google/mt5-large" and self.t5_version != "google/t5-v1_1-large"): assert False, "Model version not defined." if self.t5_version == "facebook/bart-large": self.tokenizer = BartTokenizer.from_pretrained(self.data_dir) if self.t5_version == "facebook/mbart-large-50-many-to-many-mmt": self.tokenizer = MBart50Tokenizer.from_pretrained(self.data_dir) if self.t5_version == "google/mt5-large": self.tokenizer = MT5Tokenizer.from_pretrained(self.data_dir) if self.t5_version == "google/t5-v1_1-large": self.tokenizer = T5Tokenizer.from_pretrained(self.data_dir) @registry.register('encoder', 'spider-t5') class SpiderEncoderT5(torch.nn.Module): Preproc = SpiderEncoderT5Preproc batched = True def __init__( self, device, preproc, update_config={}, t5_version="google/mt5-large", pretrained_checkpoint="pretrained_checkpoint/pytorch_model.bin", summarize_header="first", use_column_type=True, include_in_memory=('question', 'column', 'table')): super().__init__() self._device = device self.preproc = preproc self.base_enc_hidden_size = 1024 assert summarize_header in ["first", "avg"] self.summarize_header = summarize_header self.enc_hidden_size = self.base_enc_hidden_size self.use_column_type = use_column_type self.include_in_memory = set(include_in_memory) update_modules = { 'relational_transformer': spider_enc_modules.RelationalTransformerUpdate, 'none': spider_enc_modules.NoOpUpdate, } self.encs_update = registry.instantiate( update_modules[update_config['name']], update_config, unused_keys={"name"}, device=self._device, hidden_size=self.enc_hidden_size, sc_link=True, ) #Adicionado para tratamento para model mt5 print(f"SpiderEncoderT5 Model: {t5_version}") print("No GAP - Generation-Augmented Pre-Training") #print(f"SpiderEncoderT5 Pretrained Checkpoint (not used): {pretrained_checkpoint}") if(t5_version != "facebook/bart-large" and t5_version != "facebook/mbart-large-50-many-to-many-mmt" and t5_version != "google/mt5-large" and t5_version != "google/t5-v1_1-large"): assert False, "Model version not defined." if t5_version == "facebook/bart-large": self.bert_model = BartModel.from_pretrained(t5_version) if t5_version == "facebook/mbart-large-50-many-to-many-mmt": self.bert_model = MBartModel.from_pretrained(t5_version) if t5_version == "google/mt5-large": self.bert_model = MT5Model.from_pretrained(t5_version) if t5_version == "google/t5-v1_1-large": self.bert_model = T5ForConditionalGeneration.from_pretrained(t5_version) print(next(self.bert_model.encoder.parameters())) #desativando GAP - Generation-Augmented Pre-Training # def replace_model_with_pretrained(model, path, prefix): # restore_state_dict = torch.load( # path, map_location=lambda storage, location: storage) # keep_keys = [] # for key in restore_state_dict.keys(): # if key.startswith(prefix): # keep_keys.append(key) # loaded_dict = {k.replace(prefix, ""): restore_state_dict[k] for k in keep_keys} # model.load_state_dict(loaded_dict) # print("Updated the model with {}".format(path)) self.tokenizer = self.preproc.tokenizer if(t5_version != "facebook/bart-large" and t5_version != "facebook/mbart-large-50-many-to-many-mmt" and t5_version != "google/mt5-large" and t5_version != "google/t5-v1_1-large"): assert False, "Model version not defined." #desativando GAP - Generation-Augmented Pre-Training # if t5_version == "facebook/bart-large": # self.bert_model.resize_token_embeddings(50266) # several tokens added # replace_model_with_pretrained(self.bert_model.encoder, pretrained_checkpoint, "bert.model.encoder.") #"bert.model.encoder." caracteristico do BART # if t5_version == "facebook/mbart-large-50-many-to-many-mmt": # #self.bert_model.resize_token_embeddings(50266) # several tokens added #Especifico do BART, nao usar no MBART50 # replace_model_with_pretrained(self.bert_model.encoder, pretrained_checkpoint, "model.encoder.") #"model.encoder." caracteristico do MBART50, para descobrir de outro modelo, faça print key em replace_model_with_pretrained # if t5_version == "google/mt5-large": # replace_model_with_pretrained(self.bert_model.encoder, pretrained_checkpoint, "encoder.") # "encoder." caracteristico do mT5, para descobrir de outro modelo, faça print key em replace_model_with_pretrained self.bert_model.resize_token_embeddings(len(self.tokenizer)) self.bert_model = self.bert_model.encoder self.bert_model.decoder = None print(next(self.bert_model.parameters())) def forward(self, descs): batch_token_lists = [] batch_id_to_retrieve_question = [] batch_id_to_retrieve_column = [] batch_id_to_retrieve_table = [] if self.summarize_header == "avg": batch_id_to_retrieve_column_2 = [] batch_id_to_retrieve_table_2 = [] long_seq_set = set() batch_id_map = {} # some long examples are not included for batch_idx, desc in enumerate(descs): #mT5 não precisa tirar SEP e CLS qs = desc['question'] if self.use_column_type: cols = [c for c in desc['columns']] else: cols = [c[:-1] for c in desc['columns']] tabs = [t for t in desc['tables']] #print(f"Question: {qs}\nColumns: {cols}\nTables: {tabs}\n") token_list = qs + [c for col in cols for c in col] + \ [t for tab in tabs for t in tab] #assert self.check_bert_seq(token_list) if len(token_list) > 512: long_seq_set.add(batch_idx) continue q_b = len(qs) col_b = q_b + sum(len(c) for c in cols) # leave out [CLS] and [SEP] #question_indexes = list(range(q_b))[1:-1]#BART tem CLS e SEP question_indexes = list(range(q_b)) #mT5 não tem CLS e SEP # use the first representation for column/table column_indexes = \ np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1]]).tolist() table_indexes = \ np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() if self.summarize_header == "avg": column_indexes_2 = \ np.cumsum([q_b-1] + [len(token_list) for token_list in cols]).tolist()[1:] table_indexes_2 = \ np.cumsum([col_b-1] + [len(token_list) for token_list in tabs]).tolist()[1:] #print(f"question_indexes: {question_indexes}\ncolumn_indexes: {column_indexes}\ncolumn_indexes_2: {column_indexes_2}\ntable_indexes: {table_indexes}\ntable_indexes_2: {table_indexes_2}\n") indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) batch_token_lists.append(indexed_token_list) question_rep_ids = torch.LongTensor(question_indexes).to(self._device) batch_id_to_retrieve_question.append(question_rep_ids) column_rep_ids = torch.LongTensor(column_indexes).to(self._device) batch_id_to_retrieve_column.append(column_rep_ids) table_rep_ids = torch.LongTensor(table_indexes).to(self._device) batch_id_to_retrieve_table.append(table_rep_ids) if self.summarize_header == "avg": assert (all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) batch_id_to_retrieve_column_2.append(column_rep_ids_2) assert (all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) batch_id_to_retrieve_table_2.append(table_rep_ids_2) batch_id_map[batch_idx] = len(batch_id_map) padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_t5_batch(batch_token_lists) tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) bert_output = self.bert_model(tokens_tensor, attention_mask=att_masks_tensor)[0] enc_output = bert_output column_pointer_maps = [ { i: [i] for i in range(len(desc['columns'])) } for desc in descs ] table_pointer_maps = [ { i: [i] for i in range(len(desc['tables'])) } for desc in descs ] assert len(long_seq_set) == 0 # remove them for now result = [] for batch_idx, desc in enumerate(descs): c_boundary = list(range(len(desc["columns"]) + 1)) t_boundary = list(range(len(desc["tables"]) + 1)) if batch_idx in long_seq_set: q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) else: bert_batch_idx = batch_id_map[batch_idx] q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] if self.summarize_header == "avg": col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] col_enc = (col_enc + col_enc_2) / 2.0 # avg of first and last token tab_enc = (tab_enc + tab_enc_2) / 2.0 # avg of first and last token assert q_enc.size()[0] == len(desc["question"]) assert col_enc.size()[0] == c_boundary[-1] assert tab_enc.size()[0] == t_boundary[-1] q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ self.encs_update.forward_unbatched( desc, q_enc.unsqueeze(1), col_enc.unsqueeze(1), c_boundary, tab_enc.unsqueeze(1), t_boundary) import pickle pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, "t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) memory = [] if 'question' in self.include_in_memory: memory.append(q_enc_new_item) if 'column' in self.include_in_memory: memory.append(c_enc_new_item) if 'table' in self.include_in_memory: memory.append(t_enc_new_item) memory = torch.cat(memory, dim=1) result.append(SpiderEncoderState( state=None, memory=memory, question_memory=q_enc_new_item, schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), # TODO: words should match memory words=desc['question'], pointer_memories={ 'column': c_enc_new_item, 'table': t_enc_new_item, }, pointer_maps={ 'column': column_pointer_maps[batch_idx], 'table': table_pointer_maps[batch_idx], }, m2c_align_mat=align_mat_item[0], m2t_align_mat=align_mat_item[1], )) return result def encoder_long_seq(self, desc): """ Since bert cannot handle sequence longer than 512, each column/table is encoded individually The representation of a column/table is the vector of the first token [CLS] """ qs = desc['question'] cols = [c for c in desc['columns']] tabs = [t for t in desc['tables']] enc_q = self._t5_encode(qs) enc_col = self._t5_encode(cols) enc_tab = self._t5_encode(tabs) return enc_q, enc_col, enc_tab def _t5_encode(self, toks): if not isinstance(toks[0], list): # encode question words indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) outputs = self.bert_model(tokens_tensor) return outputs[0][0, 1:-1] # remove [CLS] and [SEP] else: max_len = max([len(it) for it in toks]) tok_ids = [] for item_toks in toks: item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) tok_ids.append(indexed_tokens) tokens_tensor = torch.tensor(tok_ids).to(self._device) outputs = self.bert_model(tokens_tensor) return outputs[0][:, 0, :] def pad_sequence_for_t5_batch(self, tokens_lists): pad_id = self.tokenizer.pad_token_id #print(f"self.tokenizer.pad_token_id: {self.tokenizer.pad_token_id}") max_len = max([len(it) for it in tokens_lists]) assert max_len <= 512 toks_ids = [] att_masks = [] tok_type_lists = [] for item_toks in tokens_lists: padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) #print(f"padded_item_toks: {padded_item_toks}") toks_ids.append(padded_item_toks) _att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) #print(f"_att_mask: {_att_mask}") att_masks.append(_att_mask) #first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) #print(f"first_sep_id: {first_sep_id}") #assert first_sep_id > 0 #_tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) _tok_type_list = [0] + [1] * (max_len - 1) tok_type_lists.append(_tok_type_list) return toks_ids, att_masks, tok_type_lists