MEIRa / coref_utils /conll.py
KawshikManikantan's picture
upload_trial
98e2ea5
raw
history blame
4.34 kB
import re
import subprocess
import operator
import collections
BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)")
COREF_RESULTS_REGEX = re.compile(
r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) "
r"([0-9.]+)%\tF1: ([0-9.]+)%.*",
re.DOTALL,
)
def get_doc_key(doc_id, part):
return "{}_{}".format(doc_id, int(part))
def output_conll(input_file, output_file, predictions, subtoken_map):
prediction_map = {}
for doc_key, clusters in predictions.items():
start_map = collections.defaultdict(list)
end_map = collections.defaultdict(list)
word_map = collections.defaultdict(list)
for cluster_id, mentions in enumerate(clusters):
for start, end in mentions:
start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
if start == end:
word_map[start].append(cluster_id)
else:
start_map[start].append((cluster_id, end))
end_map[end].append((cluster_id, start))
for k, v in start_map.items():
start_map[k] = [
cluster_id
for cluster_id, end in sorted(
v, key=operator.itemgetter(1), reverse=True
)
]
for k, v in end_map.items():
end_map[k] = [
cluster_id
for cluster_id, start in sorted(
v, key=operator.itemgetter(1), reverse=True
)
]
prediction_map[doc_key] = (start_map, end_map, word_map)
word_index = 0
for line in input_file.readlines():
row = line.split()
if len(row) == 0:
output_file.write("\n")
elif row[0].startswith("#"):
begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
if begin_match:
doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
start_map, end_map, word_map = prediction_map[doc_key]
word_index = 0
output_file.write(line)
# output_file.write("\n")
else:
assert get_doc_key(row[0], row[1]) == doc_key
coref_list = []
if word_index in end_map:
for cluster_id in end_map[word_index]:
coref_list.append("{})".format(cluster_id))
if word_index in word_map:
for cluster_id in word_map[word_index]:
coref_list.append("({})".format(cluster_id))
if word_index in start_map:
for cluster_id in start_map[word_index]:
coref_list.append("({}".format(cluster_id))
if len(coref_list) == 0:
row[-1] = "-"
else:
row[-1] = "|".join(coref_list)
output_file.write(" ".join(row))
output_file.write("\n")
word_index += 1
def official_conll_eval(
conll_scorer, gold_path, predicted_path, metric, official_stdout=False
):
cmd = [conll_scorer, metric, gold_path, predicted_path, "none"]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
stdout, stderr = process.communicate()
process.wait()
stdout = stdout.decode("utf-8")
if stderr is not None:
print(stderr)
if official_stdout:
print("Official result for {}".format(metric))
print(stdout)
coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)
recall = float(coref_results_match.group(1))
precision = float(coref_results_match.group(2))
f1 = float(coref_results_match.group(3))
return {"r": recall, "p": precision, "f": f1}
def evaluate_conll(
conll_scorer,
gold_path,
predictions,
subtoken_maps,
prediction_path,
all_metrics=False,
official_stdout=False,
):
with open(prediction_path, "w") as prediction_file:
with open(gold_path, "r") as gold_file:
output_conll(gold_file, prediction_file, predictions, subtoken_maps)
result = {
metric: official_conll_eval(
conll_scorer, gold_file.name, prediction_file.name, metric, official_stdout
)
for metric in ("muc", "bcub", "ceafe")
}
return result