Spaces:
Runtime error
Runtime error
File size: 5,253 Bytes
f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import re
import json
import records
from typing import List, Dict
from sqlalchemy.exc import SQLAlchemyError
from utils.sql.all_keywords import ALL_KEY_WORDS
class WTQDBEngine:
def __init__(self, fdb):
self.db = records.Database('sqlite:///{}'.format(fdb))
self.conn = self.db.get_connection()
def execute_wtq_query(self, sql_query: str):
out = self.conn.query(sql_query)
results = out.all()
merged_results = []
for i in range(len(results)):
merged_results.extend(results[i].values())
return merged_results
def delete_rows(self, row_indices: List[int]):
sql_queries = [
"delete from w where id == {}".format(row) for row in row_indices
]
for query in sql_queries:
self.conn.query(query)
def process_table_structure(_wtq_table_content: Dict, _add_all_column: bool = False):
# remove id and agg column
headers = [_.replace("\n", " ").lower() for _ in _wtq_table_content["headers"][2:]]
header_map = {}
for i in range(len(headers)):
header_map["c" + str(i + 1)] = headers[i]
header_types = _wtq_table_content["types"][2:]
all_headers = []
all_header_types = []
vertical_content = []
for column_content in _wtq_table_content["contents"][2:]:
# only take the first one
if _add_all_column:
for i in range(len(column_content)):
column_alias = column_content[i]["col"]
# do not add the numbered column
if "_number" in column_alias:
continue
vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[i]["data"]])
if "_" in column_alias:
first_slash_pos = column_alias.find("_")
column_name = header_map[column_alias[:first_slash_pos]] + " " + \
column_alias[first_slash_pos + 1:].replace("_", " ")
else:
column_name = header_map[column_alias]
all_headers.append(column_name)
if column_content[i]["type"] == "TEXT":
all_header_types.append("text")
else:
all_header_types.append("number")
else:
vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[0]["data"]])
row_content = list(map(list, zip(*vertical_content)))
if _add_all_column:
ret_header = all_headers
ret_types = all_header_types
else:
ret_header = headers
ret_types = header_types
return {
"header": ret_header,
"rows": row_content,
"types": ret_types,
"alias": list(_wtq_table_content["is_list"].keys())
}
def retrieve_wtq_query_answer(_engine, _table_content, _sql_struct: List):
# do not append id / agg
headers = _table_content["header"]
def flatten_sql(_ex_sql_struct: List):
# [ "Keyword", "select", [] ], [ "Column", "c4", [] ]
_encode_sql = []
_execute_sql = []
for _ex_tuple in _ex_sql_struct:
keyword = str(_ex_tuple[1])
# upper the keywords.
if keyword in ALL_KEY_WORDS:
keyword = str(keyword).upper()
# extra column, which we do not need in result
if keyword == "w" or keyword == "from":
# add 'FROM w' make it executable
_encode_sql.append(keyword)
elif re.fullmatch(r"c\d+(_.+)?", keyword):
# only take the first part
index_key = int(keyword.split("_")[0][1:]) - 1
# wrap it with `` to make it executable
_encode_sql.append("`{}`".format(headers[index_key]))
else:
_encode_sql.append(keyword)
# c4_list, replace it with the original one
if "_address" in keyword or "_list" in keyword:
keyword = re.findall(r"c\d+", keyword)[0]
_execute_sql.append(keyword)
return " ".join(_execute_sql), " ".join(_encode_sql)
_exec_sql_str, _encode_sql_str = flatten_sql(_sql_struct)
try:
_sql_answers = _engine.execute_wtq_query(_exec_sql_str)
except SQLAlchemyError as e:
_sql_answers = []
_norm_sql_answers = [str(_).replace("\n", " ") for _ in _sql_answers if _ is not None]
if "none" in _norm_sql_answers:
_norm_sql_answers = []
return _encode_sql_str, _norm_sql_answers, _exec_sql_str
def _load_table_w_page(table_path, page_title_path=None) -> dict:
"""
attention: the table_path must be the .tsv path.
Load the WikiTableQuestion from csv file. Result in a dict format like:
{"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
"""
from utils.utils import _load_table
table_item = _load_table(table_path)
# Load page title
if not page_title_path:
page_title_path = table_path.replace("csv", "page").replace(".tsv", ".json")
with open(page_title_path, "r") as f:
page_title = json.load(f)['title']
table_item['page_title'] = page_title
return table_item
|