Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import pickle | |
import time | |
import json | |
import torch | |
from os import path | |
from collections import OrderedDict, Counter | |
from coref_utils.metrics import CorefEvaluator, F1Evaluator | |
from coref_utils.conll import evaluate_conll | |
from coref_utils.utils import get_mention_to_cluster, is_aligned, filter_clusters | |
from model.utils import action_sequences_to_clusters | |
from model.entity_ranking_model import EntityRankingModel | |
from omegaconf import DictConfig | |
from typing import Dict | |
from torch import Tensor | |
from collections import defaultdict | |
import time | |
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) | |
logger = logging.getLogger() | |
def get_log_file_name( | |
config, | |
dataset, | |
teacher_force, | |
gold_mentions, | |
split, | |
_iter, | |
): | |
log_dir = path.join(config.paths.model_dir, dataset) | |
## Used for special experiments where we want to save logs in a different directory -- | |
if config.get("log_dir_add", None) is not None: | |
log_dir_add = config.log_dir_add | |
log_dir = path.join(log_dir, log_dir_add) | |
if not path.exists(log_dir): | |
os.makedirs(log_dir) | |
gold_ment_str = "" | |
if ( | |
config.model.mention_params.use_gold_ments | |
): ## Mode where you train with golden mentions | |
gold_ment_str = "_gold" | |
tf_str = "" ## Teacher forced evaluation | |
if teacher_force == True: | |
tf_str = "_tf" | |
gold_str = "" ## Golden mentions in evaluation | |
if gold_mentions == True: | |
gold_str = "_gold(eval)" | |
ext_ment_str = "" ## External mention evaluation | |
if config.model.mention_params.ext_ment: | |
ext_ment_str = "_ext_ment" | |
log_file = path.join( | |
log_dir, | |
split + gold_ment_str + gold_str + tf_str + _iter + ext_ment_str + ".log.jsonl", | |
) | |
log_file_link = path.join( | |
log_dir, | |
split | |
+ gold_ment_str | |
+ gold_str | |
+ tf_str | |
+ _iter | |
+ ext_ment_str | |
+ ".link.jsonl", | |
) | |
print("Log file: ", log_file) | |
return log_file, log_file_link | |
def get_logs(example, raw_predicted_clusters, coref_scores): | |
log_example = dict(example) | |
log_example["predicted_clusters"] = raw_predicted_clusters | |
log_example["coref_scores"] = coref_scores | |
del log_example["tensorized_sent"] | |
for key in list(log_example.keys()): | |
if isinstance(log_example[key], Tensor): | |
del log_example[key] | |
return log_example | |
def full_coref_evaluation( | |
config: DictConfig, | |
model: EntityRankingModel, | |
data_iter_map: Dict, | |
dataset: str, | |
split="dev", | |
_iter="", | |
teacher_force=False, | |
gold_mentions=False, | |
final_eval=False, | |
conll_data_dir: Dict = None, | |
) -> Dict: | |
"""Function to evaluate full coreference chains. | |
Args: | |
config: Experiment configuration | |
model: Coreference model | |
data_iter_map: Data iterator | |
dataset: Name of the coreference dataset | |
split: Partition of the dataset - train/dev/test | |
final_eval: Whether this is a periodic evaluation or final evaluation | |
For final evaluation, official CoNLL scores can be calculated if possible. | |
conll_data_dir: Data directory dictionary which maps datasets to their gold CoNLL files. | |
Returns: | |
dict: Dictionary with results for all the metrics. | |
""" | |
# Capture the auxiliary action accuracy | |
total_actions = 0.0 | |
evaluator = CorefEvaluator() | |
f1evaluator = F1Evaluator() | |
coref_predictions, subtoken_maps = {}, {} | |
logger.info(f"Evaluating on {len(data_iter_map[split][dataset])} examples") | |
log_file, log_file_link = get_log_file_name( | |
config, | |
dataset, | |
teacher_force, | |
gold_mentions, | |
split, | |
_iter, | |
) | |
f = open(log_file, "w") | |
f_link = open(log_file_link, "w") | |
for example in data_iter_map[split][dataset]: | |
## Get outputs: | |
( | |
pred_mentions, | |
pred_mentions_emb, | |
mention_scores, | |
gt_actions, | |
pred_actions, | |
coref_scores, | |
entity_cluster_states, | |
link_time, | |
) = model(example, teacher_force=teacher_force, gold_mentions=gold_mentions) | |
num_major_entities = len(example["representatives"]) | |
raw_predicted_clusters = action_sequences_to_clusters( | |
pred_actions, pred_mentions, num_major_entities | |
) | |
assert ( | |
len(raw_predicted_clusters) | |
== len(example["clusters"]) | |
== num_major_entities + 1 | |
), "Number of clusters should be equal to number of major entities + 1" | |
## Remove clusters less than the threshold of 1 and remove others from evaluation in MET here. Remove empty clustes for coref | |
predicted_clusters_coref = filter_clusters(raw_predicted_clusters, threshold=1) | |
## Keep cluster numbers same as the number of major entities. | |
predicted_clusters_f1 = filter_clusters(raw_predicted_clusters, threshold=0) | |
## Golden clusters cannot be empty so we can use the threshold as 1 But we remove the last cluster anyways | |
gold_clusters = filter_clusters(example["clusters"], threshold=1) | |
mention_to_predicted_coref = get_mention_to_cluster(predicted_clusters_coref) | |
mention_to_gold = get_mention_to_cluster(gold_clusters) | |
evaluator.update( | |
predicted_clusters_coref, | |
gold_clusters, | |
mention_to_predicted_coref, | |
mention_to_gold, | |
) | |
assert ( | |
len(predicted_clusters_f1) == len(gold_clusters) == num_major_entities | |
), "Predicted and Gold clusters should be of same length and equal to number of major entities + 1" | |
f1evaluator.update(predicted_clusters_f1, gold_clusters) | |
coref_predictions[example["doc_key"]] = raw_predicted_clusters | |
if "orig_subtoken_map" in example: | |
subtoken_maps[example["doc_key"]] = example["orig_subtoken_map"] | |
else: | |
subtoken_maps[example["doc_key"]] = example["subtoken_map"] | |
total_actions += len(pred_actions) | |
max_coref_scores = [max(coref_score) for coref_score in coref_scores] | |
## Removed oracle clustering for now. Code is now at the bottom of this file. | |
log_example = get_logs( | |
example, | |
raw_predicted_clusters=raw_predicted_clusters, | |
coref_scores=max_coref_scores, | |
) | |
log_link_example = { | |
"doc_key": example["doc_key"], | |
"num_mentions": len(pred_mentions), | |
"link_time": link_time, | |
} | |
if _iter == "": | |
f.write(json.dumps(log_example) + "\n") | |
f_link.write(json.dumps(log_link_example) + "\n") | |
f.close() | |
f_link.close() | |
result_dict: Dict = OrderedDict() | |
perf_str: str = "" | |
# Print individual metrics | |
for indv_metric, indv_evaluator in zip(config.metrics, evaluator.evaluators): | |
perf_str += ", " + indv_metric + ": {}".format(indv_evaluator.get_f1() * 100) | |
result_dict[indv_metric] = OrderedDict() | |
result_dict[indv_metric]["recall"] = indv_evaluator.get_recall() * 100 | |
result_dict[indv_metric]["precision"] = indv_evaluator.get_precision() * 100 | |
result_dict[indv_metric]["fscore"] = indv_evaluator.get_f1() * 100 | |
result_dict["fscore"] = evaluator.get_f1() * 100 | |
result_dict["f1_macro"], result_dict["f1_micro"] = f1evaluator.get_numbers() | |
logger.info("F-score: %.1f %s" % (result_dict["fscore"], perf_str)) | |
return result_dict | |
def coref_evaluation( | |
config: DictConfig, | |
model: EntityRankingModel, | |
data_iter_map: Dict, | |
dataset: str, | |
split="dev", | |
_iter="", | |
teacher_force=False, | |
gold_mentions=False, | |
final_eval=False, | |
conll_data_dir: Dict = None, | |
) -> Dict: | |
"""Evaluation function which calls the dataset-appropriate coreference evaluation function.""" | |
return full_coref_evaluation( | |
config, | |
model, | |
data_iter_map, | |
dataset, | |
split=split, | |
_iter=_iter, | |
teacher_force=teacher_force, | |
gold_mentions=gold_mentions, | |
final_eval=final_eval, | |
conll_data_dir=conll_data_dir, | |
) | |