File size: 4,468 Bytes
d916065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Natural Language Toolkit: evaluation of dependency parser
#
# Author: Long Duong <[email protected]>
#
# Copyright (C) 2001-2023 NLTK Project
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

import unicodedata


class DependencyEvaluator:
    """

    Class for measuring labelled and unlabelled attachment score for

    dependency parsing. Note that the evaluation ignores punctuation.



    >>> from nltk.parse import DependencyGraph, DependencyEvaluator



    >>> gold_sent = DependencyGraph(\"""

    ... Pierre  NNP     2       NMOD

    ... Vinken  NNP     8       SUB

    ... ,       ,       2       P

    ... 61      CD      5       NMOD

    ... years   NNS     6       AMOD

    ... old     JJ      2       NMOD

    ... ,       ,       2       P

    ... will    MD      0       ROOT

    ... join    VB      8       VC

    ... the     DT      11      NMOD

    ... board   NN      9       OBJ

    ... as      IN      9       VMOD

    ... a       DT      15      NMOD

    ... nonexecutive    JJ      15      NMOD

    ... director        NN      12      PMOD

    ... Nov.    NNP     9       VMOD

    ... 29      CD      16      NMOD

    ... .       .       9       VMOD

    ... \""")



    >>> parsed_sent = DependencyGraph(\"""

    ... Pierre  NNP     8       NMOD

    ... Vinken  NNP     1       SUB

    ... ,       ,       3       P

    ... 61      CD      6       NMOD

    ... years   NNS     6       AMOD

    ... old     JJ      2       NMOD

    ... ,       ,       3       AMOD

    ... will    MD      0       ROOT

    ... join    VB      8       VC

    ... the     DT      11      AMOD

    ... board   NN      9       OBJECT

    ... as      IN      9       NMOD

    ... a       DT      15      NMOD

    ... nonexecutive    JJ      15      NMOD

    ... director        NN      12      PMOD

    ... Nov.    NNP     9       VMOD

    ... 29      CD      16      NMOD

    ... .       .       9       VMOD

    ... \""")



    >>> de = DependencyEvaluator([parsed_sent],[gold_sent])

    >>> las, uas = de.eval()

    >>> las

    0.6

    >>> uas

    0.8

    >>> abs(uas - 0.8) < 0.00001

    True

    """

    def __init__(self, parsed_sents, gold_sents):
        """

        :param parsed_sents: the list of parsed_sents as the output of parser

        :type parsed_sents: list(DependencyGraph)

        """
        self._parsed_sents = parsed_sents
        self._gold_sents = gold_sents

    def _remove_punct(self, inStr):
        """

        Function to remove punctuation from Unicode string.

        :param input: the input string

        :return: Unicode string after remove all punctuation

        """
        punc_cat = {"Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"}
        return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)

    def eval(self):
        """

        Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)



        :return : tuple(float,float)

        """
        if len(self._parsed_sents) != len(self._gold_sents):
            raise ValueError(
                " Number of parsed sentence is different with number of gold sentence."
            )

        corr = 0
        corrL = 0
        total = 0

        for i in range(len(self._parsed_sents)):
            parsed_sent_nodes = self._parsed_sents[i].nodes
            gold_sent_nodes = self._gold_sents[i].nodes

            if len(parsed_sent_nodes) != len(gold_sent_nodes):
                raise ValueError("Sentences must have equal length.")

            for parsed_node_address, parsed_node in parsed_sent_nodes.items():
                gold_node = gold_sent_nodes[parsed_node_address]

                if parsed_node["word"] is None:
                    continue
                if parsed_node["word"] != gold_node["word"]:
                    raise ValueError("Sentence sequence is not matched.")

                # Ignore if word is punctuation by default
                # if (parsed_sent[j]["word"] in string.punctuation):
                if self._remove_punct(parsed_node["word"]) == "":
                    continue

                total += 1
                if parsed_node["head"] == gold_node["head"]:
                    corr += 1
                    if parsed_node["rel"] == gold_node["rel"]:
                        corrL += 1

        return corrL / total, corr / total