gap-text2sql
/
gap-text2sql-main
/mrat-sql-gap
/seq2struct
/datasets
/spider_lib
/preprocess
/get_tables-ORI.py
import os | |
import sys | |
import json | |
import sqlite3 | |
from os import listdir, makedirs | |
from os.path import isfile, isdir, join, split, exists, splitext | |
from nltk import word_tokenize, tokenize | |
import traceback | |
EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"} | |
def convert_fk_index(data): | |
fk_holder = [] | |
for fk in data["foreign_keys"]: | |
tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] | |
ref_cid, cid = None, None | |
try: | |
tid = data['table_names_original'].index(tn) | |
ref_tid = data['table_names_original'].index(ref_tn) | |
for i, (tab_id, col_org) in enumerate(data['column_names_original']): | |
if tab_id == ref_tid and ref_col == col_org: | |
ref_cid = i | |
elif tid == tab_id and col == col_org: | |
cid = i | |
if ref_cid and cid: | |
fk_holder.append([cid, ref_cid]) | |
except: | |
traceback.print_exc() | |
print("table_names_original: ", data['table_names_original']) | |
print("finding tab name: ", tn, ref_tn) | |
sys.exit() | |
return fk_holder | |
def dump_db_json_schema(db, f): | |
'''read table and column info''' | |
conn = sqlite3.connect(db) | |
conn.execute('pragma foreign_keys=ON') | |
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
data = {'db_id': f, | |
'table_names_original': [], | |
'table_names': [], | |
'column_names_original': [(-1, '*')], | |
'column_names': [(-1, '*')], | |
'column_types': ['text'], | |
'primary_keys': [], | |
'foreign_keys': []} | |
fk_holder = [] | |
for i, item in enumerate(cursor.fetchall()): | |
table_name = item[0] | |
data['table_names_original'].append(table_name) | |
data['table_names'].append(table_name.lower().replace("_", ' ')) | |
fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall() | |
#print("db:{} table:{} fks:{}".format(f,table_name,fks)) | |
fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) | |
cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) | |
for j, col in enumerate(cur.fetchall()): | |
data['column_names_original'].append((i, col[1])) | |
data['column_names'].append((i, col[1].lower().replace("_", " "))) | |
#varchar, '' -> text, int, numeric -> integer, | |
col_type = col[2].lower() | |
if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type: | |
data['column_types'].append('text') | |
elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type\ | |
or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type: | |
data['column_types'].append('number') | |
elif 'date' in col_type or 'time' in col_type or 'year' in col_type: | |
data['column_types'].append('time') | |
elif 'boolean' in col_type: | |
data['column_types'].append('boolean') | |
else: | |
data['column_types'].append('others') | |
if col[5] == 1: | |
data['primary_keys'].append(len(data['column_names'])-1) | |
data["foreign_keys"] = fk_holder | |
data['foreign_keys'] = convert_fk_index(data) | |
return data | |
if __name__ == '__main__': | |
if len(sys.argv) < 2: | |
print("Usage: python get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]") | |
sys.exit() | |
input_dir = sys.argv[1] | |
output_file = sys.argv[2] | |
ex_tab_file = sys.argv[3] | |
all_fs = [df for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))] | |
with open(ex_tab_file) as f: | |
ex_tabs = json.load(f) | |
#for tab in ex_tabs: | |
# tab["foreign_keys"] = convert_fk_index(tab) | |
ex_tabs = {tab["db_id"]: tab for tab in ex_tabs if tab["db_id"] in all_fs} | |
print("precessed file num: ", len(ex_tabs)) | |
not_fs = [df for df in listdir(input_dir) if not exists(join(input_dir, df, df+'.sqlite'))] | |
for d in not_fs: | |
print("no sqlite file found in: ", d) | |
db_files = [(df+'.sqlite', df) for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))] | |
tables = [] | |
for f, df in db_files: | |
#if df in ex_tabs.keys(): | |
#print 'reading old db: ', df | |
# tables.append(ex_tabs[df]) | |
db = join(input_dir, df, f) | |
print('\nreading new db: ', df) | |
table = dump_db_json_schema(db, df) | |
prev_tab_num = len(ex_tabs[df]["table_names"]) | |
prev_col_num = len(ex_tabs[df]["column_names"]) | |
cur_tab_num = len(table["table_names"]) | |
cur_col_num = len(table["column_names"]) | |
if df in ex_tabs.keys() and prev_tab_num == cur_tab_num and prev_col_num == cur_col_num and prev_tab_num != 0 and len(ex_tabs[df]["column_names"]) > 1: | |
table["table_names"] = ex_tabs[df]["table_names"] | |
table["column_names"] = ex_tabs[df]["column_names"] | |
else: | |
print("\n----------------------------------problem db: ", df) | |
tables.append(table) | |
print("final db num: ", len(tables)) | |
with open(output_file, 'wt') as out: | |
json.dump(tables, out, sort_keys=True, indent=2, separators=(',', ': ')) |