|
import json |
|
|
|
import attr |
|
import torch |
|
import networkx as nx |
|
|
|
from seq2struct.utils import registry |
|
from seq2struct.datasets.spider_lib import evaluation |
|
|
|
|
|
@attr.s |
|
class SpiderItem: |
|
text = attr.ib() |
|
code = attr.ib() |
|
schema = attr.ib() |
|
orig = attr.ib() |
|
orig_schema = attr.ib() |
|
|
|
|
|
@attr.s |
|
class Column: |
|
id = attr.ib() |
|
table = attr.ib() |
|
name = attr.ib() |
|
unsplit_name = attr.ib() |
|
orig_name = attr.ib() |
|
type = attr.ib() |
|
foreign_key_for = attr.ib(default=None) |
|
|
|
|
|
@attr.s |
|
class Table: |
|
id = attr.ib() |
|
name = attr.ib() |
|
unsplit_name = attr.ib() |
|
orig_name = attr.ib() |
|
columns = attr.ib(factory=list) |
|
primary_keys = attr.ib(factory=list) |
|
|
|
|
|
@attr.s |
|
class Schema: |
|
db_id = attr.ib() |
|
tables = attr.ib() |
|
columns = attr.ib() |
|
foreign_key_graph = attr.ib() |
|
orig = attr.ib() |
|
|
|
|
|
def load_tables(paths): |
|
schemas = {} |
|
eval_foreign_key_maps = {} |
|
|
|
for path in paths: |
|
schema_dicts = json.load(open(path, encoding='utf8')) |
|
for schema_dict in schema_dicts: |
|
tables = tuple( |
|
Table( |
|
id=i, |
|
name=name.split(), |
|
unsplit_name=name, |
|
orig_name=orig_name, |
|
) |
|
for i, (name, orig_name) in enumerate(zip( |
|
schema_dict['table_names'], schema_dict['table_names_original'])) |
|
) |
|
columns = tuple( |
|
Column( |
|
id=i, |
|
table=tables[table_id] if table_id >= 0 else None, |
|
name=col_name.split(), |
|
unsplit_name=col_name, |
|
orig_name=orig_col_name, |
|
type=col_type, |
|
) |
|
for i, ((table_id, col_name), (_, orig_col_name), col_type) in enumerate(zip( |
|
schema_dict['column_names'], |
|
schema_dict['column_names_original'], |
|
schema_dict['column_types'])) |
|
) |
|
|
|
|
|
for column in columns: |
|
if column.table: |
|
column.table.columns.append(column) |
|
|
|
for column_id in schema_dict['primary_keys']: |
|
|
|
column = columns[column_id] |
|
column.table.primary_keys.append(column) |
|
|
|
foreign_key_graph = nx.DiGraph() |
|
for source_column_id, dest_column_id in schema_dict['foreign_keys']: |
|
|
|
source_column = columns[source_column_id] |
|
dest_column = columns[dest_column_id] |
|
source_column.foreign_key_for = dest_column |
|
foreign_key_graph.add_edge( |
|
source_column.table.id, |
|
dest_column.table.id, |
|
columns=(source_column_id, dest_column_id)) |
|
foreign_key_graph.add_edge( |
|
dest_column.table.id, |
|
source_column.table.id, |
|
columns=(dest_column_id, source_column_id)) |
|
|
|
db_id = schema_dict['db_id'] |
|
assert db_id not in schemas |
|
schemas[db_id] = Schema(db_id, tables, columns, foreign_key_graph, schema_dict) |
|
eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(schema_dict) |
|
|
|
return schemas, eval_foreign_key_maps |
|
|
|
def load_tables_from_schema_dict(schema_dict): |
|
schemas = {} |
|
eval_foreign_key_maps = {} |
|
schema_dicts = [schema_dict] |
|
|
|
for schema_dict in schema_dicts: |
|
tables = tuple( |
|
Table( |
|
id=i, |
|
name=name.split(), |
|
unsplit_name=name, |
|
orig_name=orig_name, |
|
) |
|
for i, (name, orig_name) in enumerate(zip( |
|
schema_dict['table_names'], schema_dict['table_names_original'])) |
|
) |
|
columns = tuple( |
|
Column( |
|
id=i, |
|
table=tables[table_id] if table_id >= 0 else None, |
|
name=col_name.split(), |
|
unsplit_name=col_name, |
|
orig_name=orig_col_name, |
|
type=col_type, |
|
) |
|
for i, ((table_id, col_name), (_, orig_col_name), col_type) in enumerate(zip( |
|
schema_dict['column_names'], |
|
schema_dict['column_names_original'], |
|
schema_dict['column_types'])) |
|
) |
|
|
|
|
|
for column in columns: |
|
if column.table: |
|
column.table.columns.append(column) |
|
|
|
for column_id in schema_dict['primary_keys']: |
|
|
|
column = columns[column_id] |
|
column.table.primary_keys.append(column) |
|
|
|
foreign_key_graph = nx.DiGraph() |
|
for source_column_id, dest_column_id in schema_dict['foreign_keys']: |
|
|
|
source_column = columns[source_column_id] |
|
dest_column = columns[dest_column_id] |
|
source_column.foreign_key_for = dest_column |
|
foreign_key_graph.add_edge( |
|
source_column.table.id, |
|
dest_column.table.id, |
|
columns=(source_column_id, dest_column_id)) |
|
foreign_key_graph.add_edge( |
|
dest_column.table.id, |
|
source_column.table.id, |
|
columns=(dest_column_id, source_column_id)) |
|
|
|
db_id = schema_dict['db_id'] |
|
assert db_id not in schemas |
|
schemas[db_id] = Schema(db_id, tables, columns, foreign_key_graph, schema_dict) |
|
eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(schema_dict) |
|
|
|
return schemas, eval_foreign_key_maps |
|
|
|
@registry.register('dataset_infer', 'spider') |
|
class SpiderDatasetInfer(torch.utils.data.Dataset): |
|
def __init__(self, schemas, eval_foreign_key_maps, db_path): |
|
self.db_path = db_path |
|
self.examples = [] |
|
self.schemas, self.eval_foreign_key_maps = schemas, eval_foreign_key_maps |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, idx): |
|
return self.examples[idx] |
|
|
|
@registry.register('dataset', 'spider') |
|
class SpiderDataset(torch.utils.data.Dataset): |
|
def __init__(self, paths, tables_paths, db_path, limit=None): |
|
self.paths = paths |
|
self.db_path = db_path |
|
self.examples = [] |
|
|
|
self.schemas, self.eval_foreign_key_maps = load_tables(tables_paths) |
|
|
|
for path in paths: |
|
raw_data = json.load(open(path, encoding='utf8')) |
|
for entry in raw_data: |
|
item = SpiderItem( |
|
text=entry['question_toks'], |
|
code=entry['sql'], |
|
schema=self.schemas[entry['db_id']], |
|
orig=entry, |
|
orig_schema=self.schemas[entry['db_id']].orig) |
|
self.examples.append(item) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, idx): |
|
return self.examples[idx] |
|
|
|
class Metrics: |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.foreign_key_maps = { |
|
db_id: evaluation.build_foreign_key_map(schema.orig) |
|
for db_id, schema in self.dataset.schemas.items() |
|
} |
|
self.evaluator = evaluation.Evaluator( |
|
self.dataset.db_path, |
|
self.foreign_key_maps, |
|
'match') |
|
self.results = [] |
|
|
|
def add(self, item, inferred_code, orig_question=None): |
|
ret_dict = self.evaluator.evaluate_one( |
|
item.schema.db_id, item.orig['query'], inferred_code) |
|
if orig_question: |
|
ret_dict["orig_question"] = orig_question |
|
self.results.append(ret_dict) |
|
|
|
def add_beams(self, item, inferred_codes, orig_question=None): |
|
beam_dict = {} |
|
if orig_question: |
|
beam_dict["orig_question"] = orig_question |
|
for i, code in enumerate(inferred_codes): |
|
ret_dict = self.evaluator.evaluate_one( |
|
item.schema.db_id, item.orig['query'], code) |
|
beam_dict[i] = ret_dict |
|
if ret_dict["exact"] is True: |
|
break |
|
self.results.append(beam_dict) |
|
|
|
def finalize(self): |
|
self.evaluator.finalize() |
|
return { |
|
'per_item': self.results, |
|
'total_scores': self.evaluator.scores |
|
} |
|
|
|
|
|
@registry.register('dataset', 'spider_idiom_ast') |
|
class SpiderIdiomAstDataset(torch.utils.data.Dataset): |
|
|
|
def __init__(self, paths, tables_paths, db_path, limit=None): |
|
self.paths = paths |
|
self.db_path = db_path |
|
self.examples = [] |
|
|
|
self.schemas, self.eval_foreign_key_maps = load_tables(tables_paths) |
|
|
|
for path in paths: |
|
for line in open(path, encoding='utf8'): |
|
entry = json.loads(line) |
|
item = SpiderItem( |
|
text=entry['orig']['question_toks'], |
|
code=entry['rewritten_ast'], |
|
schema=self.schemas[entry['orig']['db_id']], |
|
orig=entry['orig'], |
|
orig_schema=self.schemas[entry['orig']['db_id']].orig) |
|
self.examples.append(item) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, idx): |
|
return self.examples[idx] |
|
|
|
Metrics = SpiderDataset.Metrics |