File size: 5,253 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import re
import json
import records
from typing import List, Dict
from sqlalchemy.exc import SQLAlchemyError
from utils.sql.all_keywords import ALL_KEY_WORDS


class WTQDBEngine:
    def __init__(self, fdb):
        self.db = records.Database('sqlite:///{}'.format(fdb))
        self.conn = self.db.get_connection()

    def execute_wtq_query(self, sql_query: str):
        out = self.conn.query(sql_query)
        results = out.all()
        merged_results = []
        for i in range(len(results)):
            merged_results.extend(results[i].values())
        return merged_results

    def delete_rows(self, row_indices: List[int]):
        sql_queries = [
            "delete from w where id == {}".format(row) for row in row_indices
        ]
        for query in sql_queries:
            self.conn.query(query)


def process_table_structure(_wtq_table_content: Dict, _add_all_column: bool = False):
    # remove id and agg column
    headers = [_.replace("\n", " ").lower() for _ in _wtq_table_content["headers"][2:]]
    header_map = {}
    for i in range(len(headers)):
        header_map["c" + str(i + 1)] = headers[i]
    header_types = _wtq_table_content["types"][2:]

    all_headers = []
    all_header_types = []
    vertical_content = []
    for column_content in _wtq_table_content["contents"][2:]:
        # only take the first one
        if _add_all_column:
            for i in range(len(column_content)):
                column_alias = column_content[i]["col"]
                # do not add the numbered column
                if "_number" in column_alias:
                    continue
                vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[i]["data"]])
                if "_" in column_alias:
                    first_slash_pos = column_alias.find("_")
                    column_name = header_map[column_alias[:first_slash_pos]] + " " + \
                                  column_alias[first_slash_pos + 1:].replace("_", " ")
                else:
                    column_name = header_map[column_alias]
                all_headers.append(column_name)
                if column_content[i]["type"] == "TEXT":
                    all_header_types.append("text")
                else:
                    all_header_types.append("number")
        else:
            vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[0]["data"]])
    row_content = list(map(list, zip(*vertical_content)))

    if _add_all_column:
        ret_header = all_headers
        ret_types = all_header_types
    else:
        ret_header = headers
        ret_types = header_types
    return {
        "header": ret_header,
        "rows": row_content,
        "types": ret_types,
        "alias": list(_wtq_table_content["is_list"].keys())
    }


def retrieve_wtq_query_answer(_engine, _table_content, _sql_struct: List):
    # do not append id / agg
    headers = _table_content["header"]

    def flatten_sql(_ex_sql_struct: List):
        # [ "Keyword", "select", [] ], [ "Column", "c4", [] ]
        _encode_sql = []
        _execute_sql = []
        for _ex_tuple in _ex_sql_struct:
            keyword = str(_ex_tuple[1])
            # upper the keywords.
            if keyword in ALL_KEY_WORDS:
                keyword = str(keyword).upper()

            # extra column, which we do not need in result
            if keyword == "w" or keyword == "from":
                # add 'FROM w' make it executable
                _encode_sql.append(keyword)
            elif re.fullmatch(r"c\d+(_.+)?", keyword):
                # only take the first part
                index_key = int(keyword.split("_")[0][1:]) - 1
                # wrap it with `` to make it executable
                _encode_sql.append("`{}`".format(headers[index_key]))
            else:
                _encode_sql.append(keyword)
            # c4_list, replace it with the original one
            if "_address" in keyword or "_list" in keyword:
                keyword = re.findall(r"c\d+", keyword)[0]

            _execute_sql.append(keyword)

        return " ".join(_execute_sql), " ".join(_encode_sql)

    _exec_sql_str, _encode_sql_str = flatten_sql(_sql_struct)
    try:
        _sql_answers = _engine.execute_wtq_query(_exec_sql_str)
    except SQLAlchemyError as e:
        _sql_answers = []
    _norm_sql_answers = [str(_).replace("\n", " ") for _ in _sql_answers if _ is not None]
    if "none" in _norm_sql_answers:
        _norm_sql_answers = []
    return _encode_sql_str, _norm_sql_answers, _exec_sql_str


def _load_table_w_page(table_path, page_title_path=None) -> dict:
    """
    attention: the table_path must be the .tsv path.
    Load the WikiTableQuestion from csv file. Result in a dict format like:
    {"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
    """

    from utils.utils import _load_table

    table_item = _load_table(table_path)

    # Load page title
    if not page_title_path:
        page_title_path = table_path.replace("csv", "page").replace(".tsv", ".json")
    with open(page_title_path, "r") as f:
        page_title = json.load(f)['title']
    table_item['page_title'] = page_title

    return table_item