Spaces:
Runtime error
Runtime error
File size: 10,066 Bytes
f6f97d8 9611943 f6f97d8 9611943 f6f97d8 9611943 f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import json
from typing import List, Dict
from nsql.qa_module.openai_qa import OpenAIQAModel
from nsql.qa_module.vqa import vqa_call
from nsql.database import NeuralDB
from nsql.parser import get_cfg_tree, get_steps, remove_duplicate, TreeNode, parse_question_paras, nsql_role_recognize, \
extract_answers
class NSQLExecutor(object):
def __init__(self, args, keys=None):
self.new_col_name_id = 0
self.qa_model = OpenAIQAModel(args, keys)
def generate_new_col_names(self, number):
col_names = ["col_{}".format(i) for i in range(self.new_col_name_id, self.new_col_name_id + number)]
self.new_col_name_id += number
return col_names
def sql_exec(self, sql: str, db: NeuralDB, verbose=True):
if verbose:
print("Exec SQL '{}' with additional row_id on {}".format(sql, db))
result = db.execute_query(sql)
return result
def nsql_exec(self, stamp, nsql: str, db: NeuralDB, verbose=True):
steps = []
root_node = get_cfg_tree(nsql) # Parse execution tree from nsql.
get_steps(root_node, steps) # Flatten the execution tree and get the steps.
steps = remove_duplicate(steps) # Remove the duplicate steps.
if verbose:
print("Steps:", [s.rename for s in steps])
with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "w") as f:
json.dump([s.rename for s in steps], f)
col_idx = 0
for step in steps:
# All steps should be formatted as 'QA()' except for last step which could also be normal SQL.
assert isinstance(step, TreeNode), "step must be treenode"
nsql = step.rename
if nsql.startswith('QA('):
question, sql_s = parse_question_paras(nsql, self.qa_model)
sql_executed_sub_tables = []
# Execute all SQLs and get the results as parameters
for sql_item in sql_s:
role, sql_item = nsql_role_recognize(sql_item,
db.get_header(),
db.get_passages_titles(),
db.get_images_titles())
if role in ['col', 'complete_sql']:
sql_executed_sub_table = self.sql_exec(sql_item, db, verbose=verbose)
sql_executed_sub_tables.append(sql_executed_sub_table)
elif role == 'val':
val = eval(sql_item)
sql_executed_sub_tables.append({
"header": ["row_id", "val"],
"rows": [["0", val]]
})
elif role == 'passage_title_and_image_title':
sql_executed_sub_tables.append({
"header": ["row_id", "{}".format(sql_item)],
"rows": [["0", db.get_passage_by_title(sql_item) +
db.get_image_caption_by_title(sql_item)
# "{} (The answer of '{}' is {})".format(
# sql_item,
# # Add image qa result as backup info
# question[len("***@"):],
# vqa_call(question=question[len("***@"):],
# image_path=db.get_image_by_title(sql_item)))
]]
})
elif role == 'passage_title':
sql_executed_sub_tables.append({
"header": ["row_id", "{}".format(sql_item)],
"rows": [["0", db.get_passage_by_title(sql_item)]]
})
elif role == 'image_title':
sql_executed_sub_tables.append({
"header": ["row_id", "{}".format(sql_item)],
"rows": [["0", db.get_image_caption_by_title(sql_item)]],
# "rows": [["0", "{} (The answer of '{}' is {})".format(
# sql_item,
# # Add image qa result as backup info
# question[len("***@"):],
# vqa_call(question=question[len("***@"):],
# image_path=db.get_image_by_title(sql_item)))]],
})
# If the sub_tables to execute with link, append it to the cell.
passage_linker = db.get_passage_linker()
image_linker = db.get_image_linker()
for _sql_executed_sub_table in sql_executed_sub_tables:
for i in range(len(_sql_executed_sub_table['rows'])):
for j in range(len(_sql_executed_sub_table['rows'][i])):
_cell = _sql_executed_sub_table['rows'][i][j]
if _cell in passage_linker.keys():
_sql_executed_sub_table['rows'][i][j] += " ({})".format(
# Add passage text as backup info
db.get_passage_by_title(passage_linker[_cell]))
if _cell in image_linker.keys():
_sql_executed_sub_table['rows'][i][j] += " ({})".format(
# Add image caption as backup info
db.get_image_caption_by_title(image_linker[_cell]))
# _sql_executed_sub_table['rows'][i][j] += " (The answer of '{}' is {})".format(
# # Add image qa result as backup info
# question[len("***@"):],
# vqa_call(question=question[len("***@"):],
# image_path=db.get_image_by_title(image_linker[_cell])))
pass
if question.lower().startswith("map@"):
# When the question is a type of mapping, we return the mapped column.
question = question[len("map@"):]
if step.father:
step.rename_father_col(col_idx=col_idx)
sub_table: Dict = self.qa_model.qa(question,
sql_executed_sub_tables,
table_title=db.table_title,
qa_type="map",
new_col_name_s=step.produced_col_name_s,
verbose=verbose)
with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(sql_executed_sub_tables, f)
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(sub_table, f)
db.add_sub_table(sub_table, verbose=verbose)
col_idx += 1
else: # This step is the final step
sub_table: Dict = self.qa_model.qa(question,
sql_executed_sub_tables,
table_title=db.table_title,
qa_type="map",
new_col_name_s=["col_{}".format(col_idx)],
verbose=verbose)
with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(sql_executed_sub_tables, f)
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(sub_table, f)
return extract_answers(sub_table)
elif question.lower().startswith("ans@"):
# When the question is a type of answering, we return an answer list.
question = question[len("ans@"):]
answer: List = self.qa_model.qa(question,
sql_executed_sub_tables,
table_title=db.table_title,
qa_type="ans",
verbose=verbose)
with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(sql_executed_sub_tables, f)
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(answer, f)
if step.father:
step.rename_father_val(answer)
else: # This step is the final step
return answer
else:
raise ValueError(
"Except for operators or NL question must start with 'map@' or 'ans@'!, check '{}'".format(
question))
else:
sub_table = self.sql_exec(nsql, db, verbose=verbose)
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f:
json.dump(sub_table, f)
return extract_answers(sub_table)
|