|
import collections |
|
import copy |
|
import itertools |
|
import os |
|
|
|
import asdl |
|
import attr |
|
import networkx as nx |
|
|
|
from seq2struct import ast_util |
|
from seq2struct.utils import registry |
|
|
|
|
|
def bimap(first, second): |
|
return {f: s for f, s in zip(first, second)}, {s: f for f, s in zip(first, second)} |
|
|
|
|
|
def filter_nones(d): |
|
return {k: v for k, v in d.items() if v is not None and v != []} |
|
|
|
|
|
def join(iterable, delimiter): |
|
it = iter(iterable) |
|
yield next(it) |
|
for x in it: |
|
yield delimiter |
|
yield x |
|
|
|
|
|
def intersperse(delimiter, seq): |
|
return itertools.islice( |
|
itertools.chain.from_iterable( |
|
zip(itertools.repeat(delimiter), seq)), 1, None) |
|
|
|
|
|
@registry.register('grammar', 'spider') |
|
class SpiderLanguage: |
|
|
|
root_type = 'sql' |
|
|
|
def __init__( |
|
self, |
|
output_from=False, |
|
use_table_pointer=False, |
|
include_literals=True, |
|
include_columns=True, |
|
end_with_from=False, |
|
clause_order=None, |
|
infer_from_conditions=False, |
|
factorize_sketch=0): |
|
|
|
|
|
custom_primitive_type_checkers = {} |
|
self.pointers = set() |
|
if use_table_pointer: |
|
custom_primitive_type_checkers['table'] = lambda x: isinstance(x, int) |
|
self.pointers.add('table') |
|
self.include_columns = include_columns |
|
if include_columns: |
|
custom_primitive_type_checkers['column'] = lambda x: isinstance(x, int) |
|
self.pointers.add('column') |
|
|
|
|
|
self.factorize_sketch = factorize_sketch |
|
if self.factorize_sketch == 0: |
|
asdl_file = "Spider.asdl" |
|
elif self.factorize_sketch == 1: |
|
asdl_file = "Spider_f1.asdl" |
|
elif self.factorize_sketch == 2: |
|
asdl_file = "Spider_f2.asdl" |
|
else: |
|
raise NotImplementedError |
|
self.ast_wrapper = ast_util.ASTWrapper( |
|
asdl.parse( |
|
os.path.join( |
|
os.path.dirname(os.path.abspath(__file__)), |
|
asdl_file)), |
|
custom_primitive_type_checkers=custom_primitive_type_checkers) |
|
if not use_table_pointer: |
|
self.ast_wrapper.singular_types['Table'].fields[0].type = 'int' |
|
if not include_columns: |
|
col_unit_fields = self.ast_wrapper.singular_types['col_unit'].fields |
|
assert col_unit_fields[1].name == 'col_id' |
|
del col_unit_fields[1] |
|
|
|
|
|
self.include_literals = include_literals |
|
if not self.include_literals: |
|
if self.factorize_sketch == 0: |
|
limit_field = self.ast_wrapper.singular_types['sql'].fields[6] |
|
else: |
|
limit_field = self.ast_wrapper.singular_types['sql_orderby'].fields[1] |
|
assert limit_field.name == 'limit' |
|
limit_field.opt = False |
|
limit_field.type = 'singleton' |
|
|
|
|
|
self.output_from = output_from |
|
self.end_with_from = end_with_from |
|
self.clause_order = clause_order |
|
self.infer_from_conditions = infer_from_conditions |
|
if self.clause_order: |
|
|
|
assert factorize_sketch == 2 |
|
sql_fields = self.ast_wrapper.product_types['sql'].fields |
|
letter2field = { k:v for k, v in zip("SFWGOI", sql_fields)} |
|
new_sql_fields = [letter2field[k] for k in self.clause_order] |
|
self.ast_wrapper.product_types['sql'].fields = new_sql_fields |
|
else: |
|
if not self.output_from: |
|
sql_fields = self.ast_wrapper.product_types['sql'].fields |
|
assert sql_fields[1].name == 'from' |
|
del sql_fields[1] |
|
else: |
|
sql_fields = self.ast_wrapper.product_types['sql'].fields |
|
assert sql_fields[1].name == "from" |
|
if self.end_with_from: |
|
sql_fields.append(sql_fields[1]) |
|
del sql_fields[1] |
|
|
|
|
|
def parse(self, code, section): |
|
return self.parse_sql(code) |
|
|
|
def unparse(self, tree, item): |
|
unparser = SpiderUnparser(self.ast_wrapper, item.schema, self.factorize_sketch) |
|
return unparser.unparse_sql(tree) |
|
|
|
@classmethod |
|
def tokenize_field_value(cls, field_value): |
|
if isinstance(field_value, bytes): |
|
field_value_str = field_value.encode('latin1') |
|
elif isinstance(field_value, str): |
|
field_value_str = field_value |
|
else: |
|
field_value_str = str(field_value) |
|
if field_value_str[0] == '"' and field_value_str[-1] == '"': |
|
field_value_str = field_value_str[1:-1] |
|
|
|
return [field_value_str] |
|
|
|
|
|
|
|
|
|
|
|
def parse_val(self, val): |
|
if isinstance(val, str): |
|
if not self.include_literals: |
|
return {'_type': 'Terminal'} |
|
return { |
|
'_type': 'String', |
|
's': val, |
|
} |
|
elif isinstance(val, list): |
|
return { |
|
'_type': 'ColUnit', |
|
'c': self.parse_col_unit(val), |
|
} |
|
elif isinstance(val, float): |
|
if not self.include_literals: |
|
return {'_type': 'Terminal'} |
|
return { |
|
'_type': 'Number', |
|
'f': val, |
|
} |
|
elif isinstance(val, dict): |
|
return { |
|
'_type': 'ValSql', |
|
's': self.parse_sql(val), |
|
} |
|
else: |
|
raise ValueError(val) |
|
|
|
def parse_col_unit(self, col_unit): |
|
agg_id, col_id, is_distinct = col_unit |
|
result = { |
|
'_type': 'col_unit', |
|
'agg_id': {'_type': self.AGG_TYPES_F[agg_id]}, |
|
'is_distinct': is_distinct, |
|
} |
|
if self.include_columns: |
|
result['col_id'] = col_id |
|
return result |
|
|
|
def parse_val_unit(self, val_unit): |
|
unit_op, col_unit1, col_unit2 = val_unit |
|
result = { |
|
'_type': self.UNIT_TYPES_F[unit_op], |
|
'col_unit1': self.parse_col_unit(col_unit1), |
|
} |
|
if unit_op != 0: |
|
result['col_unit2'] = self.parse_col_unit(col_unit2) |
|
return result |
|
|
|
def parse_table_unit(self, table_unit): |
|
table_type, value = table_unit |
|
if table_type == 'sql': |
|
return { |
|
'_type': 'TableUnitSql', |
|
's': self.parse_sql(value), |
|
} |
|
elif table_type == 'table_unit': |
|
return { |
|
'_type': 'Table', |
|
'table_id': value, |
|
} |
|
else: |
|
raise ValueError(table_type) |
|
|
|
def parse_cond(self, cond, optional=False): |
|
if optional and not cond: |
|
return None |
|
|
|
if len(cond) > 1: |
|
return { |
|
'_type': self.LOGIC_OPERATORS_F[cond[1]], |
|
'left': self.parse_cond(cond[:1]), |
|
'right': self.parse_cond(cond[2:]), |
|
} |
|
|
|
(not_op, op_id, val_unit, val1, val2), = cond |
|
result = { |
|
'_type': self.COND_TYPES_F[op_id], |
|
'val_unit': self.parse_val_unit(val_unit), |
|
'val1': self.parse_val(val1), |
|
} |
|
if op_id == 1: |
|
result['val2'] = self.parse_val(val2) |
|
if not_op: |
|
result = { |
|
'_type': 'Not', |
|
'c': result, |
|
} |
|
return result |
|
|
|
def parse_sql(self, sql, optional=False): |
|
if optional and sql is None: |
|
return None |
|
if self.factorize_sketch == 0: |
|
return filter_nones({ |
|
'_type': 'sql', |
|
'select': self.parse_select(sql['select']), |
|
'where': self.parse_cond(sql['where'], optional=True), |
|
'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], |
|
'order_by': self.parse_order_by(sql['orderBy']), |
|
'having': self.parse_cond(sql['having'], optional=True), |
|
'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), |
|
'intersect': self.parse_sql(sql['intersect'], optional=True), |
|
'except': self.parse_sql(sql['except'], optional=True), |
|
'union': self.parse_sql(sql['union'], optional=True), |
|
**({ |
|
'from': self.parse_from(sql['from'], self.infer_from_conditions), |
|
} if self.output_from else {}) |
|
}) |
|
elif self.factorize_sketch == 1: |
|
return filter_nones({ |
|
'_type': 'sql', |
|
'select': self.parse_select(sql['select']), |
|
**({ |
|
'from': self.parse_from(sql['from'], self.infer_from_conditions), |
|
} if self.output_from else {}), |
|
'sql_where': filter_nones({ |
|
'_type': 'sql_where', |
|
'where': self.parse_cond(sql['where'], optional=True), |
|
'sql_groupby': filter_nones({ |
|
'_type': 'sql_groupby', |
|
'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], |
|
'having': filter_nones({ |
|
'_type': 'having', |
|
'having': self.parse_cond(sql['having'], optional=True), |
|
}), |
|
'sql_orderby': filter_nones({ |
|
'_type': 'sql_orderby', |
|
'order_by': self.parse_order_by(sql['orderBy']), |
|
'limit': filter_nones({ |
|
'_type': 'limit', |
|
'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), |
|
}), |
|
'sql_ieu': filter_nones({ |
|
'_type': 'sql_ieu', |
|
'intersect': self.parse_sql(sql['intersect'], optional=True), |
|
'except': self.parse_sql(sql['except'], optional=True), |
|
'union': self.parse_sql(sql['union'], optional=True), |
|
}) |
|
}) |
|
}) |
|
}) |
|
}) |
|
elif self.factorize_sketch == 2: |
|
return filter_nones({ |
|
'_type': 'sql', |
|
'select': self.parse_select(sql['select']), |
|
**({ |
|
'from': self.parse_from(sql['from'], self.infer_from_conditions), |
|
} if self.output_from else {}), |
|
"sql_where": filter_nones({ |
|
'_type': 'sql_where', |
|
'where': self.parse_cond(sql['where'], optional=True), |
|
}), |
|
"sql_groupby": filter_nones({ |
|
'_type': 'sql_groupby', |
|
'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], |
|
'having': self.parse_cond(sql['having'], optional=True), |
|
}), |
|
"sql_orderby": filter_nones({ |
|
'_type': 'sql_orderby', |
|
'order_by': self.parse_order_by(sql['orderBy']), |
|
'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), |
|
}), |
|
'sql_ieu': filter_nones({ |
|
'_type': 'sql_ieu', |
|
'intersect': self.parse_sql(sql['intersect'], optional=True), |
|
'except': self.parse_sql(sql['except'], optional=True), |
|
'union': self.parse_sql(sql['union'], optional=True), |
|
}) |
|
}) |
|
|
|
def parse_select(self, select): |
|
is_distinct, aggs = select |
|
return { |
|
'_type': 'select', |
|
'is_distinct': is_distinct, |
|
'aggs': [self.parse_agg(agg) for agg in aggs], |
|
} |
|
|
|
def parse_agg(self, agg): |
|
agg_id, val_unit = agg |
|
return { |
|
'_type': 'agg', |
|
'agg_id': {'_type': self.AGG_TYPES_F[agg_id]}, |
|
'val_unit': self.parse_val_unit(val_unit), |
|
} |
|
|
|
def parse_from(self, from_, infer_from_conditions=False): |
|
return filter_nones({ |
|
'_type': 'from', |
|
'table_units': [ |
|
self.parse_table_unit(u) for u in from_['table_units']], |
|
'conds': self.parse_cond(from_['conds'], optional=True) \ |
|
if not infer_from_conditions else None, |
|
}) |
|
|
|
def parse_order_by(self, order_by): |
|
if not order_by: |
|
return None |
|
|
|
order, val_units = order_by |
|
return { |
|
'_type': 'order_by', |
|
'order': {'_type': self.ORDERS_F[order]}, |
|
'val_units': [self.parse_val_unit(v) for v in val_units] |
|
} |
|
|
|
|
|
|
|
|
|
|
|
COND_TYPES_F, COND_TYPES_B = bimap( |
|
|
|
|
|
range(1, 10), |
|
('Between', 'Eq', 'Gt', 'Lt', 'Ge', 'Le', 'Ne', 'In', 'Like')) |
|
|
|
UNIT_TYPES_F, UNIT_TYPES_B = bimap( |
|
|
|
range(5), |
|
('Column', 'Minus', 'Plus', 'Times', 'Divide')) |
|
|
|
AGG_TYPES_F, AGG_TYPES_B = bimap( |
|
range(6), |
|
('NoneAggOp', 'Max', 'Min', 'Count', 'Sum', 'Avg')) |
|
|
|
ORDERS_F, ORDERS_B = bimap( |
|
('asc', 'desc'), |
|
('Asc', 'Desc')) |
|
|
|
LOGIC_OPERATORS_F, LOGIC_OPERATORS_B = bimap( |
|
('and', 'or'), |
|
('And', 'Or')) |
|
|
|
@attr.s |
|
class SpiderUnparser: |
|
ast_wrapper = attr.ib() |
|
schema = attr.ib() |
|
factorize_sketch = attr.ib(default=0) |
|
|
|
UNIT_TYPES_B = { |
|
'Minus': '-', |
|
'Plus': '+', |
|
'Times': '*', |
|
'Divide': '/', |
|
} |
|
COND_TYPES_B = { |
|
'Between': 'BETWEEN', |
|
'Eq': '=', |
|
'Gt': '>', |
|
'Lt': '<', |
|
'Ge': '>=', |
|
'Le': '<=', |
|
'Ne': '!=', |
|
'In': 'IN', |
|
'Like': 'LIKE' |
|
} |
|
|
|
@classmethod |
|
def conjoin_conds(cls, conds): |
|
if not conds: |
|
return None |
|
if len(conds) == 1: |
|
return conds[0] |
|
return {'_type': 'And', 'left': conds[0], 'right': cls.conjoin_conds(conds[1:])} |
|
|
|
@classmethod |
|
def linearize_cond(cls, cond): |
|
if cond['_type'] in ('And', 'Or'): |
|
conds, keywords = cls.linearize_cond(cond['right']) |
|
return [cond['left']] + conds, [cond['_type']] + keywords |
|
else: |
|
return [cond], [] |
|
|
|
def unparse_val(self, val): |
|
if val['_type'] == 'Terminal': |
|
return "'terminal'" |
|
if val['_type'] == 'String': |
|
return val['s'] |
|
if val['_type'] == 'ColUnit': |
|
return self.unparse_col_unit(val['c']) |
|
if val['_type'] == 'Number': |
|
return str(val['f']) |
|
if val['_type'] == 'ValSql': |
|
return '({})'.format(self.unparse_sql(val['s'])) |
|
|
|
def unparse_col_unit(self, col_unit): |
|
if 'col_id' in col_unit: |
|
column = self.schema.columns[col_unit['col_id']] |
|
if column.table is None: |
|
column_name = column.orig_name |
|
else: |
|
column_name = '{}.{}'.format(column.table.orig_name, column.orig_name) |
|
else: |
|
column_name = 'some_col' |
|
|
|
if col_unit['is_distinct']: |
|
column_name = 'DISTINCT {}'.format(column_name) |
|
agg_type = col_unit['agg_id']['_type'] |
|
if agg_type == 'NoneAggOp': |
|
return column_name |
|
else: |
|
return '{}({})'.format(agg_type, column_name) |
|
|
|
def unparse_val_unit(self, val_unit): |
|
if val_unit['_type'] == 'Column': |
|
return self.unparse_col_unit(val_unit['col_unit1']) |
|
col1 = self.unparse_col_unit(val_unit['col_unit1']) |
|
col2 = self.unparse_col_unit(val_unit['col_unit2']) |
|
return '{} {} {}'.format(col1, self.UNIT_TYPES_B[val_unit['_type']], col2) |
|
|
|
|
|
|
|
|
|
def unparse_cond(self, cond, negated=False): |
|
if cond['_type'] == 'And': |
|
assert not negated |
|
return '{} AND {}'.format( |
|
self.unparse_cond(cond['left']), self.unparse_cond(cond['right'])) |
|
elif cond['_type'] == 'Or': |
|
assert not negated |
|
return '{} OR {}'.format( |
|
self.unparse_cond(cond['left']), self.unparse_cond(cond['right'])) |
|
elif cond['_type'] == 'Not': |
|
return self.unparse_cond(cond['c'], negated=True) |
|
elif cond['_type'] == 'Between': |
|
tokens = [self.unparse_val_unit(cond['val_unit'])] |
|
if negated: |
|
tokens.append('NOT') |
|
tokens += [ |
|
'BETWEEN', |
|
self.unparse_val(cond['val1']), |
|
'AND', |
|
self.unparse_val(cond['val2']), |
|
] |
|
return ' '.join(tokens) |
|
tokens = [self.unparse_val_unit(cond['val_unit'])] |
|
if negated: |
|
tokens.append('NOT') |
|
tokens += [self.COND_TYPES_B[cond['_type']], self.unparse_val(cond['val1'])] |
|
return ' '.join(tokens) |
|
|
|
def refine_from(self, tree): |
|
""" |
|
1) Inferring tables from columns predicted |
|
2) Mix them with the predicted tables if any |
|
3) Inferring conditions based on tables |
|
""" |
|
tree = dict(tree) |
|
|
|
|
|
if "from" in tree and tree["from"]["table_units"][0]["_type"] == 'TableUnitSql': |
|
for table_unit in tree["from"]["table_units"]: |
|
subquery_tree = table_unit["s"] |
|
self.refine_from(subquery_tree) |
|
return |
|
|
|
|
|
predicted_from_table_ids = set() |
|
if "from" in tree: |
|
table_unit_set = [] |
|
for table_unit in tree["from"]["table_units"]: |
|
if table_unit["table_id"] not in predicted_from_table_ids: |
|
predicted_from_table_ids.add(table_unit["table_id"]) |
|
table_unit_set.append(table_unit) |
|
tree["from"]["table_units"] = table_unit_set |
|
|
|
|
|
candidate_column_ids = set(self.ast_wrapper.find_all_descendants_of_type( |
|
tree, 'column', lambda field: field.type != 'sql')) |
|
candidate_columns = [self.schema.columns[i] for i in candidate_column_ids] |
|
must_in_from_table_ids = set( |
|
column.table.id for column in candidate_columns if column.table is not None) |
|
|
|
|
|
all_from_table_ids = must_in_from_table_ids.union(predicted_from_table_ids) |
|
if not all_from_table_ids: |
|
|
|
all_from_table_ids = {0} |
|
|
|
covered_tables = set() |
|
candidate_table_ids = sorted(all_from_table_ids) |
|
start_table_id = candidate_table_ids[0] |
|
conds = [] |
|
for table_id in candidate_table_ids[1:]: |
|
if table_id in covered_tables: |
|
continue |
|
try: |
|
path = nx.shortest_path( |
|
self.schema.foreign_key_graph, source=start_table_id, target=table_id) |
|
except (nx.NetworkXNoPath, nx.NodeNotFound): |
|
covered_tables.add(table_id) |
|
continue |
|
|
|
for source_table_id, target_table_id in zip(path, path[1:]): |
|
if target_table_id in covered_tables: |
|
continue |
|
all_from_table_ids.add(target_table_id) |
|
col1, col2 = self.schema.foreign_key_graph[source_table_id][target_table_id]['columns'] |
|
conds.append({ |
|
'_type': 'Eq', |
|
'val_unit': { |
|
'_type': 'Column', |
|
'col_unit1': { |
|
'_type': 'col_unit', |
|
'agg_id': {'_type': 'NoneAggOp'}, |
|
'col_id': col1, |
|
'is_distinct': False |
|
}, |
|
}, |
|
'val1': { |
|
'_type': 'ColUnit', |
|
'c': { |
|
'_type': 'col_unit', |
|
'agg_id': {'_type': 'NoneAggOp'}, |
|
'col_id': col2, |
|
'is_distinct': False |
|
} |
|
} |
|
}) |
|
table_units = [{'_type': 'Table', 'table_id': i} for i in sorted(all_from_table_ids)] |
|
|
|
tree['from'] = { |
|
'_type': 'from', |
|
'table_units': table_units, |
|
} |
|
cond_node = self.conjoin_conds(conds) |
|
if cond_node is not None: |
|
tree['from']['conds'] = cond_node |
|
|
|
|
|
def unparse_sql(self, tree): |
|
self.refine_from(tree) |
|
|
|
result = [ |
|
|
|
self.unparse_select(tree['select']), |
|
|
|
self.unparse_from(tree['from']), |
|
] |
|
|
|
def find_subtree(_tree, name): |
|
if self.factorize_sketch == 0: |
|
return _tree, _tree |
|
elif name in _tree: |
|
if self.factorize_sketch == 1: |
|
return _tree[name], _tree[name] |
|
elif self.factorize_sketch == 2: |
|
return _tree, _tree[name] |
|
else: |
|
raise NotImplementedError |
|
|
|
tree, target_tree = find_subtree(tree, "sql_where") |
|
|
|
if 'where' in target_tree: |
|
result += [ |
|
'WHERE', |
|
self.unparse_cond(target_tree['where']) |
|
] |
|
|
|
tree, target_tree = find_subtree(tree, "sql_groupby") |
|
|
|
if 'group_by' in target_tree: |
|
result += [ |
|
'GROUP BY', |
|
', '.join(self.unparse_col_unit(c) for c in target_tree['group_by']) |
|
] |
|
|
|
tree, target_tree = find_subtree(tree, "sql_orderby") |
|
|
|
if 'order_by' in target_tree: |
|
result.append(self.unparse_order_by(target_tree['order_by'])) |
|
|
|
tree, target_tree = find_subtree(tree, "sql_groupby") |
|
|
|
if 'having' in target_tree: |
|
result += ['HAVING', self.unparse_cond(target_tree['having'])] |
|
|
|
tree, target_tree = find_subtree(tree, "sql_orderby") |
|
|
|
if 'limit' in target_tree: |
|
if isinstance(target_tree['limit'], bool): |
|
if target_tree['limit']: |
|
result += ['LIMIT', '1'] |
|
else: |
|
result += ['LIMIT', str(target_tree['limit'])] |
|
|
|
tree, target_tree = find_subtree(tree, "sql_ieu") |
|
|
|
if 'intersect' in target_tree: |
|
result += ['INTERSECT', self.unparse_sql(target_tree['intersect'])] |
|
|
|
if 'except' in target_tree: |
|
result += ['EXCEPT', self.unparse_sql(target_tree['except'])] |
|
|
|
if 'union' in target_tree: |
|
result += ['UNION', self.unparse_sql(target_tree['union'])] |
|
|
|
return ' '.join(result) |
|
|
|
def unparse_select(self, select): |
|
tokens = ['SELECT'] |
|
if select['is_distinct']: |
|
tokens.append('DISTINCT') |
|
tokens.append(', '.join(self.unparse_agg(agg) for agg in select.get('aggs', []))) |
|
return ' '.join(tokens) |
|
|
|
def unparse_agg(self, agg): |
|
unparsed_val_unit = self.unparse_val_unit(agg['val_unit']) |
|
agg_type = agg['agg_id']['_type'] |
|
if agg_type == 'NoneAggOp': |
|
return unparsed_val_unit |
|
else: |
|
return '{}({})'.format(agg_type, unparsed_val_unit) |
|
|
|
def unparse_from(self, from_): |
|
if 'conds' in from_: |
|
all_conds, keywords = self.linearize_cond(from_['conds']) |
|
else: |
|
all_conds, keywords = [], [] |
|
assert all(keyword == 'And' for keyword in keywords) |
|
|
|
cond_indices_by_table = collections.defaultdict(set) |
|
tables_involved_by_cond_idx = collections.defaultdict(set) |
|
for i, cond in enumerate(all_conds): |
|
for column in self.ast_wrapper.find_all_descendants_of_type(cond, 'column'): |
|
table = self.schema.columns[column].table |
|
if table is None: |
|
continue |
|
cond_indices_by_table[table.id].add(i) |
|
tables_involved_by_cond_idx[i].add(table.id) |
|
|
|
output_table_ids = set() |
|
output_cond_indices = set() |
|
tokens = ['FROM'] |
|
for i, table_unit in enumerate(from_.get('table_units', [])): |
|
if i > 0: |
|
tokens += ['JOIN'] |
|
|
|
if table_unit['_type'] == 'TableUnitSql': |
|
tokens.append('({})'.format(self.unparse_sql(table_unit['s']))) |
|
elif table_unit['_type'] == 'Table': |
|
table_id = table_unit['table_id'] |
|
tokens += [self.schema.tables[table_id].orig_name] |
|
output_table_ids.add(table_id) |
|
|
|
|
|
conds_to_output = [] |
|
for cond_idx in sorted(cond_indices_by_table[table_id]): |
|
if cond_idx in output_cond_indices: |
|
continue |
|
if tables_involved_by_cond_idx[cond_idx] <= output_table_ids: |
|
conds_to_output.append(all_conds[cond_idx]) |
|
output_cond_indices.add(cond_idx) |
|
if conds_to_output: |
|
tokens += ['ON'] |
|
tokens += list(intersperse( |
|
'AND', |
|
(self.unparse_cond(cond) for cond in conds_to_output))) |
|
return ' '.join(tokens) |
|
|
|
def unparse_order_by(self, order_by): |
|
return 'ORDER BY {} {}'.format( |
|
', '.join(self.unparse_val_unit(v) for v in order_by['val_units']), |
|
order_by['order']['_type']) |
|
|