Spaces:
Sleeping
Sleeping
File size: 1,918 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
"""Computes the repetition metric. Adapted from: https://raw.githubusercontent.com/ari-holtzman/degen/master/metrics/repetition.py"""
def repetition(tokenized_texts, tokenizer):
"""
Args:
tokenized_texts: (List[List[int]]) generated input tokenized texts.
Computes the repetition metric https://arxiv.org/pdf/1904.09751.pdf showing how each
example is repeating itself, specifically the phrase the generation is repeating
and how many times it is repeated.
"""
SEP = tokenizer.encode(tokenizer.bos_token)[0]
repetition_stats = []
max_n = 90
num_examples = len(tokenized_texts)
n_repeated_examples = 0
for tokenized_text in tokenized_texts:
if tokenized_text[-1] == SEP:
tokenized_text.pop(-1)
rev_gen = list(reversed(tokenized_text))
last_n_repeats = [0] * max_n
for n in range(1, max_n + 1):
n_repeat = 1
while (
len(rev_gen[n * n_repeat : n * (n_repeat + 1)]) == n
and rev_gen[n * n_repeat : n * (n_repeat + 1)] == rev_gen[:n]
):
n_repeat += 1
last_n_repeats[n - 1] = n_repeat
max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x])
if last_n_repeats[max_repeated_n] > 1 and (
max_repeated_n + 1 >= 3 or last_n_repeats[max_repeated_n] > 50
):
repetition_stats.append(
{
"repeated_phrase": list(reversed(rev_gen[: max_repeated_n + 1])),
"repeated_times": last_n_repeats[max_repeated_n],
"repeated_phrase_length": max_repeated_n + 1,
}
)
n_repeated_examples += 1
else:
repetition_stats.append({})
return {
"repetition": n_repeated_examples * 1.0 / num_examples
} # , "repetition_stats": repetition_stats}
|