import copy import math import torch import torch.nn as nn import torch.nn.functional as F import entmax import numpy as np import itertools def clamp(value, abs_max): value = max(-abs_max, value) value = min(abs_max, value) return value # Adapted from # https://github.com/tensorflow/tensor2tensor/blob/0b156ac533ab53f65f44966381f6e147c7371eee/tensor2tensor/layers/common_attention.py def relative_attention_logits(query, key, relation): # We can't reuse the same logic as tensor2tensor because we don't share relation vectors across the batch. # In this version, relation vectors are shared across heads. # query: [batch, heads, num queries, depth]. # key: [batch, heads, num kvs, depth]. # relation: [batch, num queries, num kvs, depth]. # qk_matmul is [batch, heads, num queries, num kvs] qk_matmul = torch.matmul(query, key.transpose(-2, -1)) # q_t is [batch, num queries, heads, depth] q_t = query.permute(0, 2, 1, 3) # r_t is [batch, num queries, depth, num kvs] r_t = relation.transpose(-2, -1) # [batch, num queries, heads, depth] # * [batch, num queries, depth, num kvs] # = [batch, num queries, heads, num kvs] # For each batch and query, we have a query vector per head. # We take its dot product with the relation vector for each kv. q_tr_t_matmul = torch.matmul(q_t, r_t) # qtr_t_matmul_t is [batch, heads, num queries, num kvs] q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, heads, num queries, num kvs] return (qk_matmul + q_tr_tmatmul_t) / math.sqrt(query.shape[-1]) # Sharing relation vectors across batch and heads: # query: [batch, heads, num queries, depth]. # key: [batch, heads, num kvs, depth]. # relation: [num queries, num kvs, depth]. # # Then take # key reshaped # [num queries, batch * heads, depth] # relation.transpose(-2, -1) # [num queries, depth, num kvs] # and multiply them together. # # Without sharing relation vectors across heads: # query: [batch, heads, num queries, depth]. # key: [batch, heads, num kvs, depth]. # relation: [batch, heads, num queries, num kvs, depth]. # # Then take # key.unsqueeze(3) # [batch, heads, num queries, 1, depth] # relation.transpose(-2, -1) # [batch, heads, num queries, depth, num kvs] # and multiply them together: # [batch, heads, num queries, 1, depth] # * [batch, heads, num queries, depth, num kvs] # = [batch, heads, num queries, 1, num kvs] # and squeeze # [batch, heads, num queries, num kvs] def relative_attention_values(weight, value, relation): # In this version, relation vectors are shared across heads. # weight: [batch, heads, num queries, num kvs]. # value: [batch, heads, num kvs, depth]. # relation: [batch, num queries, num kvs, depth]. # wv_matmul is [batch, heads, num queries, depth] wv_matmul = torch.matmul(weight, value) # w_t is [batch, num queries, heads, num kvs] w_t = weight.permute(0, 2, 1, 3) # [batch, num queries, heads, num kvs] # * [batch, num queries, num kvs, depth] # = [batch, num queries, heads, depth] w_tr_matmul = torch.matmul(w_t, relation) # w_tr_matmul_t is [batch, heads, num queries, depth] w_tr_matmul_t = w_tr_matmul.permute(0, 2, 1, 3) return wv_matmul + w_tr_matmul_t # Adapted from The Annotated Transformer def clones(module_fn, N): return nn.ModuleList([module_fn() for _ in range(N)]) def attention(query, key, value, mask=None, dropout=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) # return torch.matmul(p_attn, value), scores.squeeze(1).squeeze(1) return torch.matmul(p_attn, value), p_attn def sparse_attention(query, key, value, alpha, mask=None, dropout=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) if alpha == 2: p_attn = entmax.sparsemax(scores, -1) elif alpha == 1.5: p_attn = entmax.entmax15(scores, -1) else: raise NotImplementedError if dropout is not None: p_attn = dropout(p_attn) # return torch.matmul(p_attn, value), scores.squeeze(1).squeeze(1) return torch.matmul(p_attn, value), p_attn # Adapted from The Annotated Transformers class MultiHeadedAttention(nn.Module): def __init__(self, h, d_model, dropout=0.1): "Take in model size and number of heads." super(MultiHeadedAttention, self).__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(lambda: nn.Linear(d_model, d_model), 4) self.attn = None self.dropout = nn.Dropout(p=dropout) def forward(self, query, key, value, mask=None): "Implements Figure 2" if mask is not None: # Same mask applied to all h heads. mask = mask.unsqueeze(1) nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) if query.dim() == 3: x = x.squeeze(1) return self.linears[-1](x) # Adapted from The Annotated Transformer def attention_with_relations(query, key, value, relation_k, relation_v, mask=None, dropout=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = relative_attention_logits(query, key, relation_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn_orig = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn_orig) return relative_attention_values(p_attn, value, relation_v), p_attn_orig class PointerWithRelations(nn.Module): def __init__(self, hidden_size, num_relation_kinds, dropout=0.2): super(PointerWithRelations, self).__init__() self.hidden_size = hidden_size self.linears = clones(lambda: nn.Linear(hidden_size, hidden_size), 3) self.attn = None self.dropout = nn.Dropout(p=dropout) self.relation_k_emb = nn.Embedding(num_relation_kinds, self.hidden_size) self.relation_v_emb = nn.Embedding(num_relation_kinds, self.hidden_size) def forward(self, query, key, value, relation, mask=None): relation_k = self.relation_k_emb(relation) relation_v = self.relation_v_emb(relation) if mask is not None: mask = mask.unsqueeze(0) nbatches = query.size(0) query, key, value = \ [l(x).view(nbatches, -1, 1, self.hidden_size).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] _, self.attn = attention_with_relations( query, key, value, relation_k, relation_v, mask=mask, dropout=self.dropout) return self.attn[0, 0] # Adapted from The Annotated Transformer class MultiHeadedAttentionWithRelations(nn.Module): def __init__(self, h, d_model, dropout=0.1): "Take in model size and number of heads." super(MultiHeadedAttentionWithRelations, self).__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(lambda: nn.Linear(d_model, d_model), 4) self.attn = None self.dropout = nn.Dropout(p=dropout) def forward(self, query, key, value, relation_k, relation_v, mask=None): # query shape: [batch, num queries, d_model] # key shape: [batch, num kv, d_model] # value shape: [batch, num kv, d_model] # relations_k shape: [batch, num queries, num kv, (d_model // h)] # relations_v shape: [batch, num queries, num kv, (d_model // h)] # mask shape: [batch, num queries, num kv] if mask is not None: # Same mask applied to all h heads. # mask shape: [batch, 1, num queries, num kv] mask = mask.unsqueeze(1) nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # 2) Apply attention on all the projected vectors in batch. # x shape: [batch, heads, num queries, depth] x, self.attn = attention_with_relations( query, key, value, relation_k, relation_v, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x) # Adapted from The Annotated Transformer class Encoder(nn.Module): "Core encoder is a stack of N layers" def __init__(self, layer, layer_size, N, tie_layers=False): super(Encoder, self).__init__() if tie_layers: self.layer = layer() self.layers = [self.layer for _ in range(N)] else: self.layers = clones(layer, N) self.norm = nn.LayerNorm(layer_size) # TODO initialize using xavier def forward(self, x, relation, mask): "Pass the input (and mask) through each layer in turn." for layer in self.layers: x = layer(x, relation, mask) return self.norm(x) # Adapted from The Annotated Transformer class SublayerConnection(nn.Module): """ A residual connection followed by a layer norm. Note for code simplicity the norm is first as opposed to last. """ def __init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm = nn.LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, sublayer): "Apply residual connection to any sublayer with the same size." return x + self.dropout(sublayer(self.norm(x))) # Adapted from The Annotated Transformer class EncoderLayer(nn.Module): "Encoder is made up of self-attn and feed forward (defined below)" def __init__(self, size, self_attn, feed_forward, num_relation_kinds, dropout): super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.sublayer = clones(lambda: SublayerConnection(size, dropout), 2) self.size = size self.relation_k_emb = nn.Embedding(num_relation_kinds, self.self_attn.d_k) self.relation_v_emb = nn.Embedding(num_relation_kinds, self.self_attn.d_k) def forward(self, x, relation, mask): "Follow Figure 1 (left) for connections." relation_k = self.relation_k_emb(relation) relation_v = self.relation_v_emb(relation) x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, relation_k, relation_v, mask)) return self.sublayer[1](x, self.feed_forward) # Adapted from The Annotated Transformer class PositionwiseFeedForward(nn.Module): "Implements FFN equation." def __init__(self, d_model, d_ff, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w_2(self.dropout(F.relu(self.w_1(x)))) relation_ids = { ('qq_dist', -2): 0, ('qq_dist', -1): 1, ('qq_dist', 0): 2, ('qq_dist', 1): 3, ('qq_dist', 2): 4, 'qc_default': 5, 'qt_default': 6, 'cq_default': 7, 'cc_default': 8, 'cc_foreign_key_forward': 9, 'cc_foreign_key_backward': 10, 'cc_table_match': 11, ('cc_dist', -2): 12, ('cc_dist', -1): 13, ('cc_dist', 0): 14, ('cc_dist', 1): 15, ('cc_dist', 2): 16, 'ct_default': 17, 'ct_foreign_key': 18, 'ct_primary_key': 19, 'ct_table_match': 20, 'ct_any_table': 21, 'tq_default': 22, 'tc_default': 23, 'tc_primary_key': 24, 'tc_table_match': 25, 'tc_any_table': 26, 'tc_foreign_key': 27, 'tt_default': 28, 'tt_foreign_key_forward': 29, 'tt_foreign_key_backward': 30, 'tt_foreign_key_both': 31, ('tt_dist', -2): 32, ('tt_dist', -1): 33, ('tt_dist', 0): 34, ('tt_dist', 1): 35, ('tt_dist', 2): 36, 'qcCEM': 37, 'cqCEM': 38, 'qtTEM': 39, 'tqTEM': 40, 'qcCPM': 41, 'cqCPM': 42, 'qtTPM': 43, 'tqTPM': 44, 'qcNUMBER': 45, 'cqNUMBER': 46, 'qcTIME': 47, 'cqTIME': 48, 'qcCELLMATCH': 49, 'cqCELLMATCH': 50} num_heads = 8 hidden_size = 1024 ff_size = 4096 dropout = 0.1 num_layers = 8 tie_layers = False encoder = Encoder( lambda: EncoderLayer( hidden_size, MultiHeadedAttentionWithRelations( num_heads, hidden_size, dropout), PositionwiseFeedForward( hidden_size, ff_size, dropout), len(relation_ids), dropout), hidden_size, num_layers, tie_layers) class RelationalTransformerUpdate(torch.nn.Module): def __init__(self, num_layers, num_heads, hidden_size, ff_size=None, dropout=0.1, merge_types=False, tie_layers=False, qq_max_dist=2, # qc_token_match=True, # qt_token_match=True, # cq_token_match=True, cc_foreign_key=True, cc_table_match=True, cc_max_dist=2, ct_foreign_key=True, ct_table_match=True, # tq_token_match=True, tc_table_match=True, tc_foreign_key=True, tt_max_dist=2, tt_foreign_key=True, sc_link=False, cv_link=False, ): super().__init__() self.num_heads = num_heads self.qq_max_dist = qq_max_dist # self.qc_token_match = qc_token_match # self.qt_token_match = qt_token_match # self.cq_token_match = cq_token_match self.cc_foreign_key = cc_foreign_key self.cc_table_match = cc_table_match self.cc_max_dist = cc_max_dist self.ct_foreign_key = ct_foreign_key self.ct_table_match = ct_table_match # self.tq_token_match = tq_token_match self.tc_table_match = tc_table_match self.tc_foreign_key = tc_foreign_key self.tt_max_dist = tt_max_dist self.tt_foreign_key = tt_foreign_key self.relation_ids = {} def add_relation(name): self.relation_ids[name] = len(self.relation_ids) def add_rel_dist(name, max_dist): for i in range(-max_dist, max_dist + 1): add_relation((name, i)) add_rel_dist('qq_dist', qq_max_dist) add_relation('qc_default') # if qc_token_match: # add_relation('qc_token_match') add_relation('qt_default') # if qt_token_match: # add_relation('qt_token_match') add_relation('cq_default') # if cq_token_match: # add_relation('cq_token_match') add_relation('cc_default') if cc_foreign_key: add_relation('cc_foreign_key_forward') add_relation('cc_foreign_key_backward') if cc_table_match: add_relation('cc_table_match') add_rel_dist('cc_dist', cc_max_dist) add_relation('ct_default') if ct_foreign_key: add_relation('ct_foreign_key') if ct_table_match: add_relation('ct_primary_key') add_relation('ct_table_match') add_relation('ct_any_table') add_relation('tq_default') # if cq_token_match: # add_relation('tq_token_match') add_relation('tc_default') if tc_table_match: add_relation('tc_primary_key') add_relation('tc_table_match') add_relation('tc_any_table') if tc_foreign_key: add_relation('tc_foreign_key') add_relation('tt_default') if tt_foreign_key: add_relation('tt_foreign_key_forward') add_relation('tt_foreign_key_backward') add_relation('tt_foreign_key_both') add_rel_dist('tt_dist', tt_max_dist) # schema linking relations # forward_backward if sc_link: add_relation('qcCEM') add_relation('cqCEM') add_relation('qtTEM') add_relation('tqTEM') add_relation('qcCPM') add_relation('cqCPM') add_relation('qtTPM') add_relation('tqTPM') if cv_link: add_relation("qcNUMBER") add_relation("cqNUMBER") add_relation("qcTIME") add_relation("cqTIME") add_relation("qcCELLMATCH") add_relation("cqCELLMATCH") if merge_types: assert not cc_foreign_key assert not cc_table_match assert not ct_foreign_key assert not ct_table_match assert not tc_foreign_key assert not tc_table_match assert not tt_foreign_key assert cc_max_dist == qq_max_dist assert tt_max_dist == qq_max_dist add_relation('xx_default') self.relation_ids['qc_default'] = self.relation_ids['xx_default'] self.relation_ids['qt_default'] = self.relation_ids['xx_default'] self.relation_ids['cq_default'] = self.relation_ids['xx_default'] self.relation_ids['cc_default'] = self.relation_ids['xx_default'] self.relation_ids['ct_default'] = self.relation_ids['xx_default'] self.relation_ids['tq_default'] = self.relation_ids['xx_default'] self.relation_ids['tc_default'] = self.relation_ids['xx_default'] self.relation_ids['tt_default'] = self.relation_ids['xx_default'] if sc_link: self.relation_ids['qcCEM'] = self.relation_ids['xx_default'] self.relation_ids['qcCPM'] = self.relation_ids['xx_default'] self.relation_ids['qtTEM'] = self.relation_ids['xx_default'] self.relation_ids['qtTPM'] = self.relation_ids['xx_default'] self.relation_ids['cqCEM'] = self.relation_ids['xx_default'] self.relation_ids['cqCPM'] = self.relation_ids['xx_default'] self.relation_ids['tqTEM'] = self.relation_ids['xx_default'] self.relation_ids['tqTPM'] = self.relation_ids['xx_default'] if cv_link: self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["qcTIME"] = self.relation_ids['xx_default'] self.relation_ids["cqTIME"] = self.relation_ids['xx_default'] self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default'] self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default'] for i in range(-qq_max_dist, qq_max_dist + 1): self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i] self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i] if ff_size is None: ff_size = hidden_size * 4 self.encoder = Encoder( lambda: EncoderLayer( hidden_size, MultiHeadedAttentionWithRelations( num_heads, hidden_size, dropout), PositionwiseFeedForward( hidden_size, ff_size, dropout), len(self.relation_ids), dropout), hidden_size, num_layers, tie_layers) self.align_attn = PointerWithRelations(hidden_size, len(self.relation_ids), dropout) def create_align_mask(self, num_head, q_length, c_length, t_length): # mask with size num_heads * all_len * all * len all_length = q_length + c_length + t_length mask_1 = torch.ones(num_head - 1, all_length, all_length, device=next(self.parameters()).device) mask_2 = torch.zeros(1, all_length, all_length, device=next(self.parameters()).device) for i in range(q_length): for j in range(q_length, q_length + c_length): mask_2[0, i, j] = 1 mask_2[0, j, i] = 1 mask = torch.cat([mask_1, mask_2], 0) return mask def forward_unbatched(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries): # enc shape: total len x batch (=1) x recurrent size enc = torch.cat((q_enc, c_enc, t_enc), dim=0) # enc shape: batch (=1) x total len x recurrent size enc = enc.transpose(0, 1) # Catalogue which things are where relations = self.compute_relations( desc, enc_length=enc.shape[1], q_enc_length=q_enc.shape[0], c_enc_length=c_enc.shape[0], c_boundaries=c_boundaries, t_boundaries=t_boundaries) relations_t = torch.LongTensor(relations).to(next(self.parameters()).device) enc_new = self.encoder(enc, relations_t, mask=None) # Split updated_enc again c_base = q_enc.shape[0] t_base = q_enc.shape[0] + c_enc.shape[0] q_enc_new = enc_new[:, :c_base] c_enc_new = enc_new[:, c_base:t_base] t_enc_new = enc_new[:, t_base:] m2c_align_mat = self.align_attn(enc_new, enc_new[:, c_base:t_base], \ enc_new[:, c_base:t_base], relations_t[:, c_base:t_base]) m2t_align_mat = self.align_attn(enc_new, enc_new[:, t_base:], \ enc_new[:, t_base:], relations_t[:, t_base:]) return q_enc_new, c_enc_new, t_enc_new, (m2c_align_mat, m2t_align_mat) def compute_relations(self, desc, enc_length, q_enc_length, c_enc_length, c_boundaries, t_boundaries): sc_link = desc.get('sc_link', {'q_col_match': {}, 'q_tab_match': {}}) cv_link = desc.get('cv_link', {'num_date_match': {}, 'cell_match': {}}) # Catalogue which things are where loc_types = {} for i in range(q_enc_length): loc_types[i] = ('question',) c_base = q_enc_length for c_id, (c_start, c_end) in enumerate(zip(c_boundaries, c_boundaries[1:])): for i in range(c_start + c_base, c_end + c_base): loc_types[i] = ('column', c_id) t_base = q_enc_length + c_enc_length for t_id, (t_start, t_end) in enumerate(zip(t_boundaries, t_boundaries[1:])): for i in range(t_start + t_base, t_end + t_base): loc_types[i] = ('table', t_id) relations = np.empty((enc_length, enc_length), dtype=np.int64) for i, j in itertools.product(range(enc_length), repeat=2): def set_relation(name): relations[i, j] = self.relation_ids[name] i_type, j_type = loc_types[i], loc_types[j] if i_type[0] == 'question': if j_type[0] == 'question': set_relation(('qq_dist', clamp(j - i, self.qq_max_dist))) elif j_type[0] == 'column': # set_relation('qc_default') j_real = j - c_base if f"{i},{j_real}" in sc_link["q_col_match"]: set_relation("qc" + sc_link["q_col_match"][f"{i},{j_real}"]) elif f"{i},{j_real}" in cv_link["cell_match"]: set_relation("qc" + cv_link["cell_match"][f"{i},{j_real}"]) elif f"{i},{j_real}" in cv_link["num_date_match"]: set_relation("qc" + cv_link["num_date_match"][f"{i},{j_real}"]) else: set_relation('qc_default') elif j_type[0] == 'table': # set_relation('qt_default') j_real = j - t_base if f"{i},{j_real}" in sc_link["q_tab_match"]: set_relation("qt" + sc_link["q_tab_match"][f"{i},{j_real}"]) else: set_relation('qt_default') elif i_type[0] == 'column': if j_type[0] == 'question': # set_relation('cq_default') i_real = i - c_base if f"{j},{i_real}" in sc_link["q_col_match"]: set_relation("cq" + sc_link["q_col_match"][f"{j},{i_real}"]) elif f"{j},{i_real}" in cv_link["cell_match"]: set_relation("cq" + cv_link["cell_match"][f"{j},{i_real}"]) elif f"{j},{i_real}" in cv_link["num_date_match"]: set_relation("cq" + cv_link["num_date_match"][f"{j},{i_real}"]) else: set_relation('cq_default') elif j_type[0] == 'column': col1, col2 = i_type[1], j_type[1] if col1 == col2: set_relation(('cc_dist', clamp(j - i, self.cc_max_dist))) else: set_relation('cc_default') if self.cc_foreign_key: if desc['foreign_keys'].get(str(col1)) == col2: set_relation('cc_foreign_key_forward') if desc['foreign_keys'].get(str(col2)) == col1: set_relation('cc_foreign_key_backward') if (self.cc_table_match and desc['column_to_table'][str(col1)] == desc['column_to_table'][str(col2)]): set_relation('cc_table_match') elif j_type[0] == 'table': col, table = i_type[1], j_type[1] set_relation('ct_default') if self.ct_foreign_key and self.match_foreign_key(desc, col, table): set_relation('ct_foreign_key') if self.ct_table_match: col_table = desc['column_to_table'][str(col)] if col_table == table: if col in desc['primary_keys']: set_relation('ct_primary_key') else: set_relation('ct_table_match') elif col_table is None: set_relation('ct_any_table') elif i_type[0] == 'table': if j_type[0] == 'question': # set_relation('tq_default') i_real = i - t_base if f"{j},{i_real}" in sc_link["q_tab_match"]: set_relation("tq" + sc_link["q_tab_match"][f"{j},{i_real}"]) else: set_relation('tq_default') elif j_type[0] == 'column': table, col = i_type[1], j_type[1] set_relation('tc_default') if self.tc_foreign_key and self.match_foreign_key(desc, col, table): set_relation('tc_foreign_key') if self.tc_table_match: col_table = desc['column_to_table'][str(col)] if col_table == table: if col in desc['primary_keys']: set_relation('tc_primary_key') else: set_relation('tc_table_match') elif col_table is None: set_relation('tc_any_table') elif j_type[0] == 'table': table1, table2 = i_type[1], j_type[1] if table1 == table2: set_relation(('tt_dist', clamp(j - i, self.tt_max_dist))) else: set_relation('tt_default') if self.tt_foreign_key: forward = table2 in desc['foreign_keys_tables'].get(str(table1), ()) backward = table1 in desc['foreign_keys_tables'].get(str(table2), ()) if forward and backward: set_relation('tt_foreign_key_both') elif forward: set_relation('tt_foreign_key_forward') elif backward: set_relation('tt_foreign_key_backward') return relations @classmethod def match_foreign_key(cls, desc, col, table): foreign_key_for = desc['foreign_keys'].get(str(col)) if foreign_key_for is None: return False foreign_table = desc['column_to_table'][str(foreign_key_for)] return desc['column_to_table'][str(col)] == foreign_table