Spaces:
Runtime error
Runtime error
File size: 5,369 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import random
from tencentpretrain.utils.constants import *
def mask_seq(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length):
vocab = tokenizer.vocab
PAD_ID = vocab.get(PAD_TOKEN)
for i in range(len(src) - 1, -1, -1):
if src[i] != PAD_ID:
break
src_no_pad = src[:i + 1]
tokens_index, src_no_pad = create_index(src_no_pad, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length)
if len(src_no_pad) < len(src):
src = src_no_pad + (len(src) - len(src_no_pad)) * [PAD_ID]
else:
src = src_no_pad
random.shuffle(tokens_index)
num_to_predict = max(1, int(round(len(src_no_pad) * 0.15)))
tgt_mlm = []
for index_set in tokens_index:
if len(tgt_mlm) >= num_to_predict:
break
if whole_word_masking:
i = index_set[0]
mask_len = index_set[1]
if len(tgt_mlm) + mask_len > num_to_predict:
continue
for j in range(mask_len):
token = src[i + j]
tgt_mlm.append((i + j, token))
prob = random.random()
if prob < 0.8:
src[i + j] = vocab.get(MASK_TOKEN)
elif prob < 0.9:
while True:
rdi = random.randint(1, len(vocab) - 1)
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]:
break
src[i + j] = rdi
elif span_masking:
i = index_set[0]
span_len = index_set[1]
if len(tgt_mlm) + span_len > num_to_predict:
continue
for j in range(span_len):
token = src[i + j]
tgt_mlm.append((i + j, token))
prob = random.random()
if prob < 0.8:
for j in range(span_len):
src[i + j] = vocab.get(MASK_TOKEN)
elif prob < 0.9:
for j in range(span_len):
while True:
rdi = random.randint(1, len(vocab) - 1)
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]:
break
src[i + j] = rdi
else:
i = index_set[0]
token = src[i]
tgt_mlm.append((i, token))
prob = random.random()
if prob < 0.8:
src[i] = vocab.get(MASK_TOKEN)
elif prob < 0.9:
while True:
rdi = random.randint(1, len(vocab) - 1)
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]:
break
src[i] = rdi
tgt_mlm = sorted(tgt_mlm, key=lambda x: x[0])
return src, tgt_mlm
def create_index(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length):
tokens_index = []
span_end_position = -1
vocab = tokenizer.vocab
PAD_ID = vocab.get(PAD_TOKEN)
if whole_word_masking:
src_wwm = []
src_length = len(src)
has_cls, has_sep = False, False
if src[0] == vocab.get(CLS_TOKEN):
src = src[1:]
has_cls = True
if src[-1] == vocab.get(SEP_TOKEN):
src = src[:-1]
has_sep = True
sentence = "".join(tokenizer.convert_ids_to_tokens(src)).replace('[UNK]', '').replace('##', '')
import jieba
wordlist = jieba.cut(sentence)
if has_cls:
src_wwm += [vocab.get(CLS_TOKEN)]
for word in wordlist:
position = len(src_wwm)
src_wwm += tokenizer.convert_tokens_to_ids(tokenizer.tokenize(word))
if len(src_wwm) < src_length:
tokens_index.append([position, len(src_wwm)-position])
if has_sep:
src_wwm += [vocab.get(SEP_TOKEN)]
if len(src_wwm) > src_length:
src = src_wwm[:src_length]
else:
src = src_wwm
else:
for (i, token) in enumerate(src):
if token == vocab.get(CLS_TOKEN) or token == vocab.get(SEP_TOKEN) or token == PAD_ID:
continue
if not span_masking:
tokens_index.append([i])
else:
if i < span_end_position:
continue
span_len = get_span_len(span_max_length, span_geo_prob)
span_end_position = i + span_len
if span_end_position > len(src):
span_len = len(src) - i
tokens_index.append([i, span_len])
return tokens_index, src
def get_span_len(max_span_len, p):
geo_prob_cum = [0.0]
geo_prob = 1.0
for i in range(max_span_len + 1):
if i == 0:
continue
if i == 1:
geo_prob *= p
geo_prob_cum.append(geo_prob_cum[-1] + geo_prob)
else:
geo_prob *= (1 - p)
geo_prob_cum.append(geo_prob_cum[-1] + geo_prob)
prob = geo_prob_cum[-1] * random.random()
for i in range(len(geo_prob_cum) - 1):
if prob >= geo_prob_cum[i] and prob < geo_prob_cum[i + 1]:
current_span_len = i + 1
return current_span_len
|