File size: 1,918 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Computes the repetition metric. Adapted from: https://raw.githubusercontent.com/ari-holtzman/degen/master/metrics/repetition.py"""


def repetition(tokenized_texts, tokenizer):
    """
    Args:
        tokenized_texts: (List[List[int]]) generated input tokenized texts.

    Computes the repetition metric https://arxiv.org/pdf/1904.09751.pdf showing how each
    example is repeating itself, specifically the phrase the generation is repeating
    and how many times it is repeated.
    """
    SEP = tokenizer.encode(tokenizer.bos_token)[0]
    repetition_stats = []
    max_n = 90
    num_examples = len(tokenized_texts)
    n_repeated_examples = 0
    for tokenized_text in tokenized_texts:
        if tokenized_text[-1] == SEP:
            tokenized_text.pop(-1)
        rev_gen = list(reversed(tokenized_text))
        last_n_repeats = [0] * max_n
        for n in range(1, max_n + 1):
            n_repeat = 1
            while (
                len(rev_gen[n * n_repeat : n * (n_repeat + 1)]) == n
                and rev_gen[n * n_repeat : n * (n_repeat + 1)] == rev_gen[:n]
            ):
                n_repeat += 1
            last_n_repeats[n - 1] = n_repeat
        max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x])
        if last_n_repeats[max_repeated_n] > 1 and (
            max_repeated_n + 1 >= 3 or last_n_repeats[max_repeated_n] > 50
        ):
            repetition_stats.append(
                {
                    "repeated_phrase": list(reversed(rev_gen[: max_repeated_n + 1])),
                    "repeated_times": last_n_repeats[max_repeated_n],
                    "repeated_phrase_length": max_repeated_n + 1,
                }
            )
            n_repeated_examples += 1
        else:
            repetition_stats.append({})

    return {
        "repetition": n_repeated_examples * 1.0 / num_examples
    }  # , "repetition_stats": repetition_stats}