File size: 1,373 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import os, sys
import json
import sqlite3
import traceback
import argparse
import tqdm
from seq2struct.datasets.spider_lib.process_sql import get_sql
from seq2struct.datasets.spider_lib.preprocess.schema import Schema, get_schemas_from_json
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True)
parser.add_argument('--tables', required=True)
parser.add_argument('--output', required=True)
args = parser.parse_args()
sql_path = args.input
output_file = args.output
table_file = args.tables
schemas, db_names, tables = get_schemas_from_json(table_file)
with open(sql_path, encoding='utf8') as inf:
sql_data = json.load(inf)
sql_data_new = []
for data in tqdm.tqdm(sql_data):
try:
db_id = data["db_id"]
schema = schemas[db_id]
table = tables[db_id]
schema = Schema(schema, table)
sql = data["query"]
sql_label = get_sql(schema, sql)
data["sql"] = sql_label
sql_data_new.append(data)
except:
print("db_id: ", db_id)
print("sql: ", sql)
raise
with open(output_file, 'wt', encoding='utf8') as out:
json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
|