File size: 1,572 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
import torch

def is_rank_0():
  if torch.distributed.is_initialized():
    if torch.distributed.get_rank() == 0:
      return True
  else:
    return True
  return False

class MatchSequenceScorer:
  def __init__(self, bos_id, eos_id, output_path):
    if isinstance(bos_id, list):
      self.bos_ids = bos_id
    else:
      self.bos_ids = [bos_id]
    self.eos_id = eos_id
    self.output_path = output_path

  def __call__(self, prediction):
    preds = prediction.predictions
    preds_size = prediction.predictions_size
    label_ids = prediction.label_ids
    label_size = prediction.label_size
    p_start, l_start = 0, 0
    correct, total = 0, 0
    if is_rank_0():
      fout = open(self.output_path, "w")
    for idx, (p_size, l_size) in enumerate(zip(preds_size, label_size)):
      p_end = p_start + p_size
      l_end = l_start + l_size
      pred = self.get_sequence(preds[p_start: p_end])
      label = self.get_sequence(label_ids[l_start: l_end])
      p_start = p_end
      l_start = l_end
      if pred == label:
        correct += 1
      total += 1
      if is_rank_0():
        fout.write(
          json.dumps({
            "idx": idx,
            "pred": pred,
            "label": label}) + "\n")
    return {
      "accuracy": correct / total,
      "correct": correct,
      "total": total
    }


  def get_sequence(self, seq):
    processed_seq = []
    for idx in seq:
      # if idx in self.bos_ids:
      #   continue
      if idx == self.eos_id:
        break
      processed_seq.append(int(idx))
    return processed_seq