nielklug's picture
init
6ed21b9
raw
history blame
2.14 kB
#!/usr/bin/python3
import argparse
import pickle
import torch
from .treebanks import load_single_text
from .benepar import parse_chart
from huggingface_hub import hf_hub_download
def run_parse(words, tags, model_path='nielklug/mhg_parser', subbatch_max_tokens=500):
# print("Loading test trees from {}...".format(args.text_path))
test_treebank = load_single_text(words, tags)
# print("Loaded {:,} test examples.".format(len(test_treebank)))
model_file = hf_hub_download(repo_id=model_path, filename='german-delex-parser_dev=83.10.pt')
# print("Loading model from {}...".format(model_path))
parser = parse_chart.ChartParser.from_trained(model_file)
if torch.cuda.is_available():
parser.cuda()
# print("Parsing test sentences...")
test_predicted = parser.parse(
test_treebank.without_gold_annotations(),
subbatch_max_tokens=subbatch_max_tokens,
)
# insert original tokens to the delexicalized parses
for example, prediction in zip(test_treebank, test_predicted):
leaf_positions = prediction.treepositions('leaves')
for word_tag_pair, leaf_pos in zip(example.word_tag_pairs, leaf_positions):
prediction[leaf_pos] = word_tag_pair[0]
prediction[leaf_pos[:-1]].set_label(word_tag_pair[1])
results = []
for tree in test_predicted:
results.append(tree.pformat(margin=1e100))
return results
# with open(args.output_path, "w") as outfile:
# for tree in test_predicted:
# outfile.write("{}\n".format(tree.pformat(margin=1e100)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default="/mounts/work/nie/mhg/schmid/MHG-Parser/self-attentive-parser-master/models/german-delex-parser_dev=83.10.pt",
type=str, help='path to the trained parser')
parser.add_argument("--text_path", required=True, type=str)
parser.add_argument('--subbatch_max_tokens', default=500, type=str)
parser.add_argument('--output_path', default="", type=str)
args = parser.parse_args()
run_parse(args)
if __name__ == "__main__":
main()