import operator import networkx as nx import attr import torch from seq2struct.beam_search import Hypothesis from seq2struct.models.nl2code.decoder import TreeState, get_field_presence_info from seq2struct.models.nl2code.tree_traversal import TreeTraversal @attr.s class Hypothesis4Filtering(Hypothesis): column_history = attr.ib(factory=list) table_history = attr.ib(factory=list) key_column_history = attr.ib(factory=list) def beam_search_with_heuristics(model, orig_item, preproc_item, beam_size, max_steps, from_cond=True): """ Find the valid FROM clasue with beam search """ inference_state, next_choices = model.begin_inference(orig_item, preproc_item) beam = [Hypothesis4Filtering(inference_state, next_choices)] cached_finished_seqs = [] # cache filtered trajectories beam_prefix = beam while True: # search prefixes with beam search prefixes2fill_from = [] for step in range(max_steps): if len(prefixes2fill_from) >= beam_size: break candidates = [] for hyp in beam_prefix: # print(hyp.inference_state.cur_item.state, hyp.inference_state.cur_item.node_type ) if hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \ and hyp.inference_state.cur_item.node_type == "from": prefixes2fill_from.append(hyp) else: candidates += [(hyp, choice, choice_score.item(), hyp.score + choice_score.item()) for choice, choice_score in hyp.next_choices] candidates.sort(key=operator.itemgetter(3), reverse=True) candidates = candidates[:beam_size-len(prefixes2fill_from)] # Create the new hypotheses from the expansions beam_prefix = [] for hyp, choice, choice_score, cum_score in candidates: inference_state = hyp.inference_state.clone() # cache column choice column_history = hyp.column_history[:] if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY and \ hyp.inference_state.cur_item.node_type == "column": column_history = column_history + [choice] next_choices = inference_state.step(choice) assert next_choices is not None beam_prefix.append( Hypothesis4Filtering(inference_state, next_choices, cum_score, hyp.choice_history + [choice], hyp.score_history + [choice_score], column_history)) prefixes2fill_from.sort(key=operator.attrgetter('score'), reverse=True) # assert len(prefixes) == beam_size # emuerating beam_from = prefixes2fill_from max_size = 6 unfiltered_finished = [] prefixes_unfinished = [] for step in range(max_steps): if len(unfiltered_finished) + len(prefixes_unfinished) > max_size: break candidates = [] for hyp in beam_from: if step > 0 and hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \ and hyp.inference_state.cur_item.node_type == "from": prefixes_unfinished.append(hyp) else: candidates += [(hyp, choice, choice_score.item(), hyp.score + choice_score.item()) for choice, choice_score in hyp.next_choices] candidates.sort(key=operator.itemgetter(3), reverse=True) candidates = candidates[:max_size - len(prefixes_unfinished)] beam_from = [] for hyp, choice, choice_score, cum_score in candidates: inference_state = hyp.inference_state.clone() # cache table choice table_history = hyp.table_history[:] key_column_history = hyp.key_column_history[:] if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: if hyp.inference_state.cur_item.node_type == "table": table_history = table_history + [choice] elif hyp.inference_state.cur_item.node_type == "column": key_column_history = key_column_history + [choice] next_choices = inference_state.step(choice) if next_choices is None: unfiltered_finished.append(Hypothesis4Filtering( inference_state, None, cum_score, hyp.choice_history + [choice], hyp.score_history + [choice_score], hyp.column_history, table_history, key_column_history)) else: beam_from.append( Hypothesis4Filtering(inference_state, next_choices, cum_score, hyp.choice_history + [choice], hyp.score_history + [choice_score], hyp.column_history, table_history, key_column_history)) unfiltered_finished.sort(key=operator.attrgetter('score'), reverse=True) # filtering filtered_finished = [] for hyp in unfiltered_finished: mentioned_column_ids = set(hyp.column_history) mentioned_key_column_ids = set(hyp.key_column_history) mentioned_table_ids = set(hyp.table_history) # duplicate tables if len(mentioned_table_ids) != len(hyp.table_history): continue # the foreign key should be correctly used # NOTE: the new version does not predict conditions in FROM clause anymore if from_cond: covered_tables = set() must_include_key_columns = set() candidate_table_ids = sorted(mentioned_table_ids) start_table_id = candidate_table_ids[0] for table_id in candidate_table_ids[1:]: if table_id in covered_tables: continue try: path = nx.shortest_path( orig_item.schema.foreign_key_graph, source=start_table_id, target=table_id) except (nx.NetworkXNoPath, nx.NodeNotFound): covered_tables.add(table_id) continue for source_table_id, target_table_id in zip(path, path[1:]): if target_table_id in covered_tables: continue if target_table_id not in mentioned_table_ids: continue col1, col2 = orig_item.schema.foreign_key_graph[source_table_id][target_table_id]['columns'] must_include_key_columns.add(col1) must_include_key_columns.add(col2) if not must_include_key_columns == mentioned_key_column_ids: continue # tables whose columns are mentioned should also exist must_table_ids = set() for col in mentioned_column_ids: tab_ = orig_item.schema.columns[col].table if tab_ is not None: must_table_ids.add(tab_.id) if not must_table_ids.issubset(mentioned_table_ids): continue filtered_finished.append(hyp) filtered_finished.sort(key=operator.attrgetter('score'), reverse=True) # filtered.sort(key=lambda x: x.score / len(x.choice_history), reverse=True) prefixes_unfinished.sort(key=operator.attrgetter('score'), reverse=True) # new_prefixes.sort(key=lambda x: x.score / len(x.choice_history), reverse=True) prefixes_, filtered_ = merge_beams(prefixes_unfinished, filtered_finished, beam_size) if filtered_: cached_finished_seqs = cached_finished_seqs + filtered_ cached_finished_seqs.sort(key=operator.attrgetter('score'), reverse=True) if prefixes_ and len(prefixes_[0].choice_history) < 200: beam_prefix = prefixes_ for hyp in beam_prefix: hyp.table_history = [] hyp.column_history = [] hyp.key_column_history = [] elif cached_finished_seqs: return cached_finished_seqs[:beam_size] else: return unfiltered_finished[:beam_size] # merge sorted beam def merge_beams(beam_1, beam_2, beam_size): if len(beam_1) == 0 or len(beam_2) == 0: return beam_1, beam_2 annoated_beam_1 = [("beam_1", b) for b in beam_1] annoated_beam_2 = [("beam_2", b) for b in beam_2] merged_beams = annoated_beam_1 + annoated_beam_2 merged_beams.sort(key=lambda x: x[1].score, reverse=True) ret_beam_1 = [] ret_beam_2 = [] for label, beam in merged_beams[:beam_size]: if label == "beam_1": ret_beam_1.append(beam) else: assert label == "beam_2" ret_beam_2.append(beam) return ret_beam_1, ret_beam_2 def beam_search_with_oracle_column(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): inference_state, next_choices = model.begin_inference(orig_item, preproc_item) beam = [Hypothesis(inference_state, next_choices)] finished = [] assert beam_size == 1 # identify all the cols mentioned in the gold sql root_node = preproc_item[1].tree col_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "column")])) tab_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "table")])) col_queue_copy = col_queue[:] tab_queue_copy = tab_queue[:] predict_counter = 0 for step in range(max_steps): if visualize_flag: print('step:') print(step) # Check if all beams are finished if len(finished) == beam_size: break # hijack the next choice using the gold col assert len(beam) == 1 hyp = beam[0] if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: if hyp.inference_state.cur_item.node_type == "column" \ and len(col_queue) > 0: gold_col = col_queue[0] flag = False for _choice in hyp.next_choices: if _choice[0] == gold_col: flag = True hyp.next_choices = [_choice] col_queue = col_queue[1:] break assert flag elif hyp.inference_state.cur_item.node_type == "table" \ and len(tab_queue) > 0: gold_tab = tab_queue[0] flag = False for _choice in hyp.next_choices: if _choice[0] == gold_tab: flag = True hyp.next_choices = [_choice] tab_queue = tab_queue[1:] break assert flag # for debug if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: predict_counter += 1 # For each hypothesis, get possible expansions # Score each expansion candidates = [] for hyp in beam: candidates += [(hyp, choice, choice_score.item(), hyp.score + choice_score.item()) for choice, choice_score in hyp.next_choices] # Keep the top K expansions candidates.sort(key=operator.itemgetter(3), reverse=True) candidates = candidates[:beam_size - len(finished)] # Create the new hypotheses from the expansions beam = [] for hyp, choice, choice_score, cum_score in candidates: inference_state = hyp.inference_state.clone() next_choices = inference_state.step(choice) if next_choices is None: finished.append(Hypothesis( inference_state, None, cum_score, hyp.choice_history + [choice], hyp.score_history + [choice_score])) else: beam.append( Hypothesis(inference_state, next_choices, cum_score, hyp.choice_history + [choice], hyp.score_history + [choice_score])) if (len(col_queue_copy) + len(tab_queue_copy)) != predict_counter: # print("The number of column/tables are not matched") pass finished.sort(key=operator.attrgetter('score'), reverse=True) return finished def beam_search_with_oracle_sketch(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): inference_state, next_choices = model.begin_inference(orig_item, preproc_item) hyp = Hypothesis(inference_state, next_choices) parsed = model.decoder.preproc.grammar.parse(orig_item.code, "val") if not parsed: return [] queue = [ TreeState( node = preproc_item[1].tree, parent_field_type=model.decoder.preproc.grammar.root_type, ) ] while queue: item = queue.pop() node = item.node parent_field_type = item.parent_field_type if isinstance(node, (list, tuple)): node_type = parent_field_type + '*' rule = (node_type, len(node)) if rule not in model.decoder.rules_index: return [] rule_idx = model.decoder.rules_index[rule] assert inference_state.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY next_choices = inference_state.step(rule_idx) if model.decoder.preproc.use_seq_elem_rules and \ parent_field_type in model.decoder.ast_wrapper.sum_types: parent_field_type += '_seq_elem' for i, elem in reversed(list(enumerate(node))): queue.append( TreeState( node=elem, parent_field_type=parent_field_type, )) hyp = Hypothesis( inference_state, None, 0, hyp.choice_history + [rule_idx], hyp.score_history + [0]) continue if parent_field_type in model.decoder.preproc.grammar.pointers: assert inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY # best_choice = max(next_choices, key=lambda x: x[1]) # node = best_choice[0] # override the node assert isinstance(node, int) next_choices = inference_state.step(node) hyp = Hypothesis( inference_state, None, 0, hyp.choice_history + [node], hyp.score_history + [0]) continue if parent_field_type in model.decoder.ast_wrapper.primitive_types: field_value_split = model.decoder.preproc.grammar.tokenize_field_value(node) + [ ''] for token in field_value_split: next_choices = inference_state.step(token) hyp = Hypothesis( inference_state, None, 0, hyp.choice_history + field_value_split, hyp.score_history + [0]) continue type_info = model.decoder.ast_wrapper.singular_types[node['_type']] if parent_field_type in model.decoder.preproc.sum_type_constructors: # ApplyRule, like expr -> Call rule = (parent_field_type, type_info.name) rule_idx = model.decoder.rules_index[rule] inference_state.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY extra_rules = [ model.decoder.rules_index[parent_field_type, extra_type] for extra_type in node.get('_extra_types', [])] next_choices = inference_state.step(rule_idx, extra_rules) hyp = Hypothesis( inference_state, None, 0, hyp.choice_history + [rule_idx], hyp.score_history + [0]) if type_info.fields: # ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords] # Figure out which rule needs to be applied present = get_field_presence_info(model.decoder.ast_wrapper, node, type_info.fields) rule = (node['_type'], tuple(present)) rule_idx = model.decoder.rules_index[rule] next_choices = inference_state.step(rule_idx) hyp = Hypothesis( inference_state, None, 0, hyp.choice_history + [rule_idx], hyp.score_history + [0]) # reversed so that we perform a DFS in left-to-right order for field_info in reversed(type_info.fields): if field_info.name not in node: continue queue.append( TreeState( node=node[field_info.name], parent_field_type=field_info.type, )) return [hyp]