|
|
|
|
|
import _jsonnet |
|
import json |
|
import argparse |
|
import collections |
|
import attr |
|
from seq2struct.commands import preprocess, train, infer, eval |
|
import crash_on_ipy |
|
|
|
@attr.s |
|
class PreprocessConfig: |
|
config = attr.ib() |
|
config_args = attr.ib() |
|
|
|
@attr.s |
|
class TrainConfig: |
|
config = attr.ib() |
|
config_args = attr.ib() |
|
logdir = attr.ib() |
|
|
|
@attr.s |
|
class InferConfig: |
|
config = attr.ib() |
|
config_args = attr.ib() |
|
logdir = attr.ib() |
|
section = attr.ib() |
|
beam_size = attr.ib() |
|
output = attr.ib() |
|
step = attr.ib() |
|
use_heuristic = attr.ib(default=False) |
|
mode = attr.ib(default="infer") |
|
limit = attr.ib(default=None) |
|
output_history = attr.ib(default=False) |
|
|
|
@attr.s |
|
class EvalConfig: |
|
config = attr.ib() |
|
config_args = attr.ib() |
|
logdir = attr.ib() |
|
section = attr.ib() |
|
inferred = attr.ib() |
|
output = attr.ib() |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('mode', help="preprocess/train/eval") |
|
parser.add_argument('exp_config_file', help="jsonnet file for experiments") |
|
args = parser.parse_args() |
|
|
|
exp_config = json.loads(_jsonnet.evaluate_file(args.exp_config_file)) |
|
model_config_file = exp_config["model_config"] |
|
if "model_config_args" in exp_config: |
|
model_config_args = json.dumps(exp_config["model_config_args"]) |
|
else: |
|
model_config_args = None |
|
other_config = json.loads(_jsonnet.evaluate_file(model_config_file, tla_codes={'args': model_config_args})) |
|
|
|
if args.mode == "preprocess": |
|
preprocess_config = PreprocessConfig(model_config_file, \ |
|
model_config_args) |
|
preprocess.main(preprocess_config) |
|
elif args.mode == "train": |
|
train_config = TrainConfig(model_config_file, |
|
model_config_args, exp_config["logdir"]) |
|
train.main(train_config) |
|
elif args.mode == "eval": |
|
result = open(f"{exp_config['eval_output']}/eval-results.csv", "a", encoding='utf8') |
|
result.write(f"checkpoint;type;easy;medium;hard;extra;all\n") |
|
result.close() |
|
first_loop = True |
|
|
|
|
|
gold = open(f"{exp_config['eval_output']}/gold.txt", "w", encoding='utf8') |
|
print(f"Open file {other_config['data']['val']['paths'][0]}") |
|
with open(f"{other_config['data']['val']['paths'][0]}", encoding='utf8') as json_data_file: |
|
data = json.load(json_data_file) |
|
length = len(data) |
|
for i in range(length): |
|
gold.write(f"{data[i]['query']}\t{data[i]['db_id']}\n") |
|
json_data_file.close() |
|
gold.close() |
|
|
|
for step in exp_config["eval_steps"]: |
|
infer_output_path = "{}/{}-step{}".format( |
|
exp_config["eval_output"], |
|
exp_config["eval_name"], |
|
step) |
|
infer_config = InferConfig( |
|
model_config_file, |
|
model_config_args, |
|
exp_config["logdir"], |
|
exp_config["eval_section"], |
|
exp_config["eval_beam_size"], |
|
infer_output_path, |
|
step, |
|
use_heuristic=exp_config["eval_use_heuristic"] |
|
) |
|
infer.main(infer_config) |
|
|
|
eval_output_path = "{}/{}-step{}.eval".format( |
|
exp_config["eval_output"], |
|
exp_config["eval_name"], |
|
step) |
|
eval_config = EvalConfig( |
|
model_config_file, |
|
model_config_args, |
|
exp_config["logdir"], |
|
exp_config["eval_section"], |
|
f"{infer_output_path}.infer", |
|
eval_output_path |
|
) |
|
eval.main(eval_config) |
|
|
|
res_json = json.load(open(eval_output_path)) |
|
print(step, res_json['total_scores']['all']['exact']) |
|
print(f"*;count;{res_json['total_scores']['easy']['count']};{res_json['total_scores']['medium']['count']};{res_json['total_scores']['hard']['count']};{res_json['total_scores']['extra']['count']};{res_json['total_scores']['all']['count']}") |
|
print(f"checkpoint;type;easy;medium;hard;extra;all") |
|
print(f"{step};exact match;{res_json['total_scores']['easy']['exact']:.3f};{res_json['total_scores']['medium']['exact']:.3f};{res_json['total_scores']['hard']['exact']:.3f};{res_json['total_scores']['extra']['exact']:.3f};{res_json['total_scores']['all']['exact']:.3f}") |
|
|
|
|
|
result = open(f"{exp_config['eval_output']}/eval-results.csv", "a", encoding='utf8') |
|
if first_loop == True: |
|
result.write(f"*;count;{res_json['total_scores']['easy']['count']};{res_json['total_scores']['medium']['count']};{res_json['total_scores']['hard']['count']};{res_json['total_scores']['extra']['count']};{res_json['total_scores']['all']['count']}\n") |
|
first_loop = False |
|
result.write(f"{step};exact match;{res_json['total_scores']['easy']['exact']:.3f};{res_json['total_scores']['medium']['exact']:.3f};{res_json['total_scores']['hard']['exact']:.3f};{res_json['total_scores']['extra']['exact']:.3f};{res_json['total_scores']['all']['exact']:.3f}\n") |
|
result.close() |
|
|
|
|
|
eval_clean = open(f"{exp_config['eval_output']}/{exp_config['eval_name']}-step{step}.csv", "w", encoding='utf8') |
|
for per_item in res_json['per_item']: |
|
if per_item['exact'] == 0 or per_item['exact'] == "false": exact = "false" |
|
if per_item['exact'] == 1 or per_item['exact'] == "true": exact = "true" |
|
eval_clean.write(f"{exact};{per_item['hardness']};{per_item['gold']};{per_item['predicted']}\n") |
|
eval_clean.close() |
|
|
|
if __name__ == "__main__": |
|
main() |