szukevin's picture
upload
7900c16
import random
from tencentpretrain.utils.constants import *
def mask_seq(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length):
vocab = tokenizer.vocab
PAD_ID = vocab.get(PAD_TOKEN)
for i in range(len(src) - 1, -1, -1):
if src[i] != PAD_ID:
break
src_no_pad = src[:i + 1]
tokens_index, src_no_pad = create_index(src_no_pad, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length)
if len(src_no_pad) < len(src):
src = src_no_pad + (len(src) - len(src_no_pad)) * [PAD_ID]
else:
src = src_no_pad
random.shuffle(tokens_index)
num_to_predict = max(1, int(round(len(src_no_pad) * 0.15)))
tgt_mlm = []
for index_set in tokens_index:
if len(tgt_mlm) >= num_to_predict:
break
if whole_word_masking:
i = index_set[0]
mask_len = index_set[1]
if len(tgt_mlm) + mask_len > num_to_predict:
continue
for j in range(mask_len):
token = src[i + j]
tgt_mlm.append((i + j, token))
prob = random.random()
if prob < 0.8:
src[i + j] = vocab.get(MASK_TOKEN)
elif prob < 0.9:
while True:
rdi = random.randint(1, len(vocab) - 1)
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]:
break
src[i + j] = rdi
elif span_masking:
i = index_set[0]
span_len = index_set[1]
if len(tgt_mlm) + span_len > num_to_predict:
continue
for j in range(span_len):
token = src[i + j]
tgt_mlm.append((i + j, token))
prob = random.random()
if prob < 0.8:
for j in range(span_len):
src[i + j] = vocab.get(MASK_TOKEN)
elif prob < 0.9:
for j in range(span_len):
while True:
rdi = random.randint(1, len(vocab) - 1)
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]:
break
src[i + j] = rdi
else:
i = index_set[0]
token = src[i]
tgt_mlm.append((i, token))
prob = random.random()
if prob < 0.8:
src[i] = vocab.get(MASK_TOKEN)
elif prob < 0.9:
while True:
rdi = random.randint(1, len(vocab) - 1)
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]:
break
src[i] = rdi
tgt_mlm = sorted(tgt_mlm, key=lambda x: x[0])
return src, tgt_mlm
def create_index(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length):
tokens_index = []
span_end_position = -1
vocab = tokenizer.vocab
PAD_ID = vocab.get(PAD_TOKEN)
if whole_word_masking:
src_wwm = []
src_length = len(src)
has_cls, has_sep = False, False
if src[0] == vocab.get(CLS_TOKEN):
src = src[1:]
has_cls = True
if src[-1] == vocab.get(SEP_TOKEN):
src = src[:-1]
has_sep = True
sentence = "".join(tokenizer.convert_ids_to_tokens(src)).replace('[UNK]', '').replace('##', '')
import jieba
wordlist = jieba.cut(sentence)
if has_cls:
src_wwm += [vocab.get(CLS_TOKEN)]
for word in wordlist:
position = len(src_wwm)
src_wwm += tokenizer.convert_tokens_to_ids(tokenizer.tokenize(word))
if len(src_wwm) < src_length:
tokens_index.append([position, len(src_wwm)-position])
if has_sep:
src_wwm += [vocab.get(SEP_TOKEN)]
if len(src_wwm) > src_length:
src = src_wwm[:src_length]
else:
src = src_wwm
else:
for (i, token) in enumerate(src):
if token == vocab.get(CLS_TOKEN) or token == vocab.get(SEP_TOKEN) or token == PAD_ID:
continue
if not span_masking:
tokens_index.append([i])
else:
if i < span_end_position:
continue
span_len = get_span_len(span_max_length, span_geo_prob)
span_end_position = i + span_len
if span_end_position > len(src):
span_len = len(src) - i
tokens_index.append([i, span_len])
return tokens_index, src
def get_span_len(max_span_len, p):
geo_prob_cum = [0.0]
geo_prob = 1.0
for i in range(max_span_len + 1):
if i == 0:
continue
if i == 1:
geo_prob *= p
geo_prob_cum.append(geo_prob_cum[-1] + geo_prob)
else:
geo_prob *= (1 - p)
geo_prob_cum.append(geo_prob_cum[-1] + geo_prob)
prob = geo_prob_cum[-1] * random.random()
for i in range(len(geo_prob_cum) - 1):
if prob >= geo_prob_cum[i] and prob < geo_prob_cum[i + 1]:
current_span_len = i + 1
return current_span_len