#!/usr/bin/python3 import argparse import pickle import torch from .treebanks import load_single_text from .benepar import parse_chart from huggingface_hub import hf_hub_download def run_parse(words, tags, model_path='nielklug/mhg_parser', subbatch_max_tokens=500): # print("Loading test trees from {}...".format(args.text_path)) test_treebank = load_single_text(words, tags) # print("Loaded {:,} test examples.".format(len(test_treebank))) model_file = hf_hub_download(repo_id=model_path, filename='german-delex-parser_dev=83.10.pt') # print("Loading model from {}...".format(model_path)) parser = parse_chart.ChartParser.from_trained(model_file) if torch.cuda.is_available(): parser.cuda() # print("Parsing test sentences...") test_predicted = parser.parse( test_treebank.without_gold_annotations(), subbatch_max_tokens=subbatch_max_tokens, ) # insert original tokens to the delexicalized parses for example, prediction in zip(test_treebank, test_predicted): leaf_positions = prediction.treepositions('leaves') for word_tag_pair, leaf_pos in zip(example.word_tag_pairs, leaf_positions): prediction[leaf_pos] = word_tag_pair[0] prediction[leaf_pos[:-1]].set_label(word_tag_pair[1]) results = [] for tree in test_predicted: results.append(tree.pformat(margin=1e100)) return results # with open(args.output_path, "w") as outfile: # for tree in test_predicted: # outfile.write("{}\n".format(tree.pformat(margin=1e100))) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_path", default="/mounts/work/nie/mhg/schmid/MHG-Parser/self-attentive-parser-master/models/german-delex-parser_dev=83.10.pt", type=str, help='path to the trained parser') parser.add_argument("--text_path", required=True, type=str) parser.add_argument('--subbatch_max_tokens', default=500, type=str) parser.add_argument('--output_path', default="", type=str) args = parser.parse_args() run_parse(args) if __name__ == "__main__": main()