|
import argparse |
|
import ast |
|
import itertools |
|
import json |
|
import os |
|
import sys |
|
|
|
import _jsonnet |
|
import asdl |
|
import astor |
|
import torch |
|
import tqdm |
|
|
|
from seq2struct import beam_search |
|
from seq2struct import datasets |
|
from seq2struct import models |
|
from seq2struct import optimizers |
|
from seq2struct.utils import registry |
|
from seq2struct.utils import saver as saver_mod |
|
|
|
from seq2struct.models.spider import spider_beam_search |
|
|
|
class Inferer: |
|
def __init__(self, config): |
|
self.config = config |
|
if torch.cuda.is_available(): |
|
self.device = torch.device('cuda') |
|
else: |
|
self.device = torch.device('cpu') |
|
torch.set_num_threads(1) |
|
|
|
|
|
self.model_preproc = registry.instantiate( |
|
registry.lookup('model', config['model']).Preproc, |
|
config['model']) |
|
self.model_preproc.load() |
|
|
|
def load_model(self, logdir, step): |
|
'''Load a model (identified by the config used for construction) and return it''' |
|
|
|
model = registry.construct('model', self.config['model'], preproc=self.model_preproc, device=self.device) |
|
model.to(self.device) |
|
model.eval() |
|
model.visualize_flag = False |
|
|
|
|
|
saver = saver_mod.Saver({"model": model}) |
|
last_step = saver.restore(logdir, step=step, map_location=self.device, item_keys=["model"]) |
|
|
|
if not last_step: |
|
raise Exception('Attempting to infer on untrained model') |
|
return model |
|
|
|
def infer(self, model, output_path, args): |
|
output = open(output_path, 'w') |
|
|
|
with torch.no_grad(): |
|
if args.mode == 'infer': |
|
orig_data = registry.construct('dataset', self.config['data'][args.section]) |
|
preproc_data = self.model_preproc.dataset(args.section) |
|
if args.limit: |
|
sliced_orig_data = itertools.islice(orig_data, args.limit) |
|
sliced_preproc_data = itertools.islice(preproc_data, args.limit) |
|
else: |
|
sliced_orig_data = orig_data |
|
sliced_preproc_data = preproc_data |
|
assert len(orig_data) == len(preproc_data) |
|
self._inner_infer(model, args.beam_size, args.output_history, sliced_orig_data, sliced_preproc_data, output, args.use_heuristic) |
|
elif args.mode == 'debug': |
|
data = self.model_preproc.dataset(args.section) |
|
if args.limit: |
|
sliced_data = itertools.islice(data, args.limit) |
|
else: |
|
sliced_data = data |
|
self._debug(model, sliced_data, output) |
|
elif args.mode == 'visualize_attention': |
|
model.visualize_flag = True |
|
model.decoder.visualize_flag = True |
|
data = registry.construct('dataset', self.config['data'][args.section]) |
|
if args.limit: |
|
sliced_data = itertools.islice(data, args.limit) |
|
else: |
|
sliced_data = data |
|
self._visualize_attention(model, args.beam_size, args.output_history, sliced_data, args.res1, args.res2, args.res3, output) |
|
|
|
def _infer_one(self, model, data_item, preproc_item, beam_size, output_history=False, use_heuristic=True): |
|
if use_heuristic: |
|
|
|
beams = spider_beam_search.beam_search_with_heuristics( |
|
model, data_item, preproc_item, beam_size=beam_size, max_steps=1000, from_cond=False) |
|
else: |
|
beams = beam_search.beam_search( |
|
model, data_item, preproc_item, beam_size=beam_size, max_steps=1000) |
|
decoded = [] |
|
for beam in beams: |
|
model_output, inferred_code = beam.inference_state.finalize() |
|
|
|
decoded.append({ |
|
'orig_question': data_item.orig["question"], |
|
'model_output': model_output, |
|
'inferred_code': inferred_code, |
|
'score': beam.score, |
|
**({ |
|
'choice_history': beam.choice_history, |
|
'score_history': beam.score_history, |
|
} if output_history else {})}) |
|
return decoded |
|
|
|
def _inner_infer(self, model, beam_size, output_history, sliced_orig_data, sliced_preproc_data, output, use_heuristic=False): |
|
for i, (orig_item, preproc_item) in enumerate( |
|
tqdm.tqdm(zip(sliced_orig_data, sliced_preproc_data), |
|
total=len(sliced_orig_data))): |
|
if use_heuristic: |
|
|
|
beams = spider_beam_search.beam_search_with_heuristics( |
|
model, orig_item, preproc_item, beam_size=beam_size, max_steps=1000, from_cond=False) |
|
else: |
|
beams = beam_search.beam_search( |
|
model, orig_item, preproc_item, beam_size=beam_size, max_steps=1000) |
|
|
|
decoded = [] |
|
for beam in beams: |
|
model_output, inferred_code = beam.inference_state.finalize() |
|
|
|
decoded.append({ |
|
'orig_question': orig_item.orig["question"], |
|
'model_output': model_output, |
|
'inferred_code': inferred_code, |
|
'score': beam.score, |
|
**({ |
|
'choice_history': beam.choice_history, |
|
'score_history': beam.score_history, |
|
} if output_history else {})}) |
|
|
|
output.write( |
|
json.dumps({ |
|
'index': i, |
|
'beams': decoded, |
|
}) + '\n') |
|
output.flush() |
|
|
|
|
|
def _debug(self, model, sliced_data, output): |
|
for i, item in enumerate(tqdm.tqdm(sliced_data)): |
|
(_, history), = model.compute_loss([item], debug=True) |
|
output.write( |
|
json.dumps({ |
|
'index': i, |
|
'history': history, |
|
}) + '\n') |
|
output.flush() |
|
|
|
def _visualize_attention(self, model, beam_size, output_history, sliced_data, res1file, res2file, res3file, output): |
|
res1 = json.load(open(res1file, 'r')) |
|
res1 = res1['per_item'] |
|
res2 = json.load(open(res2file, 'r')) |
|
res2 = res2['per_item'] |
|
res3 = json.load(open(res3file, 'r')) |
|
res3 = res3['per_item'] |
|
interest_cnt = 0 |
|
cnt = 0 |
|
for i, item in enumerate(tqdm.tqdm(sliced_data)): |
|
|
|
if res1[i]['hardness'] != 'extra': |
|
continue |
|
|
|
cnt += 1 |
|
if (res1[i]['exact'] == 0) and (res2[i]['exact'] == 0) and (res3[i]['exact'] == 0): |
|
continue |
|
interest_cnt += 1 |
|
''' |
|
print('sample index: ') |
|
print(i) |
|
beams = beam_search.beam_search( |
|
model, item, beam_size=beam_size, max_steps=1000, visualize_flag=True) |
|
entry = item.orig |
|
print('ground truth SQL:') |
|
print(entry['query_toks']) |
|
print('prediction:') |
|
print(res2[i]) |
|
decoded = [] |
|
for beam in beams: |
|
model_output, inferred_code = beam.inference_state.finalize() |
|
|
|
decoded.append({ |
|
'model_output': model_output, |
|
'inferred_code': inferred_code, |
|
'score': beam.score, |
|
**({ |
|
'choice_history': beam.choice_history, |
|
'score_history': beam.score_history, |
|
} if output_history else {})}) |
|
|
|
output.write( |
|
json.dumps({ |
|
'index': i, |
|
'beams': decoded, |
|
}) + '\n') |
|
output.flush() |
|
''' |
|
print(interest_cnt * 1.0 / cnt) |
|
|
|
|
|
def add_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--logdir', required=True) |
|
parser.add_argument('--config', required=True) |
|
parser.add_argument('--config-args') |
|
|
|
parser.add_argument('--step', type=int) |
|
parser.add_argument('--section', required=True) |
|
parser.add_argument('--output', required=True) |
|
parser.add_argument('--beam-size', required=True, type=int) |
|
parser.add_argument('--output-history', action='store_true') |
|
parser.add_argument('--limit', type=int) |
|
parser.add_argument('--mode', default='infer', choices=['infer', 'debug', 'visualize_attention']) |
|
parser.add_argument('--use_heuristic', action='store_true') |
|
parser.add_argument('--res1', default='outputs/glove-sup-att-1h-0/outputs.json') |
|
parser.add_argument('--res2', default='outputs/glove-sup-att-1h-1/outputs.json') |
|
parser.add_argument('--res3', default='outputs/glove-sup-att-1h-2/outputs.json') |
|
args = parser.parse_args() |
|
return args |
|
|
|
def main(args): |
|
if args.config_args: |
|
config = json.loads(_jsonnet.evaluate_file(args.config, tla_codes={'args': args.config_args})) |
|
else: |
|
config = json.loads(_jsonnet.evaluate_file(args.config)) |
|
|
|
if 'model_name' in config: |
|
args.logdir = os.path.join(args.logdir, config['model_name']) |
|
|
|
output_path = args.output.replace('__LOGDIR__', args.logdir) |
|
if os.path.exists(output_path): |
|
print('Output file {} already exists'.format(output_path)) |
|
sys.exit(1) |
|
|
|
inferer = Inferer(config) |
|
model = inferer.load_model(args.logdir, args.step) |
|
inferer.infer(model, output_path, args) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = add_parser() |
|
main(args) |
|
|