Spaces:
Sleeping
Sleeping
File size: 4,468 Bytes
d916065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Natural Language Toolkit: evaluation of dependency parser
#
# Author: Long Duong <[email protected]>
#
# Copyright (C) 2001-2023 NLTK Project
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT
import unicodedata
class DependencyEvaluator:
"""
Class for measuring labelled and unlabelled attachment score for
dependency parsing. Note that the evaluation ignores punctuation.
>>> from nltk.parse import DependencyGraph, DependencyEvaluator
>>> gold_sent = DependencyGraph(\"""
... Pierre NNP 2 NMOD
... Vinken NNP 8 SUB
... , , 2 P
... 61 CD 5 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 2 P
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 NMOD
... board NN 9 OBJ
... as IN 9 VMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> parsed_sent = DependencyGraph(\"""
... Pierre NNP 8 NMOD
... Vinken NNP 1 SUB
... , , 3 P
... 61 CD 6 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 3 AMOD
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 AMOD
... board NN 9 OBJECT
... as IN 9 NMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> de = DependencyEvaluator([parsed_sent],[gold_sent])
>>> las, uas = de.eval()
>>> las
0.6
>>> uas
0.8
>>> abs(uas - 0.8) < 0.00001
True
"""
def __init__(self, parsed_sents, gold_sents):
"""
:param parsed_sents: the list of parsed_sents as the output of parser
:type parsed_sents: list(DependencyGraph)
"""
self._parsed_sents = parsed_sents
self._gold_sents = gold_sents
def _remove_punct(self, inStr):
"""
Function to remove punctuation from Unicode string.
:param input: the input string
:return: Unicode string after remove all punctuation
"""
punc_cat = {"Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"}
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)
def eval(self):
"""
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
:return : tuple(float,float)
"""
if len(self._parsed_sents) != len(self._gold_sents):
raise ValueError(
" Number of parsed sentence is different with number of gold sentence."
)
corr = 0
corrL = 0
total = 0
for i in range(len(self._parsed_sents)):
parsed_sent_nodes = self._parsed_sents[i].nodes
gold_sent_nodes = self._gold_sents[i].nodes
if len(parsed_sent_nodes) != len(gold_sent_nodes):
raise ValueError("Sentences must have equal length.")
for parsed_node_address, parsed_node in parsed_sent_nodes.items():
gold_node = gold_sent_nodes[parsed_node_address]
if parsed_node["word"] is None:
continue
if parsed_node["word"] != gold_node["word"]:
raise ValueError("Sentence sequence is not matched.")
# Ignore if word is punctuation by default
# if (parsed_sent[j]["word"] in string.punctuation):
if self._remove_punct(parsed_node["word"]) == "":
continue
total += 1
if parsed_node["head"] == gold_node["head"]:
corr += 1
if parsed_node["rel"] == gold_node["rel"]:
corrL += 1
return corrL / total, corr / total
|