File size: 2,142 Bytes
6ed21b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/python3

import argparse
import pickle
import torch

from .treebanks import load_single_text
from .benepar import parse_chart
from huggingface_hub import hf_hub_download


def run_parse(words, tags, model_path='nielklug/mhg_parser', subbatch_max_tokens=500):
    # print("Loading test trees from {}...".format(args.text_path))
    test_treebank = load_single_text(words, tags)
    # print("Loaded {:,} test examples.".format(len(test_treebank)))
    model_file = hf_hub_download(repo_id=model_path, filename='german-delex-parser_dev=83.10.pt')
    # print("Loading model from {}...".format(model_path))
    parser = parse_chart.ChartParser.from_trained(model_file)

    if torch.cuda.is_available():
        parser.cuda()
    
    # print("Parsing test sentences...")

    test_predicted = parser.parse(
        test_treebank.without_gold_annotations(),
        subbatch_max_tokens=subbatch_max_tokens,
    )

    # insert original tokens to the delexicalized parses
    for example, prediction in zip(test_treebank, test_predicted):
        leaf_positions = prediction.treepositions('leaves')
        for word_tag_pair, leaf_pos in zip(example.word_tag_pairs, leaf_positions):
            prediction[leaf_pos] = word_tag_pair[0]
            prediction[leaf_pos[:-1]].set_label(word_tag_pair[1])

    results = []
    for tree in test_predicted:
        results.append(tree.pformat(margin=1e100))
    return results

    # with open(args.output_path, "w") as outfile:
    #     for tree in test_predicted:
    #         outfile.write("{}\n".format(tree.pformat(margin=1e100)))


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_path", default="/mounts/work/nie/mhg/schmid/MHG-Parser/self-attentive-parser-master/models/german-delex-parser_dev=83.10.pt",
    type=str, help='path to the trained parser')
    parser.add_argument("--text_path", required=True, type=str)
    parser.add_argument('--subbatch_max_tokens', default=500, type=str)
    parser.add_argument('--output_path', default="", type=str)

    args = parser.parse_args()

    run_parse(args)

if __name__ == "__main__":
    main()