MEIRa / utils_evaluate.py
KawshikManikantan's picture
upload_trial
98e2ea5
raw
history blame
8.29 kB
import os
import logging
import pickle
import time
import json
import torch
from os import path
from collections import OrderedDict, Counter
from coref_utils.metrics import CorefEvaluator, F1Evaluator
from coref_utils.conll import evaluate_conll
from coref_utils.utils import get_mention_to_cluster, is_aligned, filter_clusters
from model.utils import action_sequences_to_clusters
from model.entity_ranking_model import EntityRankingModel
from omegaconf import DictConfig
from typing import Dict
from torch import Tensor
from collections import defaultdict
import time
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
logger = logging.getLogger()
def get_log_file_name(
config,
dataset,
teacher_force,
gold_mentions,
split,
_iter,
):
log_dir = path.join(config.paths.model_dir, dataset)
## Used for special experiments where we want to save logs in a different directory --
if config.get("log_dir_add", None) is not None:
log_dir_add = config.log_dir_add
log_dir = path.join(log_dir, log_dir_add)
if not path.exists(log_dir):
os.makedirs(log_dir)
gold_ment_str = ""
if (
config.model.mention_params.use_gold_ments
): ## Mode where you train with golden mentions
gold_ment_str = "_gold"
tf_str = "" ## Teacher forced evaluation
if teacher_force == True:
tf_str = "_tf"
gold_str = "" ## Golden mentions in evaluation
if gold_mentions == True:
gold_str = "_gold(eval)"
ext_ment_str = "" ## External mention evaluation
if config.model.mention_params.ext_ment:
ext_ment_str = "_ext_ment"
log_file = path.join(
log_dir,
split + gold_ment_str + gold_str + tf_str + _iter + ext_ment_str + ".log.jsonl",
)
log_file_link = path.join(
log_dir,
split
+ gold_ment_str
+ gold_str
+ tf_str
+ _iter
+ ext_ment_str
+ ".link.jsonl",
)
print("Log file: ", log_file)
return log_file, log_file_link
def get_logs(example, raw_predicted_clusters, coref_scores):
log_example = dict(example)
log_example["predicted_clusters"] = raw_predicted_clusters
log_example["coref_scores"] = coref_scores
del log_example["tensorized_sent"]
for key in list(log_example.keys()):
if isinstance(log_example[key], Tensor):
del log_example[key]
return log_example
def full_coref_evaluation(
config: DictConfig,
model: EntityRankingModel,
data_iter_map: Dict,
dataset: str,
split="dev",
_iter="",
teacher_force=False,
gold_mentions=False,
final_eval=False,
conll_data_dir: Dict = None,
) -> Dict:
"""Function to evaluate full coreference chains.
Args:
config: Experiment configuration
model: Coreference model
data_iter_map: Data iterator
dataset: Name of the coreference dataset
split: Partition of the dataset - train/dev/test
final_eval: Whether this is a periodic evaluation or final evaluation
For final evaluation, official CoNLL scores can be calculated if possible.
conll_data_dir: Data directory dictionary which maps datasets to their gold CoNLL files.
Returns:
dict: Dictionary with results for all the metrics.
"""
# Capture the auxiliary action accuracy
total_actions = 0.0
evaluator = CorefEvaluator()
f1evaluator = F1Evaluator()
coref_predictions, subtoken_maps = {}, {}
logger.info(f"Evaluating on {len(data_iter_map[split][dataset])} examples")
log_file, log_file_link = get_log_file_name(
config,
dataset,
teacher_force,
gold_mentions,
split,
_iter,
)
f = open(log_file, "w")
f_link = open(log_file_link, "w")
for example in data_iter_map[split][dataset]:
## Get outputs:
(
pred_mentions,
pred_mentions_emb,
mention_scores,
gt_actions,
pred_actions,
coref_scores,
entity_cluster_states,
link_time,
) = model(example, teacher_force=teacher_force, gold_mentions=gold_mentions)
num_major_entities = len(example["representatives"])
raw_predicted_clusters = action_sequences_to_clusters(
pred_actions, pred_mentions, num_major_entities
)
assert (
len(raw_predicted_clusters)
== len(example["clusters"])
== num_major_entities + 1
), "Number of clusters should be equal to number of major entities + 1"
## Remove clusters less than the threshold of 1 and remove others from evaluation in MET here. Remove empty clustes for coref
predicted_clusters_coref = filter_clusters(raw_predicted_clusters, threshold=1)
## Keep cluster numbers same as the number of major entities.
predicted_clusters_f1 = filter_clusters(raw_predicted_clusters, threshold=0)
## Golden clusters cannot be empty so we can use the threshold as 1 But we remove the last cluster anyways
gold_clusters = filter_clusters(example["clusters"], threshold=1)
mention_to_predicted_coref = get_mention_to_cluster(predicted_clusters_coref)
mention_to_gold = get_mention_to_cluster(gold_clusters)
evaluator.update(
predicted_clusters_coref,
gold_clusters,
mention_to_predicted_coref,
mention_to_gold,
)
assert (
len(predicted_clusters_f1) == len(gold_clusters) == num_major_entities
), "Predicted and Gold clusters should be of same length and equal to number of major entities + 1"
f1evaluator.update(predicted_clusters_f1, gold_clusters)
coref_predictions[example["doc_key"]] = raw_predicted_clusters
if "orig_subtoken_map" in example:
subtoken_maps[example["doc_key"]] = example["orig_subtoken_map"]
else:
subtoken_maps[example["doc_key"]] = example["subtoken_map"]
total_actions += len(pred_actions)
max_coref_scores = [max(coref_score) for coref_score in coref_scores]
## Removed oracle clustering for now. Code is now at the bottom of this file.
log_example = get_logs(
example,
raw_predicted_clusters=raw_predicted_clusters,
coref_scores=max_coref_scores,
)
log_link_example = {
"doc_key": example["doc_key"],
"num_mentions": len(pred_mentions),
"link_time": link_time,
}
if _iter == "":
f.write(json.dumps(log_example) + "\n")
f_link.write(json.dumps(log_link_example) + "\n")
f.close()
f_link.close()
result_dict: Dict = OrderedDict()
perf_str: str = ""
# Print individual metrics
for indv_metric, indv_evaluator in zip(config.metrics, evaluator.evaluators):
perf_str += ", " + indv_metric + ": {}".format(indv_evaluator.get_f1() * 100)
result_dict[indv_metric] = OrderedDict()
result_dict[indv_metric]["recall"] = indv_evaluator.get_recall() * 100
result_dict[indv_metric]["precision"] = indv_evaluator.get_precision() * 100
result_dict[indv_metric]["fscore"] = indv_evaluator.get_f1() * 100
result_dict["fscore"] = evaluator.get_f1() * 100
result_dict["f1_macro"], result_dict["f1_micro"] = f1evaluator.get_numbers()
logger.info("F-score: %.1f %s" % (result_dict["fscore"], perf_str))
return result_dict
def coref_evaluation(
config: DictConfig,
model: EntityRankingModel,
data_iter_map: Dict,
dataset: str,
split="dev",
_iter="",
teacher_force=False,
gold_mentions=False,
final_eval=False,
conll_data_dir: Dict = None,
) -> Dict:
"""Evaluation function which calls the dataset-appropriate coreference evaluation function."""
return full_coref_evaluation(
config,
model,
data_iter_map,
dataset,
split=split,
_iter=_iter,
teacher_force=teacher_force,
gold_mentions=gold_mentions,
final_eval=final_eval,
conll_data_dir=conll_data_dir,
)