File size: 9,513 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
from typing import List, Dict
from nsql.qa_module.openai_qa import OpenAIQAModel
from nsql.qa_module.vqa import vqa_call
from nsql.database import NeuralDB
from nsql.parser import get_cfg_tree, get_steps, remove_duplicate, TreeNode, parse_question_paras, nsql_role_recognize, \
    extract_answers


class NSQLExecutor(object):
    def __init__(self, args, keys=None):
        self.new_col_name_id = 0
        self.qa_model = OpenAIQAModel(args, keys)

    def generate_new_col_names(self, number):
        col_names = ["col_{}".format(i) for i in range(self.new_col_name_id, self.new_col_name_id + number)]
        self.new_col_name_id += number
        return col_names

    def sql_exec(self, sql: str, db: NeuralDB, verbose=True):
        if verbose:
            print("Exec SQL '{}' with additional row_id on {}".format(sql, db))
        result = db.execute_query(sql)
        return result

    def nsql_exec(self, stamp, nsql: str, db: NeuralDB, verbose=True):
        steps = []
        root_node = get_cfg_tree(nsql)  # Parse execution tree from nsql.
        get_steps(root_node, steps)  # Flatten the execution tree and get the steps.
        steps = remove_duplicate(steps)  # Remove the duplicate steps.
        if verbose:
            print("Steps:", [s.rename for s in steps])
            with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "w") as f:
                json.dump([s.rename for s in steps], f)
        col_idx = 0
        for step in steps:
            # All steps should be formatted as 'QA()' except for last step which could also be normal SQL.
            assert isinstance(step, TreeNode), "step must be treenode"
            nsql = step.rename
            if nsql.startswith('QA('):
                question, sql_s = parse_question_paras(nsql, self.qa_model)
                sql_executed_sub_tables = []

                # Execute all SQLs and get the results as parameters
                for sql_item in sql_s:
                    role, sql_item = nsql_role_recognize(sql_item,
                                                         db.get_header(),
                                                         db.get_passages_titles(),
                                                         db.get_images_titles())
                    if role in ['col', 'complete_sql']:
                        sql_executed_sub_table = self.sql_exec(sql_item, db, verbose=verbose)
                        sql_executed_sub_tables.append(sql_executed_sub_table)
                    elif role == 'val':
                        val = eval(sql_item)
                        sql_executed_sub_tables.append({
                            "header": ["row_id", "val"],
                            "rows": [["0", val]]
                        })
                    elif role == 'passage_title_and_image_title':
                        sql_executed_sub_tables.append({
                            "header": ["row_id", "{}".format(sql_item)],
                            "rows": [["0", db.get_passage_by_title(sql_item) +
                                      db.get_image_caption_by_title(sql_item)
                                      # "{} (The answer of '{}' is {})".format(
                                      #     sql_item,
                                      #     # Add image qa result as backup info
                                      #     question[len("***@"):],
                                      #     vqa_call(question=question[len("***@"):],
                                      #              image_path=db.get_image_by_title(sql_item)))
                                      ]]
                        })
                    elif role == 'passage_title':
                        sql_executed_sub_tables.append({
                            "header": ["row_id", "{}".format(sql_item)],
                            "rows": [["0", db.get_passage_by_title(sql_item)]]
                        })
                    elif role == 'image_title':
                        sql_executed_sub_tables.append({
                            "header": ["row_id", "{}".format(sql_item)],
                            "rows": [["0", db.get_image_caption_by_title(sql_item)]],
                            # "rows": [["0", "{} (The answer of '{}' is {})".format(
                            #         sql_item,
                            #         # Add image qa result as backup info
                            #         question[len("***@"):],
                            #         vqa_call(question=question[len("***@"):],
                            #                  image_path=db.get_image_by_title(sql_item)))]],
                        })

                # If the sub_tables to execute with link, append it to the cell.
                passage_linker = db.get_passage_linker()
                image_linker = db.get_image_linker()
                for _sql_executed_sub_table in sql_executed_sub_tables:
                    for i in range(len(_sql_executed_sub_table['rows'])):
                        for j in range(len(_sql_executed_sub_table['rows'][i])):
                            _cell = _sql_executed_sub_table['rows'][i][j]
                            if _cell in passage_linker.keys():
                                _sql_executed_sub_table['rows'][i][j] += " ({})".format(
                                    # Add passage text as backup info
                                    db.get_passage_by_title(passage_linker[_cell]))

                            if _cell in image_linker.keys():
                                _sql_executed_sub_table['rows'][i][j] += " ({})".format(
                                    # Add image caption as backup info
                                    db.get_image_caption_by_title(image_linker[_cell]))
                                # _sql_executed_sub_table['rows'][i][j] += " (The answer of '{}' is {})".format(
                                #     # Add image qa result as backup info
                                #     question[len("***@"):],
                                #     vqa_call(question=question[len("***@"):],
                                #              image_path=db.get_image_by_title(image_linker[_cell])))
                                pass

                if question.lower().startswith("map@"):
                    # When the question is a type of mapping, we return the mapped column.
                    question = question[len("map@"):]
                    if step.father:
                        step.rename_father_col(col_idx=col_idx)
                        sub_table: Dict = self.qa_model.qa(question,
                                                           sql_executed_sub_tables,
                                                           table_title=db.table_title,
                                                           qa_type="map",
                                                           new_col_name_s=step.produced_col_name_s,
                                                           verbose=verbose)
                        with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
                            json.dump(sub_table, f)
                        db.add_sub_table(sub_table, verbose=verbose)
                        col_idx += 1
                    else:  # This step is the final step
                        sub_table: Dict = self.qa_model.qa(question,
                                                           sql_executed_sub_tables,
                                                           table_title=db.table_title,
                                                           qa_type="map",
                                                           new_col_name_s=["col_{}".format(col_idx)],
                                                           verbose=verbose)
                        with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
                            json.dump(sub_table, f)
                        return extract_answers(sub_table)

                elif question.lower().startswith("ans@"):
                    # When the question is a type of answering, we return an answer list.
                    question = question[len("ans@"):]
                    answer: List = self.qa_model.qa(question,
                                                    sql_executed_sub_tables,
                                                    table_title=db.table_title,
                                                    qa_type="ans",
                                                    verbose=verbose)
                    with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
                        json.dump(answer, f)
                    if step.father:
                        step.rename_father_val(answer)
                    else:  # This step is the final step
                        return answer
                else:
                    raise ValueError(
                        "Except for operators or NL question must start with 'map@' or 'ans@'!, check '{}'".format(
                            question))

            else:
                sub_table = self.sql_exec(nsql, db, verbose=verbose)
                with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
                    json.dump(sub_table, f)
                return extract_answers(sub_table)