gap-text2sql
/
gap-text2sql-main
/relogic
/pretrainkit
/models
/relationalsemparse
/relational_transformer.py
import copy | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import entmax | |
import numpy as np | |
import itertools | |
def clamp(value, abs_max): | |
value = max(-abs_max, value) | |
value = min(abs_max, value) | |
return value | |
# Adapted from | |
# https://github.com/tensorflow/tensor2tensor/blob/0b156ac533ab53f65f44966381f6e147c7371eee/tensor2tensor/layers/common_attention.py | |
def relative_attention_logits(query, key, relation): | |
# We can't reuse the same logic as tensor2tensor because we don't share relation vectors across the batch. | |
# In this version, relation vectors are shared across heads. | |
# query: [batch, heads, num queries, depth]. | |
# key: [batch, heads, num kvs, depth]. | |
# relation: [batch, num queries, num kvs, depth]. | |
# qk_matmul is [batch, heads, num queries, num kvs] | |
qk_matmul = torch.matmul(query, key.transpose(-2, -1)) | |
# q_t is [batch, num queries, heads, depth] | |
q_t = query.permute(0, 2, 1, 3) | |
# r_t is [batch, num queries, depth, num kvs] | |
r_t = relation.transpose(-2, -1) | |
# [batch, num queries, heads, depth] | |
# * [batch, num queries, depth, num kvs] | |
# = [batch, num queries, heads, num kvs] | |
# For each batch and query, we have a query vector per head. | |
# We take its dot product with the relation vector for each kv. | |
q_tr_t_matmul = torch.matmul(q_t, r_t) | |
# qtr_t_matmul_t is [batch, heads, num queries, num kvs] | |
q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) | |
# [batch, heads, num queries, num kvs] | |
return (qk_matmul + q_tr_tmatmul_t) / math.sqrt(query.shape[-1]) | |
# Sharing relation vectors across batch and heads: | |
# query: [batch, heads, num queries, depth]. | |
# key: [batch, heads, num kvs, depth]. | |
# relation: [num queries, num kvs, depth]. | |
# | |
# Then take | |
# key reshaped | |
# [num queries, batch * heads, depth] | |
# relation.transpose(-2, -1) | |
# [num queries, depth, num kvs] | |
# and multiply them together. | |
# | |
# Without sharing relation vectors across heads: | |
# query: [batch, heads, num queries, depth]. | |
# key: [batch, heads, num kvs, depth]. | |
# relation: [batch, heads, num queries, num kvs, depth]. | |
# | |
# Then take | |
# key.unsqueeze(3) | |
# [batch, heads, num queries, 1, depth] | |
# relation.transpose(-2, -1) | |
# [batch, heads, num queries, depth, num kvs] | |
# and multiply them together: | |
# [batch, heads, num queries, 1, depth] | |
# * [batch, heads, num queries, depth, num kvs] | |
# = [batch, heads, num queries, 1, num kvs] | |
# and squeeze | |
# [batch, heads, num queries, num kvs] | |
def relative_attention_values(weight, value, relation): | |
# In this version, relation vectors are shared across heads. | |
# weight: [batch, heads, num queries, num kvs]. | |
# value: [batch, heads, num kvs, depth]. | |
# relation: [batch, num queries, num kvs, depth]. | |
# wv_matmul is [batch, heads, num queries, depth] | |
wv_matmul = torch.matmul(weight, value) | |
# w_t is [batch, num queries, heads, num kvs] | |
w_t = weight.permute(0, 2, 1, 3) | |
# [batch, num queries, heads, num kvs] | |
# * [batch, num queries, num kvs, depth] | |
# = [batch, num queries, heads, depth] | |
w_tr_matmul = torch.matmul(w_t, relation) | |
# w_tr_matmul_t is [batch, heads, num queries, depth] | |
w_tr_matmul_t = w_tr_matmul.permute(0, 2, 1, 3) | |
return wv_matmul + w_tr_matmul_t | |
# Adapted from The Annotated Transformer | |
def clones(module_fn, N): | |
return nn.ModuleList([module_fn() for _ in range(N)]) | |
def attention(query, key, value, mask=None, dropout=None): | |
"Compute 'Scaled Dot Product Attention'" | |
d_k = query.size(-1) | |
scores = torch.matmul(query, key.transpose(-2, -1)) \ | |
/ math.sqrt(d_k) | |
if mask is not None: | |
scores = scores.masked_fill(mask == 0, -1e9) | |
p_attn = F.softmax(scores, dim=-1) | |
if dropout is not None: | |
p_attn = dropout(p_attn) | |
# return torch.matmul(p_attn, value), scores.squeeze(1).squeeze(1) | |
return torch.matmul(p_attn, value), p_attn | |
def sparse_attention(query, key, value, alpha, mask=None, dropout=None): | |
"Compute 'Scaled Dot Product Attention'" | |
d_k = query.size(-1) | |
scores = torch.matmul(query, key.transpose(-2, -1)) \ | |
/ math.sqrt(d_k) | |
if mask is not None: | |
scores = scores.masked_fill(mask == 0, -1e9) | |
if alpha == 2: | |
p_attn = entmax.sparsemax(scores, -1) | |
elif alpha == 1.5: | |
p_attn = entmax.entmax15(scores, -1) | |
else: | |
raise NotImplementedError | |
if dropout is not None: | |
p_attn = dropout(p_attn) | |
# return torch.matmul(p_attn, value), scores.squeeze(1).squeeze(1) | |
return torch.matmul(p_attn, value), p_attn | |
# Adapted from The Annotated Transformers | |
class MultiHeadedAttention(nn.Module): | |
def __init__(self, h, d_model, dropout=0.1): | |
"Take in model size and number of heads." | |
super(MultiHeadedAttention, self).__init__() | |
assert d_model % h == 0 | |
# We assume d_v always equals d_k | |
self.d_k = d_model // h | |
self.h = h | |
self.linears = clones(lambda: nn.Linear(d_model, d_model), 4) | |
self.attn = None | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, query, key, value, mask=None): | |
"Implements Figure 2" | |
if mask is not None: | |
# Same mask applied to all h heads. | |
mask = mask.unsqueeze(1) | |
nbatches = query.size(0) | |
# 1) Do all the linear projections in batch from d_model => h x d_k | |
query, key, value = \ | |
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
for l, x in zip(self.linears, (query, key, value))] | |
# 2) Apply attention on all the projected vectors in batch. | |
x, self.attn = attention(query, key, value, mask=mask, | |
dropout=self.dropout) | |
# 3) "Concat" using a view and apply a final linear. | |
x = x.transpose(1, 2).contiguous() \ | |
.view(nbatches, -1, self.h * self.d_k) | |
if query.dim() == 3: | |
x = x.squeeze(1) | |
return self.linears[-1](x) | |
# Adapted from The Annotated Transformer | |
def attention_with_relations(query, key, value, relation_k, relation_v, mask=None, dropout=None): | |
"Compute 'Scaled Dot Product Attention'" | |
d_k = query.size(-1) | |
scores = relative_attention_logits(query, key, relation_k) | |
if mask is not None: | |
scores = scores.masked_fill(mask == 0, -1e9) | |
p_attn_orig = F.softmax(scores, dim=-1) | |
if dropout is not None: | |
p_attn = dropout(p_attn_orig) | |
return relative_attention_values(p_attn, value, relation_v), p_attn_orig | |
class PointerWithRelations(nn.Module): | |
def __init__(self, hidden_size, num_relation_kinds, dropout=0.2): | |
super(PointerWithRelations, self).__init__() | |
self.hidden_size = hidden_size | |
self.linears = clones(lambda: nn.Linear(hidden_size, hidden_size), 3) | |
self.attn = None | |
self.dropout = nn.Dropout(p=dropout) | |
self.relation_k_emb = nn.Embedding(num_relation_kinds, self.hidden_size) | |
self.relation_v_emb = nn.Embedding(num_relation_kinds, self.hidden_size) | |
def forward(self, query, key, value, relation, mask=None): | |
relation_k = self.relation_k_emb(relation) | |
relation_v = self.relation_v_emb(relation) | |
if mask is not None: | |
mask = mask.unsqueeze(0) | |
nbatches = query.size(0) | |
query, key, value = \ | |
[l(x).view(nbatches, -1, 1, self.hidden_size).transpose(1, 2) | |
for l, x in zip(self.linears, (query, key, value))] | |
_, self.attn = attention_with_relations( | |
query, | |
key, | |
value, | |
relation_k, | |
relation_v, | |
mask=mask, | |
dropout=self.dropout) | |
return self.attn[0, 0] | |
# Adapted from The Annotated Transformer | |
class MultiHeadedAttentionWithRelations(nn.Module): | |
def __init__(self, h, d_model, dropout=0.1): | |
"Take in model size and number of heads." | |
super(MultiHeadedAttentionWithRelations, self).__init__() | |
assert d_model % h == 0 | |
# We assume d_v always equals d_k | |
self.d_k = d_model // h | |
self.h = h | |
self.linears = clones(lambda: nn.Linear(d_model, d_model), 4) | |
self.attn = None | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, query, key, value, relation_k, relation_v, mask=None): | |
# query shape: [batch, num queries, d_model] | |
# key shape: [batch, num kv, d_model] | |
# value shape: [batch, num kv, d_model] | |
# relations_k shape: [batch, num queries, num kv, (d_model // h)] | |
# relations_v shape: [batch, num queries, num kv, (d_model // h)] | |
# mask shape: [batch, num queries, num kv] | |
if mask is not None: | |
# Same mask applied to all h heads. | |
# mask shape: [batch, 1, num queries, num kv] | |
mask = mask.unsqueeze(1) | |
nbatches = query.size(0) | |
# 1) Do all the linear projections in batch from d_model => h x d_k | |
query, key, value = \ | |
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
for l, x in zip(self.linears, (query, key, value))] | |
# 2) Apply attention on all the projected vectors in batch. | |
# x shape: [batch, heads, num queries, depth] | |
x, self.attn = attention_with_relations( | |
query, | |
key, | |
value, | |
relation_k, | |
relation_v, | |
mask=mask, | |
dropout=self.dropout) | |
# 3) "Concat" using a view and apply a final linear. | |
x = x.transpose(1, 2).contiguous() \ | |
.view(nbatches, -1, self.h * self.d_k) | |
return self.linears[-1](x) | |
# Adapted from The Annotated Transformer | |
class Encoder(nn.Module): | |
"Core encoder is a stack of N layers" | |
def __init__(self, layer, layer_size, N, tie_layers=False): | |
super(Encoder, self).__init__() | |
if tie_layers: | |
self.layer = layer() | |
self.layers = [self.layer for _ in range(N)] | |
else: | |
self.layers = clones(layer, N) | |
self.norm = nn.LayerNorm(layer_size) | |
# TODO initialize using xavier | |
def forward(self, x, relation, mask): | |
"Pass the input (and mask) through each layer in turn." | |
for layer in self.layers: | |
x = layer(x, relation, mask) | |
return self.norm(x) | |
# Adapted from The Annotated Transformer | |
class SublayerConnection(nn.Module): | |
""" | |
A residual connection followed by a layer norm. | |
Note for code simplicity the norm is first as opposed to last. | |
""" | |
def __init__(self, size, dropout): | |
super(SublayerConnection, self).__init__() | |
self.norm = nn.LayerNorm(size) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, sublayer): | |
"Apply residual connection to any sublayer with the same size." | |
return x + self.dropout(sublayer(self.norm(x))) | |
# Adapted from The Annotated Transformer | |
class EncoderLayer(nn.Module): | |
"Encoder is made up of self-attn and feed forward (defined below)" | |
def __init__(self, size, self_attn, feed_forward, num_relation_kinds, dropout): | |
super(EncoderLayer, self).__init__() | |
self.self_attn = self_attn | |
self.feed_forward = feed_forward | |
self.sublayer = clones(lambda: SublayerConnection(size, dropout), 2) | |
self.size = size | |
self.relation_k_emb = nn.Embedding(num_relation_kinds, self.self_attn.d_k) | |
self.relation_v_emb = nn.Embedding(num_relation_kinds, self.self_attn.d_k) | |
def forward(self, x, relation, mask): | |
"Follow Figure 1 (left) for connections." | |
relation_k = self.relation_k_emb(relation) | |
relation_v = self.relation_v_emb(relation) | |
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, relation_k, relation_v, mask)) | |
return self.sublayer[1](x, self.feed_forward) | |
# Adapted from The Annotated Transformer | |
class PositionwiseFeedForward(nn.Module): | |
"Implements FFN equation." | |
def __init__(self, d_model, d_ff, dropout=0.1): | |
super(PositionwiseFeedForward, self).__init__() | |
self.w_1 = nn.Linear(d_model, d_ff) | |
self.w_2 = nn.Linear(d_ff, d_model) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
return self.w_2(self.dropout(F.relu(self.w_1(x)))) | |
relation_ids = { | |
('qq_dist', -2): 0, ('qq_dist', -1): 1, ('qq_dist', 0): 2, | |
('qq_dist', 1): 3, ('qq_dist', 2): 4, 'qc_default': 5, | |
'qt_default': 6, 'cq_default': 7, 'cc_default': 8, | |
'cc_foreign_key_forward': 9, 'cc_foreign_key_backward': 10, | |
'cc_table_match': 11, ('cc_dist', -2): 12, ('cc_dist', -1): 13, | |
('cc_dist', 0): 14, ('cc_dist', 1): 15, ('cc_dist', 2): 16, | |
'ct_default': 17, 'ct_foreign_key': 18, 'ct_primary_key': 19, | |
'ct_table_match': 20, 'ct_any_table': 21, 'tq_default': 22, | |
'tc_default': 23, 'tc_primary_key': 24, 'tc_table_match': 25, | |
'tc_any_table': 26, 'tc_foreign_key': 27, 'tt_default': 28, | |
'tt_foreign_key_forward': 29, 'tt_foreign_key_backward': 30, | |
'tt_foreign_key_both': 31, ('tt_dist', -2): 32, ('tt_dist', -1): 33, | |
('tt_dist', 0): 34, ('tt_dist', 1): 35, ('tt_dist', 2): 36, 'qcCEM': 37, | |
'cqCEM': 38, 'qtTEM': 39, 'tqTEM': 40, 'qcCPM': 41, 'cqCPM': 42, | |
'qtTPM': 43, 'tqTPM': 44, 'qcNUMBER': 45, 'cqNUMBER': 46, 'qcTIME': 47, | |
'cqTIME': 48, 'qcCELLMATCH': 49, 'cqCELLMATCH': 50} | |
num_heads = 8 | |
hidden_size = 1024 | |
ff_size = 4096 | |
dropout = 0.1 | |
num_layers = 8 | |
tie_layers = False | |
encoder = Encoder( | |
lambda: EncoderLayer( | |
hidden_size, | |
MultiHeadedAttentionWithRelations( | |
num_heads, | |
hidden_size, | |
dropout), | |
PositionwiseFeedForward( | |
hidden_size, | |
ff_size, | |
dropout), | |
len(relation_ids), | |
dropout), | |
hidden_size, | |
num_layers, | |
tie_layers) | |
class RelationalTransformerUpdate(torch.nn.Module): | |
def __init__(self, num_layers, num_heads, hidden_size, | |
ff_size=None, | |
dropout=0.1, | |
merge_types=False, | |
tie_layers=False, | |
qq_max_dist=2, | |
# qc_token_match=True, | |
# qt_token_match=True, | |
# cq_token_match=True, | |
cc_foreign_key=True, | |
cc_table_match=True, | |
cc_max_dist=2, | |
ct_foreign_key=True, | |
ct_table_match=True, | |
# tq_token_match=True, | |
tc_table_match=True, | |
tc_foreign_key=True, | |
tt_max_dist=2, | |
tt_foreign_key=True, | |
sc_link=False, | |
cv_link=False, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
self.qq_max_dist = qq_max_dist | |
# self.qc_token_match = qc_token_match | |
# self.qt_token_match = qt_token_match | |
# self.cq_token_match = cq_token_match | |
self.cc_foreign_key = cc_foreign_key | |
self.cc_table_match = cc_table_match | |
self.cc_max_dist = cc_max_dist | |
self.ct_foreign_key = ct_foreign_key | |
self.ct_table_match = ct_table_match | |
# self.tq_token_match = tq_token_match | |
self.tc_table_match = tc_table_match | |
self.tc_foreign_key = tc_foreign_key | |
self.tt_max_dist = tt_max_dist | |
self.tt_foreign_key = tt_foreign_key | |
self.relation_ids = {} | |
def add_relation(name): | |
self.relation_ids[name] = len(self.relation_ids) | |
def add_rel_dist(name, max_dist): | |
for i in range(-max_dist, max_dist + 1): | |
add_relation((name, i)) | |
add_rel_dist('qq_dist', qq_max_dist) | |
add_relation('qc_default') | |
# if qc_token_match: | |
# add_relation('qc_token_match') | |
add_relation('qt_default') | |
# if qt_token_match: | |
# add_relation('qt_token_match') | |
add_relation('cq_default') | |
# if cq_token_match: | |
# add_relation('cq_token_match') | |
add_relation('cc_default') | |
if cc_foreign_key: | |
add_relation('cc_foreign_key_forward') | |
add_relation('cc_foreign_key_backward') | |
if cc_table_match: | |
add_relation('cc_table_match') | |
add_rel_dist('cc_dist', cc_max_dist) | |
add_relation('ct_default') | |
if ct_foreign_key: | |
add_relation('ct_foreign_key') | |
if ct_table_match: | |
add_relation('ct_primary_key') | |
add_relation('ct_table_match') | |
add_relation('ct_any_table') | |
add_relation('tq_default') | |
# if cq_token_match: | |
# add_relation('tq_token_match') | |
add_relation('tc_default') | |
if tc_table_match: | |
add_relation('tc_primary_key') | |
add_relation('tc_table_match') | |
add_relation('tc_any_table') | |
if tc_foreign_key: | |
add_relation('tc_foreign_key') | |
add_relation('tt_default') | |
if tt_foreign_key: | |
add_relation('tt_foreign_key_forward') | |
add_relation('tt_foreign_key_backward') | |
add_relation('tt_foreign_key_both') | |
add_rel_dist('tt_dist', tt_max_dist) | |
# schema linking relations | |
# forward_backward | |
if sc_link: | |
add_relation('qcCEM') | |
add_relation('cqCEM') | |
add_relation('qtTEM') | |
add_relation('tqTEM') | |
add_relation('qcCPM') | |
add_relation('cqCPM') | |
add_relation('qtTPM') | |
add_relation('tqTPM') | |
if cv_link: | |
add_relation("qcNUMBER") | |
add_relation("cqNUMBER") | |
add_relation("qcTIME") | |
add_relation("cqTIME") | |
add_relation("qcCELLMATCH") | |
add_relation("cqCELLMATCH") | |
if merge_types: | |
assert not cc_foreign_key | |
assert not cc_table_match | |
assert not ct_foreign_key | |
assert not ct_table_match | |
assert not tc_foreign_key | |
assert not tc_table_match | |
assert not tt_foreign_key | |
assert cc_max_dist == qq_max_dist | |
assert tt_max_dist == qq_max_dist | |
add_relation('xx_default') | |
self.relation_ids['qc_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['qt_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['cq_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['cc_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['ct_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['tq_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['tc_default'] = self.relation_ids['xx_default'] | |
self.relation_ids['tt_default'] = self.relation_ids['xx_default'] | |
if sc_link: | |
self.relation_ids['qcCEM'] = self.relation_ids['xx_default'] | |
self.relation_ids['qcCPM'] = self.relation_ids['xx_default'] | |
self.relation_ids['qtTEM'] = self.relation_ids['xx_default'] | |
self.relation_ids['qtTPM'] = self.relation_ids['xx_default'] | |
self.relation_ids['cqCEM'] = self.relation_ids['xx_default'] | |
self.relation_ids['cqCPM'] = self.relation_ids['xx_default'] | |
self.relation_ids['tqTEM'] = self.relation_ids['xx_default'] | |
self.relation_ids['tqTPM'] = self.relation_ids['xx_default'] | |
if cv_link: | |
self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default'] | |
self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default'] | |
self.relation_ids["qcTIME"] = self.relation_ids['xx_default'] | |
self.relation_ids["cqTIME"] = self.relation_ids['xx_default'] | |
self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default'] | |
self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default'] | |
for i in range(-qq_max_dist, qq_max_dist + 1): | |
self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i] | |
self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i] | |
if ff_size is None: | |
ff_size = hidden_size * 4 | |
self.encoder = Encoder( | |
lambda: EncoderLayer( | |
hidden_size, | |
MultiHeadedAttentionWithRelations( | |
num_heads, | |
hidden_size, | |
dropout), | |
PositionwiseFeedForward( | |
hidden_size, | |
ff_size, | |
dropout), | |
len(self.relation_ids), | |
dropout), | |
hidden_size, | |
num_layers, | |
tie_layers) | |
self.align_attn = PointerWithRelations(hidden_size, | |
len(self.relation_ids), dropout) | |
def create_align_mask(self, num_head, q_length, c_length, t_length): | |
# mask with size num_heads * all_len * all * len | |
all_length = q_length + c_length + t_length | |
mask_1 = torch.ones(num_head - 1, all_length, all_length, device=next(self.parameters()).device) | |
mask_2 = torch.zeros(1, all_length, all_length, device=next(self.parameters()).device) | |
for i in range(q_length): | |
for j in range(q_length, q_length + c_length): | |
mask_2[0, i, j] = 1 | |
mask_2[0, j, i] = 1 | |
mask = torch.cat([mask_1, mask_2], 0) | |
return mask | |
def forward_unbatched(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries): | |
# enc shape: total len x batch (=1) x recurrent size | |
enc = torch.cat((q_enc, c_enc, t_enc), dim=0) | |
# enc shape: batch (=1) x total len x recurrent size | |
enc = enc.transpose(0, 1) | |
# Catalogue which things are where | |
relations = self.compute_relations( | |
desc, | |
enc_length=enc.shape[1], | |
q_enc_length=q_enc.shape[0], | |
c_enc_length=c_enc.shape[0], | |
c_boundaries=c_boundaries, | |
t_boundaries=t_boundaries) | |
relations_t = torch.LongTensor(relations).to(next(self.parameters()).device) | |
enc_new = self.encoder(enc, relations_t, mask=None) | |
# Split updated_enc again | |
c_base = q_enc.shape[0] | |
t_base = q_enc.shape[0] + c_enc.shape[0] | |
q_enc_new = enc_new[:, :c_base] | |
c_enc_new = enc_new[:, c_base:t_base] | |
t_enc_new = enc_new[:, t_base:] | |
m2c_align_mat = self.align_attn(enc_new, enc_new[:, c_base:t_base], \ | |
enc_new[:, c_base:t_base], relations_t[:, c_base:t_base]) | |
m2t_align_mat = self.align_attn(enc_new, enc_new[:, t_base:], \ | |
enc_new[:, t_base:], relations_t[:, t_base:]) | |
return q_enc_new, c_enc_new, t_enc_new, (m2c_align_mat, m2t_align_mat) | |
def compute_relations(self, desc, enc_length, q_enc_length, c_enc_length, c_boundaries, t_boundaries): | |
sc_link = desc.get('sc_link', {'q_col_match': {}, 'q_tab_match': {}}) | |
cv_link = desc.get('cv_link', {'num_date_match': {}, 'cell_match': {}}) | |
# Catalogue which things are where | |
loc_types = {} | |
for i in range(q_enc_length): | |
loc_types[i] = ('question',) | |
c_base = q_enc_length | |
for c_id, (c_start, c_end) in enumerate(zip(c_boundaries, c_boundaries[1:])): | |
for i in range(c_start + c_base, c_end + c_base): | |
loc_types[i] = ('column', c_id) | |
t_base = q_enc_length + c_enc_length | |
for t_id, (t_start, t_end) in enumerate(zip(t_boundaries, t_boundaries[1:])): | |
for i in range(t_start + t_base, t_end + t_base): | |
loc_types[i] = ('table', t_id) | |
relations = np.empty((enc_length, enc_length), dtype=np.int64) | |
for i, j in itertools.product(range(enc_length), repeat=2): | |
def set_relation(name): | |
relations[i, j] = self.relation_ids[name] | |
i_type, j_type = loc_types[i], loc_types[j] | |
if i_type[0] == 'question': | |
if j_type[0] == 'question': | |
set_relation(('qq_dist', clamp(j - i, self.qq_max_dist))) | |
elif j_type[0] == 'column': | |
# set_relation('qc_default') | |
j_real = j - c_base | |
if f"{i},{j_real}" in sc_link["q_col_match"]: | |
set_relation("qc" + sc_link["q_col_match"][f"{i},{j_real}"]) | |
elif f"{i},{j_real}" in cv_link["cell_match"]: | |
set_relation("qc" + cv_link["cell_match"][f"{i},{j_real}"]) | |
elif f"{i},{j_real}" in cv_link["num_date_match"]: | |
set_relation("qc" + cv_link["num_date_match"][f"{i},{j_real}"]) | |
else: | |
set_relation('qc_default') | |
elif j_type[0] == 'table': | |
# set_relation('qt_default') | |
j_real = j - t_base | |
if f"{i},{j_real}" in sc_link["q_tab_match"]: | |
set_relation("qt" + sc_link["q_tab_match"][f"{i},{j_real}"]) | |
else: | |
set_relation('qt_default') | |
elif i_type[0] == 'column': | |
if j_type[0] == 'question': | |
# set_relation('cq_default') | |
i_real = i - c_base | |
if f"{j},{i_real}" in sc_link["q_col_match"]: | |
set_relation("cq" + sc_link["q_col_match"][f"{j},{i_real}"]) | |
elif f"{j},{i_real}" in cv_link["cell_match"]: | |
set_relation("cq" + cv_link["cell_match"][f"{j},{i_real}"]) | |
elif f"{j},{i_real}" in cv_link["num_date_match"]: | |
set_relation("cq" + cv_link["num_date_match"][f"{j},{i_real}"]) | |
else: | |
set_relation('cq_default') | |
elif j_type[0] == 'column': | |
col1, col2 = i_type[1], j_type[1] | |
if col1 == col2: | |
set_relation(('cc_dist', clamp(j - i, self.cc_max_dist))) | |
else: | |
set_relation('cc_default') | |
if self.cc_foreign_key: | |
if desc['foreign_keys'].get(str(col1)) == col2: | |
set_relation('cc_foreign_key_forward') | |
if desc['foreign_keys'].get(str(col2)) == col1: | |
set_relation('cc_foreign_key_backward') | |
if (self.cc_table_match and | |
desc['column_to_table'][str(col1)] == desc['column_to_table'][str(col2)]): | |
set_relation('cc_table_match') | |
elif j_type[0] == 'table': | |
col, table = i_type[1], j_type[1] | |
set_relation('ct_default') | |
if self.ct_foreign_key and self.match_foreign_key(desc, col, table): | |
set_relation('ct_foreign_key') | |
if self.ct_table_match: | |
col_table = desc['column_to_table'][str(col)] | |
if col_table == table: | |
if col in desc['primary_keys']: | |
set_relation('ct_primary_key') | |
else: | |
set_relation('ct_table_match') | |
elif col_table is None: | |
set_relation('ct_any_table') | |
elif i_type[0] == 'table': | |
if j_type[0] == 'question': | |
# set_relation('tq_default') | |
i_real = i - t_base | |
if f"{j},{i_real}" in sc_link["q_tab_match"]: | |
set_relation("tq" + sc_link["q_tab_match"][f"{j},{i_real}"]) | |
else: | |
set_relation('tq_default') | |
elif j_type[0] == 'column': | |
table, col = i_type[1], j_type[1] | |
set_relation('tc_default') | |
if self.tc_foreign_key and self.match_foreign_key(desc, col, table): | |
set_relation('tc_foreign_key') | |
if self.tc_table_match: | |
col_table = desc['column_to_table'][str(col)] | |
if col_table == table: | |
if col in desc['primary_keys']: | |
set_relation('tc_primary_key') | |
else: | |
set_relation('tc_table_match') | |
elif col_table is None: | |
set_relation('tc_any_table') | |
elif j_type[0] == 'table': | |
table1, table2 = i_type[1], j_type[1] | |
if table1 == table2: | |
set_relation(('tt_dist', clamp(j - i, self.tt_max_dist))) | |
else: | |
set_relation('tt_default') | |
if self.tt_foreign_key: | |
forward = table2 in desc['foreign_keys_tables'].get(str(table1), ()) | |
backward = table1 in desc['foreign_keys_tables'].get(str(table2), ()) | |
if forward and backward: | |
set_relation('tt_foreign_key_both') | |
elif forward: | |
set_relation('tt_foreign_key_forward') | |
elif backward: | |
set_relation('tt_foreign_key_backward') | |
return relations | |
def match_foreign_key(cls, desc, col, table): | |
foreign_key_for = desc['foreign_keys'].get(str(col)) | |
if foreign_key_for is None: | |
return False | |
foreign_table = desc['column_to_table'][str(foreign_key_for)] | |
return desc['column_to_table'][str(col)] == foreign_table | |