Spaces:
Build error
Build error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import os | |
import re | |
import shutil | |
from string import ascii_uppercase | |
from tqdm.auto import tqdm | |
from model.third_party.HMNet.Evaluation.OldROUGEEval import rouge | |
from model.third_party.HMNet.ThirdParty.ROUGE import pyrouge | |
from shutil import copyfile | |
from mpi4py import MPI | |
import torch | |
import logging | |
import json | |
def write_json_res( | |
output_file, tokenizers, x_ids, y_ids, x_tokens, y_tokens, predictions, gts | |
): | |
data = [] | |
# for x_id, y_id, x_token, y_token, preds, gt in zip(x_ids, y_ids, x_tokens, y_tokens, predictions, gts): | |
# x_id = tokenizers[0].decode(x_id, skip_special_tokens=False) if x_id.dim() == 1 else tokenizers[0].convert_tokens_to_string(x_token) | |
# y_id = tokenizers[1].decode(y_id, skip_special_tokens=False) if y_id.dim() == 1 else tokenizers[1].convert_tokens_to_string(y_token) | |
for x_token, y_token, preds, gt in zip(x_tokens, y_tokens, predictions, gts): | |
data.append( | |
{ | |
# 'x_ids': x_id, | |
# 'y_ids': y_id, | |
"x_tokens": x_token if isinstance(x_token, str) else " ".join(x_token), | |
"y_tokens": y_token if isinstance(y_token, str) else " ".join(y_token), | |
"predictions": preds, | |
"gt": gt, | |
} | |
) | |
json.dump(data, output_file, indent=4, ensure_ascii=False) | |
logger = logging.getLogger(__name__) | |
""" | |
This code can only be run within docker "rouge", because of the usage of rouge-perl | |
""" | |
"""" In ROUGE parlance, your summaries are ‘system’ summaries and the gold standard summaries are ‘model’ summaries. | |
The summaries should be in separate folders, whose paths are set with the system_dir and model_dir variables. | |
All summaries should contain one sentence per line.""" | |
class ROUGEEval: | |
""" | |
Wrapper class for pyrouge. | |
Compute ROUGE given predictions and references for summarization evaluation. | |
""" | |
def __init__(self, run_dir, save_dir, opt): | |
self.run_dir = run_dir | |
self.save_dir = save_dir | |
self.opt = opt | |
# use relative path to make it work on Philly | |
self.pyrouge_dir = os.path.join( | |
os.path.dirname(__file__), "../ThirdParty/ROUGE/ROUGE-1.5.5/" | |
) | |
self.eval_batches_num = self.opt.get("EVAL_BATCHES_NUM", float("Inf")) | |
self.best_score = -float("Inf") | |
self.best_res = {} | |
def reset_best_score(self, set_high=False): | |
if set_high: | |
self.best_score = float("Inf") | |
else: | |
self.best_score = -float("Inf") | |
def make_html_safe(self, s): | |
s = s.replace("<", "<") | |
s = s.replace(">", ">") | |
return s | |
def print_to_rouge_dir( | |
self, summaries, dir, suffix, split_chars, special_char_dict=None | |
): | |
for idx, summary in enumerate(summaries): | |
fname = os.path.join(dir, "%06d_%s.txt" % (idx, suffix)) | |
with open(fname, "wb") as f: | |
sents = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", summary) | |
for i, sent in enumerate(sents): | |
if split_chars: | |
# sent = re.sub(r'([\u4e00-\u9fff])', r' \1 ', sent) | |
for x in re.finditer(r"([\u4e00-\u9fff])", sent): | |
if not x.group(1) in special_char_dict: | |
special_char_dict[x.group(1)] = len(special_char_dict) | |
sent = sent.replace( | |
x.group(1), " {} ".format(special_char_dict[x.group(1)]) | |
) | |
if i == len(sents) - 1: | |
to_print = sent.encode("utf-8") | |
else: | |
to_print = sent.encode("utf-8") + "\n".encode("utf-8") | |
f.write(to_print) | |
def print_to_rouge_dir_gt(self, summaries, dir, suffix, split_chars): | |
if split_chars: | |
char_dict = {} | |
for idx, summary in enumerate(summaries): | |
for ref_idx, sub_summary in enumerate(summary.split(" ||| ")): | |
fname = os.path.join( | |
dir, "%s.%06d_%s.txt" % (ascii_uppercase[ref_idx], idx, suffix) | |
) | |
with open(fname, "wb") as f: | |
sents = re.split( | |
r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", sub_summary | |
) | |
for i, sent in enumerate(sents): | |
if split_chars: | |
for x in re.finditer(r"([\u4e00-\u9fff])", sent): | |
if not x.group(1) in char_dict: | |
char_dict[x.group(1)] = len(char_dict) | |
sent = sent.replace( | |
x.group(1), " {} ".format(char_dict[x.group(1)]) | |
) | |
if i == len(sents) - 1: | |
to_print = sent.encode("utf-8") | |
else: | |
to_print = sent.encode("utf-8") + "\n".encode("utf-8") | |
f.write(to_print) | |
if split_chars: | |
return char_dict | |
# def filter_empty(self, predictions, groundtruths): | |
# new_predicitons = [] | |
# new_groundtruths = [] | |
# | |
# for pred, gt in zip(predictions, groundtruths): | |
# if len(gt) == 0: | |
# continue | |
# new_groundtruths.append(gt) | |
# if len(pred) == 0: | |
# new_predicitons.append('<ept>') | |
# else: | |
# new_predicitons.append(pred) | |
# return new_predicitons, new_groundtruths | |
def _convert_tokens_to_string(self, tokenizer, tokens): | |
if "EVAL_TOKENIZED" in self.opt: | |
tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] | |
if "EVAL_LOWERCASE" in self.opt: | |
tokens = [t.lower() for t in tokens] | |
if "EVAL_TOKENIZED" in self.opt: | |
return " ".join(tokens) | |
else: | |
return tokenizer.decode( | |
tokenizer.convert_tokens_to_ids(tokens), skip_special_tokens=True | |
) | |
def eval_batches(self, module, dev_batches, save_folder, label=""): | |
max_sent_len = int(self.opt["MAX_GEN_LENGTH"]) | |
logger.info( | |
"Decoding current model ... \nSaving folder is {}".format(save_folder) | |
) | |
predictions = [] # prediction of tokens from model | |
x_tokens = [] # input tokens | |
y_tokens = [] # groundtruths tokens | |
x_ids = [] # input token ids | |
y_ids = [] # groundtruths token ids | |
gts = [] # groundtruths string | |
got_better_score = False | |
# err = 0 | |
if not isinstance(module.tokenizer, list): | |
encoder_tokenizer = module.tokenizer | |
decoder_tokenizer = module.tokenizer | |
elif len(module.tokenizer) == 1: | |
encoder_tokenizer = module.tokenizer[0] | |
decoder_tokenizer = module.tokenizer[0] | |
elif len(module.tokenizer) == 2: | |
encoder_tokenizer = module.tokenizer[0] | |
decoder_tokenizer = module.tokenizer[1] | |
else: | |
assert False, f"len(module.tokenizer) > 2" | |
with torch.no_grad(): | |
for j, dev_batch in enumerate(dev_batches): | |
for b in dev_batch: | |
if torch.is_tensor(dev_batch[b]): | |
dev_batch[b] = dev_batch[b].to(self.opt["device"]) | |
beam_search_res = module( | |
dev_batch, beam_search=True, max_sent_len=max_sent_len | |
) | |
pred = [ | |
[t[0] for t in x] if len(x) > 0 else [[]] for x in beam_search_res | |
] | |
predictions.extend( | |
[ | |
[ | |
self._convert_tokens_to_string(decoder_tokenizer, tt) | |
for tt in t | |
] | |
for t in pred | |
] | |
) | |
gts.extend( | |
[ | |
self._convert_tokens_to_string(decoder_tokenizer, t) | |
for t in dev_batch["decoder_tokens"] | |
] | |
) | |
x_tokens.extend(dev_batch["encoder_tokens"]) | |
y_tokens.extend(dev_batch["decoder_tokens"]) | |
if ("DEBUG" in self.opt and j >= 10) or j >= self.eval_batches_num: | |
# in debug mode (decode first 10 batches) ortherwise decode first self.eval_batches_num bathes | |
break | |
# use MPI to gather results from all processes / GPUs | |
# the result of the gather operation is a list of sublists | |
# each sublist corresponds to the list created on one of the MPI processes (or GPUs, respectively) | |
# we flatten this list into a "simple" list | |
assert len(predictions) == len( | |
gts | |
), "len(predictions): {0}, len(gts): {1}".format(len(predictions), len(gts)) | |
comm = MPI.COMM_WORLD | |
predictions = comm.gather(predictions, root=0) | |
x_tokens = comm.gather(x_tokens, root=0) | |
y_tokens = comm.gather(y_tokens, root=0) | |
# if GPU numbers are high (>=8), passing x_ids, y_ids to a rank 0 will cause out of memory | |
# x_ids = comm.gather(x_ids, root=0) | |
# y_ids = comm.gather(y_ids, root=0) | |
gts = comm.gather(gts, root=0) | |
if self.opt["rank"] == 0: | |
# flatten lists | |
predictions = [item for sublist in predictions for item in sublist] | |
y_tokens = [item for sublist in y_tokens for item in sublist] | |
x_tokens = [item for sublist in x_tokens for item in sublist] | |
# x_ids = [item for sublist in x_ids for item in sublist] | |
# y_ids = [item for sublist in y_ids for item in sublist] | |
gts = [item for sublist in gts for item in sublist] | |
# import pdb; pdb.set_trace() | |
assert ( | |
len(predictions) == len(y_tokens) == len(x_tokens) == len(gts) | |
), "len(predictions): {0}, len(y_tokens): {1}, len(x_tokens): {2}, len(gts): {3}".format( | |
len(predictions), len(y_tokens), len(x_tokens), len(gts) | |
) | |
# write intermediate results only on rank 0 | |
if not os.path.isdir(os.path.join(save_folder, "intermediate_results")): | |
os.makedirs(os.path.join(save_folder, "intermediate_results")) | |
top_1_predictions = [pred[0] for pred in predictions] | |
with open( | |
os.path.join( | |
save_folder, "intermediate_results", "res_" + label + ".json" | |
), | |
"w", | |
encoding="utf-8", | |
) as output_file: | |
write_json_res( | |
output_file, | |
[encoder_tokenizer, decoder_tokenizer], | |
x_ids, | |
y_ids, | |
x_tokens, | |
y_tokens, | |
predictions, | |
gts, | |
) | |
try: | |
result = self.eval(top_1_predictions, gts) | |
except Exception as e: | |
logger.exception("ROUGE Eval ERROR") | |
result = {} | |
score = -float("Inf") | |
pass # this happens when no overlapping between pred and gts | |
else: | |
rouge_su4 = rouge(top_1_predictions, gts) # f, prec, recall | |
result = { | |
"ROUGE_1": result["rouge_1_f_score"] * 100.0, | |
"ROUGE_1_Prc": result["rouge_1_precision"] * 100.0, | |
"ROUGE_1_Rcl": result["rouge_1_recall"] * 100.0, | |
"ROUGE_2": result["rouge_2_f_score"] * 100.0, | |
"ROUGE_2_Prc": result["rouge_2_precision"] * 100.0, | |
"ROUGE_2_Rcl": result["rouge_2_recall"] * 100.0, | |
"ROUGE_L": result["rouge_l_f_score"] * 100.0, | |
"ROUGE_L_Prc": result["rouge_l_precision"] * 100.0, | |
"ROUGE_L_Rcl": result["rouge_l_recall"] * 100.0, | |
"ROUGE_SU4": rouge_su4["rouge_su4_f_score"] * 100.0, | |
} | |
score = result["ROUGE_1"] | |
if score > self.best_score: | |
copyfile( | |
os.path.join( | |
save_folder, | |
"intermediate_results", | |
"res_" + label + ".json", | |
), | |
os.path.join( | |
save_folder, | |
"intermediate_results", | |
"res_" + label + ".best.json", | |
), | |
) | |
self.best_score = score | |
self.best_res = result | |
got_better_score = True | |
else: | |
result = {} | |
score = -float("Inf") | |
got_better_score = False | |
return result, score, got_better_score | |
def eval(self, predictions, groundtruths): | |
# predictions, groundtruths = self.filter_empty(predictions, groundtruths) | |
predictions = [self.make_html_safe(w) for w in predictions] | |
groundtruths = [self.make_html_safe(w) for w in groundtruths] | |
pred_dir = os.path.join(self.save_dir, "predictions") | |
if os.path.exists(pred_dir): | |
shutil.rmtree(pred_dir) | |
os.makedirs(pred_dir) | |
gt_dir = os.path.join(self.save_dir, "groundtruths") | |
if os.path.exists(gt_dir): | |
shutil.rmtree(gt_dir) | |
os.makedirs(gt_dir) | |
special_char_dict = self.print_to_rouge_dir_gt( | |
groundtruths, gt_dir, "gt", "SPLIT_CHARS_FOR_EVAL" in self.opt | |
) | |
self.print_to_rouge_dir( | |
predictions, | |
pred_dir, | |
"pred", | |
"SPLIT_CHARS_FOR_EVAL" in self.opt, | |
special_char_dict, | |
) | |
r = pyrouge.Rouge155(self.pyrouge_dir) | |
r.system_dir = pred_dir | |
r.model_dir = gt_dir | |
r.system_filename_pattern = "(\d+)_pred.txt" | |
r.model_filename_pattern = "[A-Z].#ID#_gt.txt" | |
results = r.output_to_dict(r.convert_and_evaluate()) | |
return results | |