import collections import copy import itertools import os import asdl import attr import networkx as nx from seq2struct import ast_util from seq2struct.utils import registry def bimap(first, second): return {f: s for f, s in zip(first, second)}, {s: f for f, s in zip(first, second)} def filter_nones(d): return {k: v for k, v in d.items() if v is not None and v != []} def join(iterable, delimiter): it = iter(iterable) yield next(it) for x in it: yield delimiter yield x def intersperse(delimiter, seq): return itertools.islice( itertools.chain.from_iterable( zip(itertools.repeat(delimiter), seq)), 1, None) @registry.register('grammar', 'spider') class SpiderLanguage: root_type = 'sql' def __init__( self, output_from=False, use_table_pointer=False, include_literals=True, include_columns=True, end_with_from=False, clause_order=None, infer_from_conditions=False, factorize_sketch=0): # collect pointers and checkers custom_primitive_type_checkers = {} self.pointers = set() if use_table_pointer: custom_primitive_type_checkers['table'] = lambda x: isinstance(x, int) self.pointers.add('table') self.include_columns = include_columns if include_columns: custom_primitive_type_checkers['column'] = lambda x: isinstance(x, int) self.pointers.add('column') # create ast wrapper self.factorize_sketch = factorize_sketch if self.factorize_sketch == 0: asdl_file = "Spider.asdl" elif self.factorize_sketch == 1: asdl_file = "Spider_f1.asdl" elif self.factorize_sketch == 2: asdl_file = "Spider_f2.asdl" else: raise NotImplementedError self.ast_wrapper = ast_util.ASTWrapper( asdl.parse( os.path.join( os.path.dirname(os.path.abspath(__file__)), asdl_file)), custom_primitive_type_checkers=custom_primitive_type_checkers) if not use_table_pointer: self.ast_wrapper.singular_types['Table'].fields[0].type = 'int' if not include_columns: col_unit_fields = self.ast_wrapper.singular_types['col_unit'].fields assert col_unit_fields[1].name == 'col_id' del col_unit_fields[1] # literals of limit field self.include_literals = include_literals if not self.include_literals: if self.factorize_sketch == 0: limit_field = self.ast_wrapper.singular_types['sql'].fields[6] else: limit_field = self.ast_wrapper.singular_types['sql_orderby'].fields[1] assert limit_field.name == 'limit' limit_field.opt = False limit_field.type = 'singleton' # from field self.output_from = output_from self.end_with_from = end_with_from self.clause_order = clause_order self.infer_from_conditions = infer_from_conditions if self.clause_order: # clause order is prioritized over configurations like end_with_from assert factorize_sketch == 2 #TODO support other grammars sql_fields = self.ast_wrapper.product_types['sql'].fields letter2field = { k:v for k, v in zip("SFWGOI", sql_fields)} new_sql_fields = [letter2field[k] for k in self.clause_order] self.ast_wrapper.product_types['sql'].fields = new_sql_fields else: if not self.output_from: sql_fields = self.ast_wrapper.product_types['sql'].fields assert sql_fields[1].name == 'from' del sql_fields[1] else: sql_fields = self.ast_wrapper.product_types['sql'].fields assert sql_fields[1].name == "from" if self.end_with_from: sql_fields.append(sql_fields[1]) del sql_fields[1] def parse(self, code, section): return self.parse_sql(code) def unparse(self, tree, item): unparser = SpiderUnparser(self.ast_wrapper, item.schema, self.factorize_sketch) return unparser.unparse_sql(tree) @classmethod def tokenize_field_value(cls, field_value): if isinstance(field_value, bytes): field_value_str = field_value.encode('latin1') elif isinstance(field_value, str): field_value_str = field_value else: field_value_str = str(field_value) if field_value_str[0] == '"' and field_value_str[-1] == '"': field_value_str = field_value_str[1:-1] # TODO: Get rid of surrounding quotes return [field_value_str] # # # def parse_val(self, val): if isinstance(val, str): if not self.include_literals: return {'_type': 'Terminal'} return { '_type': 'String', 's': val, } elif isinstance(val, list): return { '_type': 'ColUnit', 'c': self.parse_col_unit(val), } elif isinstance(val, float): if not self.include_literals: return {'_type': 'Terminal'} return { '_type': 'Number', 'f': val, } elif isinstance(val, dict): return { '_type': 'ValSql', 's': self.parse_sql(val), } else: raise ValueError(val) def parse_col_unit(self, col_unit): agg_id, col_id, is_distinct = col_unit result = { '_type': 'col_unit', 'agg_id': {'_type': self.AGG_TYPES_F[agg_id]}, 'is_distinct': is_distinct, } if self.include_columns: result['col_id'] = col_id return result def parse_val_unit(self, val_unit): unit_op, col_unit1, col_unit2 = val_unit result = { '_type': self.UNIT_TYPES_F[unit_op], 'col_unit1': self.parse_col_unit(col_unit1), } if unit_op != 0: result['col_unit2'] = self.parse_col_unit(col_unit2) return result def parse_table_unit(self, table_unit): table_type, value = table_unit if table_type == 'sql': return { '_type': 'TableUnitSql', 's': self.parse_sql(value), } elif table_type == 'table_unit': return { '_type': 'Table', 'table_id': value, } else: raise ValueError(table_type) def parse_cond(self, cond, optional=False): if optional and not cond: return None if len(cond) > 1: return { '_type': self.LOGIC_OPERATORS_F[cond[1]], 'left': self.parse_cond(cond[:1]), 'right': self.parse_cond(cond[2:]), } (not_op, op_id, val_unit, val1, val2), = cond result = { '_type': self.COND_TYPES_F[op_id], 'val_unit': self.parse_val_unit(val_unit), 'val1': self.parse_val(val1), } if op_id == 1: # between result['val2'] = self.parse_val(val2) if not_op: result = { '_type': 'Not', 'c': result, } return result def parse_sql(self, sql, optional=False): if optional and sql is None: return None if self.factorize_sketch == 0: return filter_nones({ '_type': 'sql', 'select': self.parse_select(sql['select']), 'where': self.parse_cond(sql['where'], optional=True), 'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], 'order_by': self.parse_order_by(sql['orderBy']), 'having': self.parse_cond(sql['having'], optional=True), 'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), 'intersect': self.parse_sql(sql['intersect'], optional=True), 'except': self.parse_sql(sql['except'], optional=True), 'union': self.parse_sql(sql['union'], optional=True), **({ 'from': self.parse_from(sql['from'], self.infer_from_conditions), } if self.output_from else {}) }) elif self.factorize_sketch == 1: return filter_nones({ '_type': 'sql', 'select': self.parse_select(sql['select']), **({ 'from': self.parse_from(sql['from'], self.infer_from_conditions), } if self.output_from else {}), 'sql_where': filter_nones({ '_type': 'sql_where', 'where': self.parse_cond(sql['where'], optional=True), 'sql_groupby': filter_nones({ '_type': 'sql_groupby', 'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], 'having': filter_nones({ '_type': 'having', 'having': self.parse_cond(sql['having'], optional=True), }), 'sql_orderby': filter_nones({ '_type': 'sql_orderby', 'order_by': self.parse_order_by(sql['orderBy']), 'limit': filter_nones({ '_type': 'limit', 'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), }), 'sql_ieu': filter_nones({ '_type': 'sql_ieu', 'intersect': self.parse_sql(sql['intersect'], optional=True), 'except': self.parse_sql(sql['except'], optional=True), 'union': self.parse_sql(sql['union'], optional=True), }) }) }) }) }) elif self.factorize_sketch == 2: return filter_nones({ '_type': 'sql', 'select': self.parse_select(sql['select']), **({ 'from': self.parse_from(sql['from'], self.infer_from_conditions), } if self.output_from else {}), "sql_where": filter_nones({ '_type': 'sql_where', 'where': self.parse_cond(sql['where'], optional=True), }), "sql_groupby": filter_nones({ '_type': 'sql_groupby', 'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], 'having': self.parse_cond(sql['having'], optional=True), }), "sql_orderby": filter_nones({ '_type': 'sql_orderby', 'order_by': self.parse_order_by(sql['orderBy']), 'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), }), 'sql_ieu': filter_nones({ '_type': 'sql_ieu', 'intersect': self.parse_sql(sql['intersect'], optional=True), 'except': self.parse_sql(sql['except'], optional=True), 'union': self.parse_sql(sql['union'], optional=True), }) }) def parse_select(self, select): is_distinct, aggs = select return { '_type': 'select', 'is_distinct': is_distinct, 'aggs': [self.parse_agg(agg) for agg in aggs], } def parse_agg(self, agg): agg_id, val_unit = agg return { '_type': 'agg', 'agg_id': {'_type': self.AGG_TYPES_F[agg_id]}, 'val_unit': self.parse_val_unit(val_unit), } def parse_from(self, from_, infer_from_conditions=False): return filter_nones({ '_type': 'from', 'table_units': [ self.parse_table_unit(u) for u in from_['table_units']], 'conds': self.parse_cond(from_['conds'], optional=True) \ if not infer_from_conditions else None, }) def parse_order_by(self, order_by): if not order_by: return None order, val_units = order_by return { '_type': 'order_by', 'order': {'_type': self.ORDERS_F[order]}, 'val_units': [self.parse_val_unit(v) for v in val_units] } # # # COND_TYPES_F, COND_TYPES_B = bimap( #('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists'), #(None, 'Between', 'Eq', 'Gt', 'Lt', 'Ge', 'Le', 'Ne', 'In', 'Like', 'Is', 'Exists')) range(1, 10), ('Between', 'Eq', 'Gt', 'Lt', 'Ge', 'Le', 'Ne', 'In', 'Like')) UNIT_TYPES_F, UNIT_TYPES_B = bimap( #('none', '-', '+', '*', '/'), range(5), ('Column', 'Minus', 'Plus', 'Times', 'Divide')) AGG_TYPES_F, AGG_TYPES_B = bimap( range(6), ('NoneAggOp', 'Max', 'Min', 'Count', 'Sum', 'Avg')) ORDERS_F, ORDERS_B = bimap( ('asc', 'desc'), ('Asc', 'Desc')) LOGIC_OPERATORS_F, LOGIC_OPERATORS_B = bimap( ('and', 'or'), ('And', 'Or')) @attr.s class SpiderUnparser: ast_wrapper = attr.ib() schema = attr.ib() factorize_sketch = attr.ib(default=0) UNIT_TYPES_B = { 'Minus': '-', 'Plus': '+', 'Times': '*', 'Divide': '/', } COND_TYPES_B = { 'Between': 'BETWEEN', 'Eq': '=', 'Gt': '>', 'Lt': '<', 'Ge': '>=', 'Le': '<=', 'Ne': '!=', 'In': 'IN', 'Like': 'LIKE' } @classmethod def conjoin_conds(cls, conds): if not conds: return None if len(conds) == 1: return conds[0] return {'_type': 'And', 'left': conds[0], 'right': cls.conjoin_conds(conds[1:])} @classmethod def linearize_cond(cls, cond): if cond['_type'] in ('And', 'Or'): conds, keywords = cls.linearize_cond(cond['right']) return [cond['left']] + conds, [cond['_type']] + keywords else: return [cond], [] def unparse_val(self, val): if val['_type'] == 'Terminal': return "'terminal'" if val['_type'] == 'String': return val['s'] if val['_type'] == 'ColUnit': return self.unparse_col_unit(val['c']) if val['_type'] == 'Number': return str(val['f']) if val['_type'] == 'ValSql': return '({})'.format(self.unparse_sql(val['s'])) def unparse_col_unit(self, col_unit): if 'col_id' in col_unit: column = self.schema.columns[col_unit['col_id']] if column.table is None: column_name = column.orig_name else: column_name = '{}.{}'.format(column.table.orig_name, column.orig_name) else: column_name = 'some_col' if col_unit['is_distinct']: column_name = 'DISTINCT {}'.format(column_name) agg_type = col_unit['agg_id']['_type'] if agg_type == 'NoneAggOp': return column_name else: return '{}({})'.format(agg_type, column_name) def unparse_val_unit(self, val_unit): if val_unit['_type'] == 'Column': return self.unparse_col_unit(val_unit['col_unit1']) col1 = self.unparse_col_unit(val_unit['col_unit1']) col2 = self.unparse_col_unit(val_unit['col_unit2']) return '{} {} {}'.format(col1, self.UNIT_TYPES_B[val_unit['_type']], col2) #def unparse_table_unit(self, table_unit): # raise NotImplementedError def unparse_cond(self, cond, negated=False): if cond['_type'] == 'And': assert not negated return '{} AND {}'.format( self.unparse_cond(cond['left']), self.unparse_cond(cond['right'])) elif cond['_type'] == 'Or': assert not negated return '{} OR {}'.format( self.unparse_cond(cond['left']), self.unparse_cond(cond['right'])) elif cond['_type'] == 'Not': return self.unparse_cond(cond['c'], negated=True) elif cond['_type'] == 'Between': tokens = [self.unparse_val_unit(cond['val_unit'])] if negated: tokens.append('NOT') tokens += [ 'BETWEEN', self.unparse_val(cond['val1']), 'AND', self.unparse_val(cond['val2']), ] return ' '.join(tokens) tokens = [self.unparse_val_unit(cond['val_unit'])] if negated: tokens.append('NOT') tokens += [self.COND_TYPES_B[cond['_type']], self.unparse_val(cond['val1'])] return ' '.join(tokens) def refine_from(self, tree): """ 1) Inferring tables from columns predicted 2) Mix them with the predicted tables if any 3) Inferring conditions based on tables """ tree = dict(tree) # nested query in from clause, recursively use the refinement if "from" in tree and tree["from"]["table_units"][0]["_type"] == 'TableUnitSql': for table_unit in tree["from"]["table_units"]: subquery_tree = table_unit["s"] self.refine_from(subquery_tree) return # get predicted tables predicted_from_table_ids = set() if "from" in tree: table_unit_set = [] for table_unit in tree["from"]["table_units"]: if table_unit["table_id"] not in predicted_from_table_ids: predicted_from_table_ids.add(table_unit["table_id"]) table_unit_set.append(table_unit) tree["from"]["table_units"] = table_unit_set # remove duplicate # Get all candidate columns candidate_column_ids = set(self.ast_wrapper.find_all_descendants_of_type( tree, 'column', lambda field: field.type != 'sql')) candidate_columns = [self.schema.columns[i] for i in candidate_column_ids] must_in_from_table_ids = set( column.table.id for column in candidate_columns if column.table is not None) # Table the union of inferred and predicted tables all_from_table_ids = must_in_from_table_ids.union(predicted_from_table_ids) if not all_from_table_ids: # TODO: better heuristic e.g., tables that have exact match all_from_table_ids = {0} covered_tables = set() candidate_table_ids = sorted(all_from_table_ids) start_table_id = candidate_table_ids[0] conds = [] for table_id in candidate_table_ids[1:]: if table_id in covered_tables: continue try: path = nx.shortest_path( self.schema.foreign_key_graph, source=start_table_id, target=table_id) except (nx.NetworkXNoPath, nx.NodeNotFound): covered_tables.add(table_id) continue for source_table_id, target_table_id in zip(path, path[1:]): if target_table_id in covered_tables: continue all_from_table_ids.add(target_table_id) col1, col2 = self.schema.foreign_key_graph[source_table_id][target_table_id]['columns'] conds.append({ '_type': 'Eq', 'val_unit': { '_type': 'Column', 'col_unit1': { '_type': 'col_unit', 'agg_id': {'_type': 'NoneAggOp'}, 'col_id': col1, 'is_distinct': False }, }, 'val1': { '_type': 'ColUnit', 'c': { '_type': 'col_unit', 'agg_id': {'_type': 'NoneAggOp'}, 'col_id': col2, 'is_distinct': False } } }) table_units = [{'_type': 'Table', 'table_id': i} for i in sorted(all_from_table_ids)] tree['from'] = { '_type': 'from', 'table_units': table_units, } cond_node = self.conjoin_conds(conds) if cond_node is not None: tree['from']['conds'] = cond_node def unparse_sql(self, tree): self.refine_from(tree) result = [ # select select, self.unparse_select(tree['select']), # from from, self.unparse_from(tree['from']), ] def find_subtree(_tree, name): if self.factorize_sketch == 0: return _tree, _tree elif name in _tree: if self.factorize_sketch == 1: return _tree[name], _tree[name] elif self.factorize_sketch == 2: return _tree, _tree[name] else: raise NotImplementedError tree, target_tree = find_subtree(tree, "sql_where") # cond? where, if 'where' in target_tree: result += [ 'WHERE', self.unparse_cond(target_tree['where']) ] tree, target_tree = find_subtree(tree, "sql_groupby") # col_unit* group_by, if 'group_by' in target_tree: result += [ 'GROUP BY', ', '.join(self.unparse_col_unit(c) for c in target_tree['group_by']) ] tree, target_tree = find_subtree(tree, "sql_orderby") # order_by? order_by, if 'order_by' in target_tree: result.append(self.unparse_order_by(target_tree['order_by'])) tree, target_tree = find_subtree(tree, "sql_groupby") # cond? having, if 'having' in target_tree: result += ['HAVING', self.unparse_cond(target_tree['having'])] tree, target_tree = find_subtree(tree, "sql_orderby") # int? limit, if 'limit' in target_tree: if isinstance(target_tree['limit'], bool): if target_tree['limit']: result += ['LIMIT', '1'] else: result += ['LIMIT', str(target_tree['limit'])] tree, target_tree = find_subtree(tree, "sql_ieu") # sql? intersect, if 'intersect' in target_tree: result += ['INTERSECT', self.unparse_sql(target_tree['intersect'])] # sql? except, if 'except' in target_tree: result += ['EXCEPT', self.unparse_sql(target_tree['except'])] # sql? union if 'union' in target_tree: result += ['UNION', self.unparse_sql(target_tree['union'])] return ' '.join(result) def unparse_select(self, select): tokens = ['SELECT'] if select['is_distinct']: tokens.append('DISTINCT') tokens.append(', '.join(self.unparse_agg(agg) for agg in select.get('aggs', []))) return ' '.join(tokens) def unparse_agg(self, agg): unparsed_val_unit = self.unparse_val_unit(agg['val_unit']) agg_type = agg['agg_id']['_type'] if agg_type == 'NoneAggOp': return unparsed_val_unit else: return '{}({})'.format(agg_type, unparsed_val_unit) def unparse_from(self, from_): if 'conds' in from_: all_conds, keywords = self.linearize_cond(from_['conds']) else: all_conds, keywords = [], [] assert all(keyword == 'And' for keyword in keywords) cond_indices_by_table = collections.defaultdict(set) tables_involved_by_cond_idx = collections.defaultdict(set) for i, cond in enumerate(all_conds): for column in self.ast_wrapper.find_all_descendants_of_type(cond, 'column'): table = self.schema.columns[column].table if table is None: continue cond_indices_by_table[table.id].add(i) tables_involved_by_cond_idx[i].add(table.id) output_table_ids = set() output_cond_indices = set() tokens = ['FROM'] for i, table_unit in enumerate(from_.get('table_units', [])): if i > 0: tokens += ['JOIN'] if table_unit['_type'] == 'TableUnitSql': tokens.append('({})'.format(self.unparse_sql(table_unit['s']))) elif table_unit['_type'] == 'Table': table_id = table_unit['table_id'] tokens += [self.schema.tables[table_id].orig_name] output_table_ids.add(table_id) # Output "ON " if all tables involved in the condition have been output conds_to_output = [] for cond_idx in sorted(cond_indices_by_table[table_id]): if cond_idx in output_cond_indices: continue if tables_involved_by_cond_idx[cond_idx] <= output_table_ids: conds_to_output.append(all_conds[cond_idx]) output_cond_indices.add(cond_idx) if conds_to_output: tokens += ['ON'] tokens += list(intersperse( 'AND', (self.unparse_cond(cond) for cond in conds_to_output))) return ' '.join(tokens) def unparse_order_by(self, order_by): return 'ORDER BY {} {}'.format( ', '.join(self.unparse_val_unit(v) for v in order_by['val_units']), order_by['order']['_type'])