antonlabate
ver 1.3
d758c99
import itertools
import operator
import numpy as np
import torch
from torch import nn
import torchtext
from seq2struct.models import variational_lstm
from seq2struct.models import transformer
from seq2struct.utils import batched_sequence
def clamp(value, abs_max):
value = max(-abs_max, value)
value = min(abs_max, value)
return value
def get_attn_mask(seq_lengths):
# Given seq_lengths like [3, 1, 2], this will produce
# [[[1, 1, 1],
# [1, 1, 1],
# [1, 1, 1]],
# [[1, 0, 0],
# [0, 0, 0],
# [0, 0, 0]],
# [[1, 1, 0],
# [1, 1, 0],
# [0, 0, 0]]]
# int(max(...)) so that it has type 'int instead of numpy.int64
max_length, batch_size = int(max(seq_lengths)), len(seq_lengths)
attn_mask = torch.LongTensor(batch_size, max_length, max_length).fill_(0)
for batch_idx, seq_length in enumerate(seq_lengths):
attn_mask[batch_idx, :seq_length, :seq_length] = 1
return attn_mask
class LookupEmbeddings(torch.nn.Module):
def __init__(self, device, vocab, embedder, emb_size, learnable_words=[]):
super().__init__()
self._device = device
self.vocab = vocab
self.embedder = embedder
self.emb_size = emb_size
self.embedding = torch.nn.Embedding(
num_embeddings=len(self.vocab),
embedding_dim=emb_size)
if self.embedder:
assert emb_size == self.embedder.dim
# init embedding
self.learnable_words = learnable_words
init_embed_list = []
for i, word in enumerate(self.vocab):
if self.embedder.contains(word):
init_embed_list.append( \
self.embedder.lookup(word))
else:
init_embed_list.append(self.embedding.weight[i])
init_embed_weight = torch.stack(init_embed_list, 0)
self.embedding.weight = nn.Parameter(init_embed_weight)
def forward_unbatched(self, token_lists):
# token_lists: list of list of lists
# [batch, num descs, desc length]
# - each list contains tokens
# - each list corresponds to a column name, table name, etc.
embs = []
for tokens in token_lists:
# token_indices shape: batch (=1) x length
token_indices = torch.tensor(
self.vocab.indices(tokens), device=self._device).unsqueeze(0)
# emb shape: batch (=1) x length x word_emb_size
emb = self.embedding(token_indices)
# emb shape: desc length x batch (=1) x word_emb_size
emb = emb.transpose(0, 1)
embs.append(emb)
# all_embs shape: sum of desc lengths x batch (=1) x word_emb_size
all_embs = torch.cat(embs, dim=0)
# boundaries shape: num of descs + 1
# If desc lengths are [2, 3, 4],
# then boundaries is [0, 2, 5, 9]
boundaries = np.cumsum([0] + [emb.shape[0] for emb in embs])
return all_embs, boundaries
def _compute_boundaries(self, token_lists):
# token_lists: list of list of lists
# [batch, num descs, desc length]
# - each list contains tokens
# - each list corresponds to a column name, table name, etc.
boundaries = [
np.cumsum([0] + [len(token_list) for token_list in token_lists_for_item])
for token_lists_for_item in token_lists]
return boundaries
def _embed_token(self, token, batch_idx):
if token in self.learnable_words or not self.embedder.contains(token):
return self.embedding.weight[self.vocab.index(token)]
else:
emb = self.embedder.lookup(token)
return emb.to(self._device)
def forward(self, token_lists):
# token_lists: list of list of lists
# [batch, num descs, desc length]
# - each list contains tokens
# - each list corresponds to a column name, table name, etc.
# PackedSequencePlus, with shape: [batch, sum of desc lengths, emb_size]
all_embs = batched_sequence.PackedSequencePlus.from_lists(
lists=[
[
token
for token_list in token_lists_for_item
for token in token_list
]
for token_lists_for_item in token_lists
],
item_shape=(self.emb_size,),
device=self._device,
item_to_tensor=self._embed_token)
all_embs = all_embs.apply(lambda d: d.to(self._device))
return all_embs, self._compute_boundaries(token_lists)
def _embed_words_learned(self, token_lists):
# token_lists: list of list of lists
# [batch, num descs, desc length]
# - each list contains tokens
# - each list corresponds to a column name, table name, etc.
# PackedSequencePlus, with shape: [batch, num descs * desc length (sum of desc lengths)]
indices = batched_sequence.PackedSequencePlus.from_lists(
lists=[
[
token
for token_list in token_lists_for_item
for token in token_list
]
for token_lists_for_item in token_lists
],
item_shape=(1,), # For compatibility with old PyTorch versions
tensor_type=torch.LongTensor,
item_to_tensor=lambda token, batch_idx, out: out.fill_(self.vocab.index(token))
)
indices = indices.apply(lambda d: d.to(self._device))
# PackedSequencePlus, with shape: [batch, sum of desc lengths, emb_size]
all_embs = indices.apply(lambda x: self.embedding(x.squeeze(-1)))
return all_embs, self._compute_boundaries(token_lists)
class EmbLinear(torch.nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear = torch.nn.Linear(input_size, output_size)
def forward(self, input_):
all_embs, boundaries = input_
all_embs = all_embs.apply(lambda d: self.linear(d))
return all_embs, boundaries
class BiLSTM(torch.nn.Module):
def __init__(self, input_size, output_size, dropout, summarize, use_native=False):
# input_size: dimensionality of input
# output_size: dimensionality of output
# dropout
# summarize:
# - True: return Tensor of 1 x batch x emb size
# - False: return Tensor of seq len x batch x emb size
super().__init__()
if use_native:
self.lstm = torch.nn.LSTM(
input_size=input_size,
hidden_size=output_size // 2,
bidirectional=True,
dropout=dropout)
self.dropout = torch.nn.Dropout(dropout)
else:
self.lstm = variational_lstm.LSTM(
input_size=input_size,
hidden_size=int(output_size // 2),
bidirectional=True,
dropout=dropout)
self.summarize = summarize
self.use_native = use_native
def forward_unbatched(self, input_):
# all_embs shape: sum of desc lengths x batch (=1) x input_size
all_embs, boundaries = input_
new_boundaries = [0]
outputs = []
for left, right in zip(boundaries, boundaries[1:]):
# state shape:
# - h: num_layers (=1) * num_directions (=2) x batch (=1) x recurrent_size / 2
# - c: num_layers (=1) * num_directions (=2) x batch (=1) x recurrent_size / 2
# output shape: seq len x batch size x output_size
if self.use_native:
inp = self.dropout(all_embs[left:right])
output, (h, c) = self.lstm(inp)
else:
output, (h, c) = self.lstm(all_embs[left:right])
if self.summarize:
seq_emb = torch.cat((h[0], h[1]), dim=-1).unsqueeze(0)
new_boundaries.append(new_boundaries[-1] + 1)
else:
seq_emb = output
new_boundaries.append(new_boundaries[-1] + output.shape[0])
outputs.append(seq_emb)
return torch.cat(outputs, dim=0), new_boundaries
def forward(self, input_):
# all_embs shape: PackedSequencePlus with shape [batch, sum of desc lengths, input_size]
# boundaries: list of lists with shape [batch, num descs + 1]
all_embs, boundaries = input_
# List of the following:
# (batch_idx, desc_idx, length)
desc_lengths = []
batch_desc_to_flat_map = {}
for batch_idx, boundaries_for_item in enumerate(boundaries):
for desc_idx, (left, right) in enumerate(zip(boundaries_for_item, boundaries_for_item[1:])):
desc_lengths.append((batch_idx, desc_idx, right - left))
batch_desc_to_flat_map[batch_idx, desc_idx] = len(batch_desc_to_flat_map)
# Recreate PackedSequencePlus into shape
# [batch * num descs, desc length, input_size]
# with name `rearranged_all_embs`
remapped_ps_indices = []
def rearranged_all_embs_map_index(desc_lengths_idx, seq_idx):
batch_idx, desc_idx, _ = desc_lengths[desc_lengths_idx]
return batch_idx, boundaries[batch_idx][desc_idx] + seq_idx
def rearranged_all_embs_gather_from_indices(indices):
batch_indices, seq_indices = zip(*indices)
remapped_ps_indices[:] = all_embs.raw_index(batch_indices, seq_indices)
return all_embs.ps.data[torch.LongTensor(remapped_ps_indices)]
rearranged_all_embs = batched_sequence.PackedSequencePlus.from_gather(
lengths=[length for _, _, length in desc_lengths],
map_index=rearranged_all_embs_map_index,
gather_from_indices=rearranged_all_embs_gather_from_indices)
rev_remapped_ps_indices = tuple(
x[0] for x in sorted(
enumerate(remapped_ps_indices), key=operator.itemgetter(1)))
# output shape: PackedSequence, [batch * num_descs, desc length, output_size]
# state shape:
# - h: [num_layers (=1) * num_directions (=2), batch, output_size / 2]
# - c: [num_layers (=1) * num_directions (=2), batch, output_size / 2]
if self.use_native:
rearranged_all_embs = rearranged_all_embs.apply(self.dropout)
output, (h, c) = self.lstm(rearranged_all_embs.ps)
if self.summarize:
# h shape: [batch * num descs, output_size]
h = torch.cat((h[0], h[1]), dim=-1)
# new_all_embs: PackedSequencePlus, [batch, num descs, input_size]
new_all_embs = batched_sequence.PackedSequencePlus.from_gather(
lengths=[len(boundaries_for_item) - 1 for boundaries_for_item in boundaries],
map_index=lambda batch_idx, desc_idx: rearranged_all_embs.sort_to_orig[batch_desc_to_flat_map[batch_idx, desc_idx]],
gather_from_indices=lambda indices: h[torch.LongTensor(indices)])
new_boundaries = [
list(range(len(boundaries_for_item)))
for boundaries_for_item in boundaries
]
else:
new_all_embs = all_embs.apply(
lambda _: output.data[torch.LongTensor(rev_remapped_ps_indices)])
new_boundaries = boundaries
return new_all_embs, new_boundaries
class RelationalTransformerUpdate(torch.nn.Module):
def __init__(self, device, num_layers, num_heads, hidden_size,
ff_size=None,
dropout=0.1,
merge_types=False,
tie_layers=False,
qq_max_dist=2,
#qc_token_match=True,
#qt_token_match=True,
#cq_token_match=True,
cc_foreign_key=True,
cc_table_match=True,
cc_max_dist=2,
ct_foreign_key=True,
ct_table_match=True,
#tq_token_match=True,
tc_table_match=True,
tc_foreign_key=True,
tt_max_dist=2,
tt_foreign_key=True,
sc_link=False,
cv_link=False,
):
super().__init__()
self._device = device
self.num_heads = num_heads
self.qq_max_dist = qq_max_dist
#self.qc_token_match = qc_token_match
#self.qt_token_match = qt_token_match
#self.cq_token_match = cq_token_match
self.cc_foreign_key = cc_foreign_key
self.cc_table_match = cc_table_match
self.cc_max_dist = cc_max_dist
self.ct_foreign_key = ct_foreign_key
self.ct_table_match = ct_table_match
#self.tq_token_match = tq_token_match
self.tc_table_match = tc_table_match
self.tc_foreign_key = tc_foreign_key
self.tt_max_dist = tt_max_dist
self.tt_foreign_key = tt_foreign_key
self.relation_ids = {}
def add_relation(name):
self.relation_ids[name] = len(self.relation_ids)
def add_rel_dist(name, max_dist):
for i in range(-max_dist, max_dist + 1):
add_relation((name, i))
add_rel_dist('qq_dist', qq_max_dist)
add_relation('qc_default')
#if qc_token_match:
# add_relation('qc_token_match')
add_relation('qt_default')
#if qt_token_match:
# add_relation('qt_token_match')
add_relation('cq_default')
#if cq_token_match:
# add_relation('cq_token_match')
add_relation('cc_default')
if cc_foreign_key:
add_relation('cc_foreign_key_forward')
add_relation('cc_foreign_key_backward')
if cc_table_match:
add_relation('cc_table_match')
add_rel_dist('cc_dist', cc_max_dist)
add_relation('ct_default')
if ct_foreign_key:
add_relation('ct_foreign_key')
if ct_table_match:
add_relation('ct_primary_key')
add_relation('ct_table_match')
add_relation('ct_any_table')
add_relation('tq_default')
#if cq_token_match:
# add_relation('tq_token_match')
add_relation('tc_default')
if tc_table_match:
add_relation('tc_primary_key')
add_relation('tc_table_match')
add_relation('tc_any_table')
if tc_foreign_key:
add_relation('tc_foreign_key')
add_relation('tt_default')
if tt_foreign_key:
add_relation('tt_foreign_key_forward')
add_relation('tt_foreign_key_backward')
add_relation('tt_foreign_key_both')
add_rel_dist('tt_dist', tt_max_dist)
# schema linking relations
# forward_backward
if sc_link:
add_relation('qcCEM')
add_relation('cqCEM')
add_relation('qtTEM')
add_relation('tqTEM')
add_relation('qcCPM')
add_relation('cqCPM')
add_relation('qtTPM')
add_relation('tqTPM')
if cv_link:
add_relation("qcNUMBER")
add_relation("cqNUMBER")
add_relation("qcTIME")
add_relation("cqTIME")
add_relation("qcCELLMATCH")
add_relation("cqCELLMATCH")
if merge_types:
assert not cc_foreign_key
assert not cc_table_match
assert not ct_foreign_key
assert not ct_table_match
assert not tc_foreign_key
assert not tc_table_match
assert not tt_foreign_key
assert cc_max_dist == qq_max_dist
assert tt_max_dist == qq_max_dist
add_relation('xx_default')
self.relation_ids['qc_default'] = self.relation_ids['xx_default']
self.relation_ids['qt_default'] = self.relation_ids['xx_default']
self.relation_ids['cq_default'] = self.relation_ids['xx_default']
self.relation_ids['cc_default'] = self.relation_ids['xx_default']
self.relation_ids['ct_default'] = self.relation_ids['xx_default']
self.relation_ids['tq_default'] = self.relation_ids['xx_default']
self.relation_ids['tc_default'] = self.relation_ids['xx_default']
self.relation_ids['tt_default'] = self.relation_ids['xx_default']
if sc_link:
self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
if cv_link:
self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default']
self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default']
for i in range(-qq_max_dist, qq_max_dist + 1):
self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i]
self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i]
if ff_size is None:
ff_size = hidden_size * 4
self.encoder = transformer.Encoder(
lambda: transformer.EncoderLayer(
hidden_size,
transformer.MultiHeadedAttentionWithRelations(
num_heads,
hidden_size,
dropout),
transformer.PositionwiseFeedForward(
hidden_size,
ff_size,
dropout),
len(self.relation_ids),
dropout),
hidden_size,
num_layers,
tie_layers)
self.align_attn = transformer.PointerWithRelations(hidden_size,
len(self.relation_ids), dropout)
def create_align_mask(self, num_head, q_length, c_length, t_length):
# mask with size num_heads * all_len * all * len
all_length = q_length + c_length + t_length
mask_1 = torch.ones(num_head - 1, all_length, all_length, device=self._device)
mask_2 = torch.zeros(1, all_length, all_length, device=self._device)
for i in range(q_length):
for j in range(q_length, q_length + c_length):
mask_2[0, i, j] = 1
mask_2[0, j, i] = 1
mask = torch.cat([mask_1, mask_2], 0)
return mask
def forward_unbatched(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries):
# enc shape: total len x batch (=1) x recurrent size
enc = torch.cat((q_enc, c_enc, t_enc), dim=0)
# enc shape: batch (=1) x total len x recurrent size
enc = enc.transpose(0, 1)
# Catalogue which things are where
relations = self.compute_relations(
desc,
enc_length=enc.shape[1],
q_enc_length=q_enc.shape[0],
c_enc_length=c_enc.shape[0],
c_boundaries=c_boundaries,
t_boundaries=t_boundaries)
relations_t = torch.LongTensor(relations).to(self._device)
enc_new = self.encoder(enc, relations_t, mask=None)
# Split updated_enc again
c_base = q_enc.shape[0]
t_base = q_enc.shape[0] + c_enc.shape[0]
q_enc_new = enc_new[:, :c_base]
c_enc_new = enc_new[:, c_base:t_base]
t_enc_new = enc_new[:, t_base:]
m2c_align_mat = self.align_attn(enc_new, enc_new[:, c_base:t_base], \
enc_new[:, c_base:t_base], relations_t[:, c_base:t_base])
m2t_align_mat = self.align_attn(enc_new, enc_new[:, t_base:], \
enc_new[:, t_base:], relations_t[:, t_base:])
return q_enc_new, c_enc_new, t_enc_new, (m2c_align_mat, m2t_align_mat)
def forward(self, descs, q_enc, c_enc, c_boundaries, t_enc, t_boundaries):
# TODO: Update to also compute m2c_align_mat and m2t_align_mat
# enc: PackedSequencePlus with shape [batch, total len, recurrent size]
enc = batched_sequence.PackedSequencePlus.cat_seqs((q_enc, c_enc, t_enc))
q_enc_lengths = list(q_enc.orig_lengths())
c_enc_lengths = list(c_enc.orig_lengths())
t_enc_lengths = list(t_enc.orig_lengths())
enc_lengths = list(enc.orig_lengths())
max_enc_length = max(enc_lengths)
all_relations = []
for batch_idx, desc in enumerate(descs):
enc_length = enc_lengths[batch_idx]
relations_for_item = self.compute_relations(
desc,
enc_length,
q_enc_lengths[batch_idx],
c_enc_lengths[batch_idx],
c_boundaries[batch_idx],
t_boundaries[batch_idx])
all_relations.append(np.pad(relations_for_item, ((0, max_enc_length - enc_length),), 'constant'))
relations_t = torch.from_numpy(np.stack(all_relations)).to(self._device)
# mask shape: [batch, total len, total len]
mask = get_attn_mask(enc_lengths).to(self._device)
# enc_new: shape [batch, total len, recurrent size]
enc_padded, _ = enc.pad(batch_first=True)
enc_new = self.encoder(enc_padded, relations_t, mask=mask)
# Split enc_new again
def gather_from_enc_new(indices):
batch_indices, seq_indices = zip(*indices)
return enc_new[torch.LongTensor(batch_indices), torch.LongTensor(seq_indices)]
q_enc_new = batched_sequence.PackedSequencePlus.from_gather(
lengths=q_enc_lengths,
map_index=lambda batch_idx, seq_idx: (batch_idx, seq_idx),
gather_from_indices=gather_from_enc_new)
c_enc_new = batched_sequence.PackedSequencePlus.from_gather(
lengths=c_enc_lengths,
map_index=lambda batch_idx, seq_idx: (batch_idx, q_enc_lengths[batch_idx] + seq_idx),
gather_from_indices=gather_from_enc_new)
t_enc_new = batched_sequence.PackedSequencePlus.from_gather(
lengths=t_enc_lengths,
map_index=lambda batch_idx, seq_idx: (batch_idx, q_enc_lengths[batch_idx] + c_enc_lengths[batch_idx] + seq_idx),
gather_from_indices=gather_from_enc_new)
return q_enc_new, c_enc_new, t_enc_new
def compute_relations(self, desc, enc_length, q_enc_length, c_enc_length, c_boundaries, t_boundaries):
sc_link = desc.get('sc_link', {'q_col_match': {}, 'q_tab_match': {}})
cv_link = desc.get('cv_link', {'num_date_match': {}, 'cell_match': {}})
# Catalogue which things are where
loc_types = {}
for i in range(q_enc_length):
loc_types[i] = ('question',)
c_base = q_enc_length
for c_id, (c_start, c_end) in enumerate(zip(c_boundaries, c_boundaries[1:])):
for i in range(c_start + c_base, c_end + c_base):
loc_types[i] = ('column', c_id)
t_base = q_enc_length + c_enc_length
for t_id, (t_start, t_end) in enumerate(zip(t_boundaries, t_boundaries[1:])):
for i in range(t_start + t_base, t_end + t_base):
loc_types[i] = ('table', t_id)
relations = np.empty((enc_length, enc_length), dtype=np.int64)
for i, j in itertools.product(range(enc_length),repeat=2):
def set_relation(name):
relations[i, j] = self.relation_ids[name]
i_type, j_type = loc_types[i], loc_types[j]
if i_type[0] == 'question':
if j_type[0] == 'question':
set_relation(('qq_dist', clamp(j - i, self.qq_max_dist)))
elif j_type[0] == 'column':
# set_relation('qc_default')
j_real = j - c_base
if f"{i},{j_real}" in sc_link["q_col_match"]:
set_relation("qc" + sc_link["q_col_match"][f"{i},{j_real}"])
elif f"{i},{j_real}" in cv_link["cell_match"]:
set_relation("qc" + cv_link["cell_match"][f"{i},{j_real}"])
elif f"{i},{j_real}" in cv_link["num_date_match"]:
set_relation("qc" + cv_link["num_date_match"][f"{i},{j_real}"])
else:
set_relation('qc_default')
elif j_type[0] == 'table':
# set_relation('qt_default')
j_real = j - t_base
if f"{i},{j_real}" in sc_link["q_tab_match"]:
set_relation("qt" + sc_link["q_tab_match"][f"{i},{j_real}"])
else:
set_relation('qt_default')
elif i_type[0] == 'column':
if j_type[0] == 'question':
# set_relation('cq_default')
i_real = i - c_base
if f"{j},{i_real}" in sc_link["q_col_match"]:
set_relation("cq" + sc_link["q_col_match"][f"{j},{i_real}"])
elif f"{j},{i_real}" in cv_link["cell_match"]:
set_relation("cq" + cv_link["cell_match"][f"{j},{i_real}"])
elif f"{j},{i_real}" in cv_link["num_date_match"]:
set_relation("cq" + cv_link["num_date_match"][f"{j},{i_real}"])
else:
set_relation('cq_default')
elif j_type[0] == 'column':
col1, col2 = i_type[1], j_type[1]
if col1 == col2:
set_relation(('cc_dist', clamp(j - i, self.cc_max_dist)))
else:
set_relation('cc_default')
if self.cc_foreign_key:
if desc['foreign_keys'].get(str(col1)) == col2:
set_relation('cc_foreign_key_forward')
if desc['foreign_keys'].get(str(col2)) == col1:
set_relation('cc_foreign_key_backward')
if (self.cc_table_match and
desc['column_to_table'][str(col1)] == desc['column_to_table'][str(col2)]):
set_relation('cc_table_match')
elif j_type[0] == 'table':
col, table = i_type[1], j_type[1]
set_relation('ct_default')
if self.ct_foreign_key and self.match_foreign_key(desc, col, table):
set_relation('ct_foreign_key')
if self.ct_table_match:
col_table = desc['column_to_table'][str(col)]
if col_table == table:
if col in desc['primary_keys']:
set_relation('ct_primary_key')
else:
set_relation('ct_table_match')
elif col_table is None:
set_relation('ct_any_table')
elif i_type[0] == 'table':
if j_type[0] == 'question':
# set_relation('tq_default')
i_real = i - t_base
if f"{j},{i_real}" in sc_link["q_tab_match"]:
set_relation("tq" + sc_link["q_tab_match"][f"{j},{i_real}"])
else:
set_relation('tq_default')
elif j_type[0] == 'column':
table, col = i_type[1], j_type[1]
set_relation('tc_default')
if self.tc_foreign_key and self.match_foreign_key(desc, col, table):
set_relation('tc_foreign_key')
if self.tc_table_match:
col_table = desc['column_to_table'][str(col)]
if col_table == table:
if col in desc['primary_keys']:
set_relation('tc_primary_key')
else:
set_relation('tc_table_match')
elif col_table is None:
set_relation('tc_any_table')
elif j_type[0] == 'table':
table1, table2 = i_type[1], j_type[1]
if table1 == table2:
set_relation(('tt_dist', clamp(j - i, self.tt_max_dist)))
else:
set_relation('tt_default')
if self.tt_foreign_key:
forward = table2 in desc['foreign_keys_tables'].get(str(table1), ())
backward = table1 in desc['foreign_keys_tables'].get(str(table2), ())
if forward and backward:
set_relation('tt_foreign_key_both')
elif forward:
set_relation('tt_foreign_key_forward')
elif backward:
set_relation('tt_foreign_key_backward')
return relations
@classmethod
def match_foreign_key(cls, desc, col, table):
foreign_key_for = desc['foreign_keys'].get(str(col))
if foreign_key_for is None:
return False
foreign_table = desc['column_to_table'][str(foreign_key_for)]
return desc['column_to_table'][str(col)] == foreign_table
class NoOpUpdate:
def __init__(self, device, hidden_size):
pass
def __call__(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries):
#return q_enc.transpose(0, 1), c_enc.transpose(0, 1), t_enc.transpose(0, 1)
return q_enc, c_enc, t_enc
def forward_unbatched(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries):
"""
The same interface with RAT
return: encodings with size: length * embed_size, alignment matrix
"""
return q_enc.transpose(0,1), c_enc.transpose(0,1), t_enc.transpose(0,1), (None, None)