File size: 1,950 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import json
import torch
from bleu import list_bleu
def is_rank_0():
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
return True
else:
return True
return False
class TextGenerationScorer:
def __init__(self, tokenizer, bos_id, eos_id, output_path):
self.bos_id = bos_id
self.eos_id = eos_id
self.output_path = output_path
self.tokenizer = tokenizer
def __call__(self, prediction):
preds = prediction.predictions
preds_size = prediction.predictions_size
label_ids = prediction.label_ids
label_size = prediction.label_size
p_start, l_start = 0, 0
correct, total = 0, 0
ref = []
hyp = []
if is_rank_0():
fout = open(self.output_path, "w")
for idx, (p_size, l_size) in enumerate(zip(preds_size, label_size)):
p_end = p_start + p_size
l_end = l_start + l_size
pred = self.get_sequence(preds[p_start: p_end])
label = self.get_sequence(label_ids[l_start: l_end])
p_start = p_end
l_start = l_end
if pred == label:
correct += 1
total += 1
if is_rank_0():
pred_text = self.tokenizer.decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
label_text = self.tokenizer.decode(label, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
ref.append(label_text)
hyp.append(pred_text)
fout.write(
json.dumps({
"idx": idx,
"pred": pred_text,
"label": label_text}) + "\n")
score = list_bleu([ref], hyp)
return {
"bleu": score,
"accuracy": correct / total,
"correct": correct,
"total": total
}
def get_sequence(self, seq):
processed_seq = []
for idx in seq:
if idx == self.bos_id:
continue
if idx == self.eos_id:
break
processed_seq.append(int(idx))
return processed_seq
|