File size: 1,950 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import torch

from bleu import list_bleu

def is_rank_0():
  if torch.distributed.is_initialized():
    if torch.distributed.get_rank() == 0:
      return True
  else:
    return True
  return False

class TextGenerationScorer:
  def __init__(self, tokenizer, bos_id, eos_id, output_path):
    self.bos_id = bos_id
    self.eos_id = eos_id
    self.output_path = output_path
    self.tokenizer = tokenizer

  def __call__(self, prediction):
    preds = prediction.predictions
    preds_size = prediction.predictions_size
    label_ids = prediction.label_ids
    label_size = prediction.label_size
    p_start, l_start = 0, 0
    correct, total = 0, 0
    ref = []
    hyp = []
    if is_rank_0():
      fout = open(self.output_path, "w")
    for idx, (p_size, l_size) in enumerate(zip(preds_size, label_size)):
      p_end = p_start + p_size
      l_end = l_start + l_size
      pred = self.get_sequence(preds[p_start: p_end])
      label = self.get_sequence(label_ids[l_start: l_end])
      p_start = p_end
      l_start = l_end
      if pred == label:
        correct += 1
      total += 1
      if is_rank_0():
        pred_text = self.tokenizer.decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
        label_text = self.tokenizer.decode(label, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
        ref.append(label_text)
        hyp.append(pred_text)
        fout.write(
          json.dumps({
            "idx": idx,
            "pred": pred_text,
            "label": label_text}) + "\n")
    score = list_bleu([ref], hyp)
    return {
      "bleu": score,
      "accuracy": correct / total,
      "correct": correct,
      "total": total
    }


  def get_sequence(self, seq):
    processed_seq = []
    for idx in seq:
      if idx == self.bos_id:
        continue
      if idx == self.eos_id:
        break
      processed_seq.append(int(idx))
    return processed_seq