|
import operator |
|
import networkx as nx |
|
import attr |
|
import torch |
|
|
|
from seq2struct.beam_search import Hypothesis |
|
from seq2struct.models.nl2code.decoder import TreeState, get_field_presence_info |
|
from seq2struct.models.nl2code.tree_traversal import TreeTraversal |
|
|
|
@attr.s |
|
class Hypothesis4Filtering(Hypothesis): |
|
column_history = attr.ib(factory=list) |
|
table_history = attr.ib(factory=list) |
|
key_column_history = attr.ib(factory=list) |
|
|
|
|
|
def beam_search_with_heuristics(model, orig_item, preproc_item, beam_size, max_steps, from_cond=True): |
|
""" |
|
Find the valid FROM clasue with beam search |
|
""" |
|
inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
|
beam = [Hypothesis4Filtering(inference_state, next_choices)] |
|
|
|
cached_finished_seqs = [] |
|
beam_prefix = beam |
|
while True: |
|
|
|
prefixes2fill_from = [] |
|
for step in range(max_steps): |
|
if len(prefixes2fill_from) >= beam_size: |
|
break |
|
|
|
candidates = [] |
|
for hyp in beam_prefix: |
|
|
|
if hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \ |
|
and hyp.inference_state.cur_item.node_type == "from": |
|
prefixes2fill_from.append(hyp) |
|
else: |
|
candidates += [(hyp, choice, choice_score.item(), |
|
hyp.score + choice_score.item()) |
|
for choice, choice_score in hyp.next_choices] |
|
candidates.sort(key=operator.itemgetter(3), reverse=True) |
|
candidates = candidates[:beam_size-len(prefixes2fill_from)] |
|
|
|
|
|
beam_prefix = [] |
|
for hyp, choice, choice_score, cum_score in candidates: |
|
inference_state = hyp.inference_state.clone() |
|
|
|
|
|
column_history = hyp.column_history[:] |
|
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY and \ |
|
hyp.inference_state.cur_item.node_type == "column": |
|
column_history = column_history + [choice] |
|
|
|
next_choices = inference_state.step(choice) |
|
assert next_choices is not None |
|
beam_prefix.append( |
|
Hypothesis4Filtering(inference_state, next_choices, cum_score, |
|
hyp.choice_history + [choice], |
|
hyp.score_history + [choice_score], |
|
column_history)) |
|
|
|
prefixes2fill_from.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
|
|
|
|
|
beam_from = prefixes2fill_from |
|
max_size = 6 |
|
unfiltered_finished = [] |
|
prefixes_unfinished = [] |
|
for step in range(max_steps): |
|
if len(unfiltered_finished) + len(prefixes_unfinished) > max_size: |
|
break |
|
|
|
candidates = [] |
|
for hyp in beam_from: |
|
if step > 0 and hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \ |
|
and hyp.inference_state.cur_item.node_type == "from": |
|
prefixes_unfinished.append(hyp) |
|
else: |
|
candidates += [(hyp, choice, choice_score.item(), |
|
hyp.score + choice_score.item()) |
|
for choice, choice_score in hyp.next_choices] |
|
candidates.sort(key=operator.itemgetter(3), reverse=True) |
|
candidates = candidates[:max_size - len(prefixes_unfinished)] |
|
|
|
beam_from = [] |
|
for hyp, choice, choice_score, cum_score in candidates: |
|
inference_state = hyp.inference_state.clone() |
|
|
|
|
|
table_history = hyp.table_history[:] |
|
key_column_history = hyp.key_column_history[:] |
|
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: |
|
if hyp.inference_state.cur_item.node_type == "table": |
|
table_history = table_history + [choice] |
|
elif hyp.inference_state.cur_item.node_type == "column": |
|
key_column_history = key_column_history + [choice] |
|
|
|
next_choices = inference_state.step(choice) |
|
if next_choices is None: |
|
unfiltered_finished.append(Hypothesis4Filtering( |
|
inference_state, |
|
None, |
|
cum_score, |
|
hyp.choice_history + [choice], |
|
hyp.score_history + [choice_score], |
|
hyp.column_history, table_history, |
|
key_column_history)) |
|
else: |
|
beam_from.append( |
|
Hypothesis4Filtering(inference_state, next_choices, cum_score, |
|
hyp.choice_history + [choice], |
|
hyp.score_history + [choice_score], |
|
hyp.column_history, table_history, |
|
key_column_history)) |
|
|
|
unfiltered_finished.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
|
|
|
filtered_finished = [] |
|
for hyp in unfiltered_finished: |
|
mentioned_column_ids = set(hyp.column_history) |
|
mentioned_key_column_ids = set(hyp.key_column_history) |
|
mentioned_table_ids = set(hyp.table_history) |
|
|
|
|
|
if len(mentioned_table_ids) != len(hyp.table_history): |
|
continue |
|
|
|
|
|
|
|
if from_cond: |
|
covered_tables = set() |
|
must_include_key_columns = set() |
|
candidate_table_ids = sorted(mentioned_table_ids) |
|
start_table_id = candidate_table_ids[0] |
|
for table_id in candidate_table_ids[1:]: |
|
if table_id in covered_tables: |
|
continue |
|
try: |
|
path = nx.shortest_path( |
|
orig_item.schema.foreign_key_graph, source=start_table_id, target=table_id) |
|
except (nx.NetworkXNoPath, nx.NodeNotFound): |
|
covered_tables.add(table_id) |
|
continue |
|
|
|
for source_table_id, target_table_id in zip(path, path[1:]): |
|
if target_table_id in covered_tables: |
|
continue |
|
if target_table_id not in mentioned_table_ids: |
|
continue |
|
col1, col2 = orig_item.schema.foreign_key_graph[source_table_id][target_table_id]['columns'] |
|
must_include_key_columns.add(col1) |
|
must_include_key_columns.add(col2) |
|
if not must_include_key_columns == mentioned_key_column_ids: |
|
continue |
|
|
|
|
|
must_table_ids = set() |
|
for col in mentioned_column_ids: |
|
tab_ = orig_item.schema.columns[col].table |
|
if tab_ is not None: |
|
must_table_ids.add(tab_.id) |
|
if not must_table_ids.issubset(mentioned_table_ids): |
|
continue |
|
|
|
filtered_finished.append(hyp) |
|
|
|
filtered_finished.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
|
prefixes_unfinished.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
|
|
|
prefixes_, filtered_ = merge_beams(prefixes_unfinished, filtered_finished, beam_size) |
|
|
|
if filtered_: |
|
cached_finished_seqs = cached_finished_seqs + filtered_ |
|
cached_finished_seqs.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
|
if prefixes_ and len(prefixes_[0].choice_history) < 200: |
|
beam_prefix = prefixes_ |
|
for hyp in beam_prefix: |
|
hyp.table_history = [] |
|
hyp.column_history = [] |
|
hyp.key_column_history = [] |
|
elif cached_finished_seqs: |
|
return cached_finished_seqs[:beam_size] |
|
else: |
|
return unfiltered_finished[:beam_size] |
|
|
|
|
|
|
|
def merge_beams(beam_1, beam_2, beam_size): |
|
if len(beam_1) == 0 or len(beam_2) == 0: |
|
return beam_1, beam_2 |
|
|
|
annoated_beam_1 = [("beam_1", b) for b in beam_1] |
|
annoated_beam_2 = [("beam_2", b) for b in beam_2] |
|
merged_beams = annoated_beam_1 + annoated_beam_2 |
|
merged_beams.sort(key=lambda x: x[1].score, reverse=True) |
|
|
|
ret_beam_1 = [] |
|
ret_beam_2 = [] |
|
for label, beam in merged_beams[:beam_size]: |
|
if label == "beam_1": |
|
ret_beam_1.append(beam) |
|
else: |
|
assert label == "beam_2" |
|
ret_beam_2.append(beam) |
|
return ret_beam_1, ret_beam_2 |
|
|
|
|
|
def beam_search_with_oracle_column(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): |
|
inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
|
beam = [Hypothesis(inference_state, next_choices)] |
|
finished = [] |
|
assert beam_size == 1 |
|
|
|
|
|
root_node = preproc_item[1].tree |
|
|
|
col_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "column")])) |
|
tab_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "table")])) |
|
col_queue_copy = col_queue[:] |
|
tab_queue_copy = tab_queue[:] |
|
|
|
predict_counter = 0 |
|
|
|
for step in range(max_steps): |
|
if visualize_flag: |
|
print('step:') |
|
print(step) |
|
|
|
if len(finished) == beam_size: |
|
break |
|
|
|
|
|
assert len(beam) == 1 |
|
hyp = beam[0] |
|
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: |
|
if hyp.inference_state.cur_item.node_type == "column" \ |
|
and len(col_queue) > 0: |
|
gold_col = col_queue[0] |
|
|
|
flag = False |
|
for _choice in hyp.next_choices: |
|
if _choice[0] == gold_col: |
|
flag = True |
|
hyp.next_choices = [_choice] |
|
col_queue = col_queue[1:] |
|
break |
|
assert flag |
|
elif hyp.inference_state.cur_item.node_type == "table" \ |
|
and len(tab_queue) > 0: |
|
gold_tab = tab_queue[0] |
|
|
|
flag = False |
|
for _choice in hyp.next_choices: |
|
if _choice[0] == gold_tab: |
|
flag = True |
|
hyp.next_choices = [_choice] |
|
tab_queue = tab_queue[1:] |
|
break |
|
assert flag |
|
|
|
|
|
if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: |
|
predict_counter += 1 |
|
|
|
|
|
|
|
candidates = [] |
|
for hyp in beam: |
|
candidates += [(hyp, choice, choice_score.item(), |
|
hyp.score + choice_score.item()) |
|
for choice, choice_score in hyp.next_choices] |
|
|
|
|
|
candidates.sort(key=operator.itemgetter(3), reverse=True) |
|
candidates = candidates[:beam_size - len(finished)] |
|
|
|
|
|
|
|
beam = [] |
|
for hyp, choice, choice_score, cum_score in candidates: |
|
inference_state = hyp.inference_state.clone() |
|
next_choices = inference_state.step(choice) |
|
if next_choices is None: |
|
finished.append(Hypothesis( |
|
inference_state, |
|
None, |
|
cum_score, |
|
hyp.choice_history + [choice], |
|
hyp.score_history + [choice_score])) |
|
else: |
|
beam.append( |
|
Hypothesis(inference_state, next_choices, cum_score, |
|
hyp.choice_history + [choice], |
|
hyp.score_history + [choice_score])) |
|
if (len(col_queue_copy) + len(tab_queue_copy)) != predict_counter: |
|
|
|
pass |
|
finished.sort(key=operator.attrgetter('score'), reverse=True) |
|
return finished |
|
|
|
|
|
def beam_search_with_oracle_sketch(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): |
|
inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
|
hyp = Hypothesis(inference_state, next_choices) |
|
|
|
parsed = model.decoder.preproc.grammar.parse(orig_item.code, "val") |
|
if not parsed: |
|
return [] |
|
|
|
queue = [ |
|
TreeState( |
|
node = preproc_item[1].tree, |
|
parent_field_type=model.decoder.preproc.grammar.root_type, |
|
) |
|
] |
|
|
|
while queue: |
|
item = queue.pop() |
|
node = item.node |
|
parent_field_type = item.parent_field_type |
|
|
|
if isinstance(node, (list, tuple)): |
|
node_type = parent_field_type + '*' |
|
rule = (node_type, len(node)) |
|
if rule not in model.decoder.rules_index: |
|
return [] |
|
rule_idx = model.decoder.rules_index[rule] |
|
assert inference_state.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY |
|
next_choices = inference_state.step(rule_idx) |
|
|
|
if model.decoder.preproc.use_seq_elem_rules and \ |
|
parent_field_type in model.decoder.ast_wrapper.sum_types: |
|
parent_field_type += '_seq_elem' |
|
|
|
for i, elem in reversed(list(enumerate(node))): |
|
queue.append( |
|
TreeState( |
|
node=elem, |
|
parent_field_type=parent_field_type, |
|
)) |
|
|
|
hyp = Hypothesis( |
|
inference_state, |
|
None, |
|
0, |
|
hyp.choice_history + [rule_idx], |
|
hyp.score_history + [0]) |
|
continue |
|
|
|
if parent_field_type in model.decoder.preproc.grammar.pointers: |
|
assert inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY |
|
|
|
|
|
|
|
assert isinstance(node, int) |
|
next_choices = inference_state.step(node) |
|
hyp = Hypothesis( |
|
inference_state, |
|
None, |
|
0, |
|
hyp.choice_history + [node], |
|
hyp.score_history + [0]) |
|
continue |
|
|
|
if parent_field_type in model.decoder.ast_wrapper.primitive_types: |
|
field_value_split = model.decoder.preproc.grammar.tokenize_field_value(node) + [ |
|
'<EOS>'] |
|
|
|
for token in field_value_split: |
|
next_choices = inference_state.step(token) |
|
hyp = Hypothesis( |
|
inference_state, |
|
None, |
|
0, |
|
hyp.choice_history + field_value_split, |
|
hyp.score_history + [0]) |
|
continue |
|
|
|
type_info = model.decoder.ast_wrapper.singular_types[node['_type']] |
|
|
|
if parent_field_type in model.decoder.preproc.sum_type_constructors: |
|
|
|
rule = (parent_field_type, type_info.name) |
|
rule_idx = model.decoder.rules_index[rule] |
|
inference_state.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY |
|
extra_rules = [ |
|
model.decoder.rules_index[parent_field_type, extra_type] |
|
for extra_type in node.get('_extra_types', [])] |
|
next_choices = inference_state.step(rule_idx, extra_rules) |
|
|
|
hyp = Hypothesis( |
|
inference_state, |
|
None, |
|
0, |
|
hyp.choice_history + [rule_idx], |
|
hyp.score_history + [0]) |
|
|
|
if type_info.fields: |
|
|
|
|
|
present = get_field_presence_info(model.decoder.ast_wrapper, node, type_info.fields) |
|
rule = (node['_type'], tuple(present)) |
|
rule_idx = model.decoder.rules_index[rule] |
|
next_choices = inference_state.step(rule_idx) |
|
|
|
hyp = Hypothesis( |
|
inference_state, |
|
None, |
|
0, |
|
hyp.choice_history + [rule_idx], |
|
hyp.score_history + [0]) |
|
|
|
|
|
for field_info in reversed(type_info.fields): |
|
if field_info.name not in node: |
|
continue |
|
|
|
queue.append( |
|
TreeState( |
|
node=node[field_info.name], |
|
parent_field_type=field_info.type, |
|
)) |
|
|
|
return [hyp] |