antonlabate
ver 1.3
d758c99
import os
import string
import sqlite3
import nltk.corpus
import re
#from seq2struct.resources import corenlp
STOPWORDS = set(nltk.corpus.stopwords.words('english'))
PUNKS = set(a for a in string.punctuation)
# schema linking, similar to IRNet
def compute_schema_linking(question, column, table):
def partial_match(x_list, y_list):
x_str = " ".join(x_list)
y_str = " ".join(y_list)
if x_str in STOPWORDS or x_str in PUNKS:
return False
if re.match(rf"\b{re.escape(x_str)}\b", y_str):
assert x_str in y_str
return True
else:
return False
def exact_match(x_list, y_list):
x_str = " ".join(x_list)
y_str = " ".join(y_list)
if x_str == y_str:
return True
else:
return False
q_col_match = dict()
q_tab_match = dict()
col_id2list = dict()
for col_id, col_item in enumerate(column):
if col_id == 0:
continue
col_id2list[col_id] = col_item
tab_id2list = dict()
for tab_id, tab_item in enumerate(table):
tab_id2list[tab_id] = tab_item
# 5-gram
n = 5
while n > 0:
for i in range(len(question) - n + 1):
n_gram_list = question[i:i + n]
n_gram = " ".join(n_gram_list)
if len(n_gram.strip()) == 0:
continue
# exact match case
for col_id in col_id2list:
if exact_match(n_gram_list, col_id2list[col_id]):
for q_id in range(i, i+n):
q_col_match[f"{q_id},{col_id}"] = "CEM"
for tab_id in tab_id2list:
if exact_match(n_gram_list, tab_id2list[tab_id]):
for q_id in range(i, i+n):
q_tab_match[f"{q_id},{tab_id}"] = "TEM"
# partial match case
for col_id in col_id2list:
if partial_match(n_gram_list, col_id2list[col_id]):
for q_id in range(i, i+n):
if f"{q_id},{col_id}" not in q_col_match:
q_col_match[f"{q_id},{col_id}"] = "CPM"
for tab_id in tab_id2list:
if partial_match(n_gram_list, tab_id2list[tab_id]):
for q_id in range(i, i+n):
if f"{q_id},{tab_id}" not in q_tab_match:
q_tab_match[f"{q_id},{tab_id}"] = "TPM"
n -= 1
return {"q_col_match": q_col_match, "q_tab_match": q_tab_match }
def compute_cell_value_linking(tokens, schema, db_dir):
def isnumber(word):
try:
float(word)
return True
except:
return False
def db_word_match(word, column, table, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
p_str = f"select {column} from {table} where {column} like '{word} %' or {column} like '% {word}' or {column} like '% {word} %' or {column} like '{word}'"
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
if len(p_res) == 0:
return False
else:
return p_res
except:
return False
db_name = schema.db_id
db_path = os.path.join(db_dir, db_name, db_name + '.sqlite')
num_date_match = {}
cell_match = {}
for q_id, word in enumerate(tokens):
if len(word.strip()) == 0:
continue
if word in STOPWORDS or word in PUNKS:
continue
num_flag = isnumber(word)
for col_id, column in enumerate(schema.columns):
if col_id == 0:
assert column.orig_name == "*"
continue
# word is number
if num_flag:
if column.type in ["number", "time"]: # TODO fine-grained date
num_date_match[f"{q_id},{col_id}"] = column.type.upper()
else:
ret = db_word_match(word, column.orig_name, column.table.orig_name, db_path)
if ret:
# print(word, ret)
cell_match[f"{q_id},{col_id}"] = "CELLMATCH"
cv_link = {"num_date_match": num_date_match, "cell_match" : cell_match, "normalized_token": tokens}
return cv_link