Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
import argparse | |
import pickle | |
import torch | |
from .treebanks import load_single_text | |
from .benepar import parse_chart | |
from huggingface_hub import hf_hub_download | |
def run_parse(words, tags, model_path='nielklug/mhg_parser', subbatch_max_tokens=500): | |
# print("Loading test trees from {}...".format(args.text_path)) | |
test_treebank = load_single_text(words, tags) | |
# print("Loaded {:,} test examples.".format(len(test_treebank))) | |
model_file = hf_hub_download(repo_id=model_path, filename='german-delex-parser_dev=83.10.pt') | |
# print("Loading model from {}...".format(model_path)) | |
parser = parse_chart.ChartParser.from_trained(model_file) | |
if torch.cuda.is_available(): | |
parser.cuda() | |
# print("Parsing test sentences...") | |
test_predicted = parser.parse( | |
test_treebank.without_gold_annotations(), | |
subbatch_max_tokens=subbatch_max_tokens, | |
) | |
# insert original tokens to the delexicalized parses | |
for example, prediction in zip(test_treebank, test_predicted): | |
leaf_positions = prediction.treepositions('leaves') | |
for word_tag_pair, leaf_pos in zip(example.word_tag_pairs, leaf_positions): | |
prediction[leaf_pos] = word_tag_pair[0] | |
prediction[leaf_pos[:-1]].set_label(word_tag_pair[1]) | |
results = [] | |
for tree in test_predicted: | |
results.append(tree.pformat(margin=1e100)) | |
return results | |
# with open(args.output_path, "w") as outfile: | |
# for tree in test_predicted: | |
# outfile.write("{}\n".format(tree.pformat(margin=1e100))) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_path", default="/mounts/work/nie/mhg/schmid/MHG-Parser/self-attentive-parser-master/models/german-delex-parser_dev=83.10.pt", | |
type=str, help='path to the trained parser') | |
parser.add_argument("--text_path", required=True, type=str) | |
parser.add_argument('--subbatch_max_tokens', default=500, type=str) | |
parser.add_argument('--output_path', default="", type=str) | |
args = parser.parse_args() | |
run_parse(args) | |
if __name__ == "__main__": | |
main() | |