import re import subprocess import operator import collections BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") COREF_RESULTS_REGEX = re.compile( r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) " r"([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL, ) def get_doc_key(doc_id, part): return "{}_{}".format(doc_id, int(part)) def output_conll(input_file, output_file, predictions, subtoken_map): prediction_map = {} for doc_key, clusters in predictions.items(): start_map = collections.defaultdict(list) end_map = collections.defaultdict(list) word_map = collections.defaultdict(list) for cluster_id, mentions in enumerate(clusters): for start, end in mentions: start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end] if start == end: word_map[start].append(cluster_id) else: start_map[start].append((cluster_id, end)) end_map[end].append((cluster_id, start)) for k, v in start_map.items(): start_map[k] = [ cluster_id for cluster_id, end in sorted( v, key=operator.itemgetter(1), reverse=True ) ] for k, v in end_map.items(): end_map[k] = [ cluster_id for cluster_id, start in sorted( v, key=operator.itemgetter(1), reverse=True ) ] prediction_map[doc_key] = (start_map, end_map, word_map) word_index = 0 for line in input_file.readlines(): row = line.split() if len(row) == 0: output_file.write("\n") elif row[0].startswith("#"): begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) if begin_match: doc_key = get_doc_key(begin_match.group(1), begin_match.group(2)) start_map, end_map, word_map = prediction_map[doc_key] word_index = 0 output_file.write(line) # output_file.write("\n") else: assert get_doc_key(row[0], row[1]) == doc_key coref_list = [] if word_index in end_map: for cluster_id in end_map[word_index]: coref_list.append("{})".format(cluster_id)) if word_index in word_map: for cluster_id in word_map[word_index]: coref_list.append("({})".format(cluster_id)) if word_index in start_map: for cluster_id in start_map[word_index]: coref_list.append("({}".format(cluster_id)) if len(coref_list) == 0: row[-1] = "-" else: row[-1] = "|".join(coref_list) output_file.write(" ".join(row)) output_file.write("\n") word_index += 1 def official_conll_eval( conll_scorer, gold_path, predicted_path, metric, official_stdout=False ): cmd = [conll_scorer, metric, gold_path, predicted_path, "none"] process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True) stdout, stderr = process.communicate() process.wait() stdout = stdout.decode("utf-8") if stderr is not None: print(stderr) if official_stdout: print("Official result for {}".format(metric)) print(stdout) coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) recall = float(coref_results_match.group(1)) precision = float(coref_results_match.group(2)) f1 = float(coref_results_match.group(3)) return {"r": recall, "p": precision, "f": f1} def evaluate_conll( conll_scorer, gold_path, predictions, subtoken_maps, prediction_path, all_metrics=False, official_stdout=False, ): with open(prediction_path, "w") as prediction_file: with open(gold_path, "r") as gold_file: output_conll(gold_file, prediction_file, predictions, subtoken_maps) result = { metric: official_conll_eval( conll_scorer, gold_file.name, prediction_file.name, metric, official_stdout ) for metric in ("muc", "bcub", "ceafe") } return result