import json import torch from bleu import list_bleu def is_rank_0(): if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: return True else: return True return False class TextGenerationScorer: def __init__(self, tokenizer, bos_id, eos_id, output_path): self.bos_id = bos_id self.eos_id = eos_id self.output_path = output_path self.tokenizer = tokenizer def __call__(self, prediction): preds = prediction.predictions preds_size = prediction.predictions_size label_ids = prediction.label_ids label_size = prediction.label_size p_start, l_start = 0, 0 correct, total = 0, 0 ref = [] hyp = [] if is_rank_0(): fout = open(self.output_path, "w") for idx, (p_size, l_size) in enumerate(zip(preds_size, label_size)): p_end = p_start + p_size l_end = l_start + l_size pred = self.get_sequence(preds[p_start: p_end]) label = self.get_sequence(label_ids[l_start: l_end]) p_start = p_end l_start = l_end if pred == label: correct += 1 total += 1 if is_rank_0(): pred_text = self.tokenizer.decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() label_text = self.tokenizer.decode(label, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() ref.append(label_text) hyp.append(pred_text) fout.write( json.dumps({ "idx": idx, "pred": pred_text, "label": label_text}) + "\n") score = list_bleu([ref], hyp) return { "bleu": score, "accuracy": correct / total, "correct": correct, "total": total } def get_sequence(self, seq): processed_seq = [] for idx in seq: if idx == self.bos_id: continue if idx == self.eos_id: break processed_seq.append(int(idx)) return processed_seq