Spaces:
Runtime error
Runtime error
import random | |
from tencentpretrain.utils.constants import * | |
def mask_seq(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length): | |
vocab = tokenizer.vocab | |
PAD_ID = vocab.get(PAD_TOKEN) | |
for i in range(len(src) - 1, -1, -1): | |
if src[i] != PAD_ID: | |
break | |
src_no_pad = src[:i + 1] | |
tokens_index, src_no_pad = create_index(src_no_pad, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length) | |
if len(src_no_pad) < len(src): | |
src = src_no_pad + (len(src) - len(src_no_pad)) * [PAD_ID] | |
else: | |
src = src_no_pad | |
random.shuffle(tokens_index) | |
num_to_predict = max(1, int(round(len(src_no_pad) * 0.15))) | |
tgt_mlm = [] | |
for index_set in tokens_index: | |
if len(tgt_mlm) >= num_to_predict: | |
break | |
if whole_word_masking: | |
i = index_set[0] | |
mask_len = index_set[1] | |
if len(tgt_mlm) + mask_len > num_to_predict: | |
continue | |
for j in range(mask_len): | |
token = src[i + j] | |
tgt_mlm.append((i + j, token)) | |
prob = random.random() | |
if prob < 0.8: | |
src[i + j] = vocab.get(MASK_TOKEN) | |
elif prob < 0.9: | |
while True: | |
rdi = random.randint(1, len(vocab) - 1) | |
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]: | |
break | |
src[i + j] = rdi | |
elif span_masking: | |
i = index_set[0] | |
span_len = index_set[1] | |
if len(tgt_mlm) + span_len > num_to_predict: | |
continue | |
for j in range(span_len): | |
token = src[i + j] | |
tgt_mlm.append((i + j, token)) | |
prob = random.random() | |
if prob < 0.8: | |
for j in range(span_len): | |
src[i + j] = vocab.get(MASK_TOKEN) | |
elif prob < 0.9: | |
for j in range(span_len): | |
while True: | |
rdi = random.randint(1, len(vocab) - 1) | |
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]: | |
break | |
src[i + j] = rdi | |
else: | |
i = index_set[0] | |
token = src[i] | |
tgt_mlm.append((i, token)) | |
prob = random.random() | |
if prob < 0.8: | |
src[i] = vocab.get(MASK_TOKEN) | |
elif prob < 0.9: | |
while True: | |
rdi = random.randint(1, len(vocab) - 1) | |
if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]: | |
break | |
src[i] = rdi | |
tgt_mlm = sorted(tgt_mlm, key=lambda x: x[0]) | |
return src, tgt_mlm | |
def create_index(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length): | |
tokens_index = [] | |
span_end_position = -1 | |
vocab = tokenizer.vocab | |
PAD_ID = vocab.get(PAD_TOKEN) | |
if whole_word_masking: | |
src_wwm = [] | |
src_length = len(src) | |
has_cls, has_sep = False, False | |
if src[0] == vocab.get(CLS_TOKEN): | |
src = src[1:] | |
has_cls = True | |
if src[-1] == vocab.get(SEP_TOKEN): | |
src = src[:-1] | |
has_sep = True | |
sentence = "".join(tokenizer.convert_ids_to_tokens(src)).replace('[UNK]', '').replace('##', '') | |
import jieba | |
wordlist = jieba.cut(sentence) | |
if has_cls: | |
src_wwm += [vocab.get(CLS_TOKEN)] | |
for word in wordlist: | |
position = len(src_wwm) | |
src_wwm += tokenizer.convert_tokens_to_ids(tokenizer.tokenize(word)) | |
if len(src_wwm) < src_length: | |
tokens_index.append([position, len(src_wwm)-position]) | |
if has_sep: | |
src_wwm += [vocab.get(SEP_TOKEN)] | |
if len(src_wwm) > src_length: | |
src = src_wwm[:src_length] | |
else: | |
src = src_wwm | |
else: | |
for (i, token) in enumerate(src): | |
if token == vocab.get(CLS_TOKEN) or token == vocab.get(SEP_TOKEN) or token == PAD_ID: | |
continue | |
if not span_masking: | |
tokens_index.append([i]) | |
else: | |
if i < span_end_position: | |
continue | |
span_len = get_span_len(span_max_length, span_geo_prob) | |
span_end_position = i + span_len | |
if span_end_position > len(src): | |
span_len = len(src) - i | |
tokens_index.append([i, span_len]) | |
return tokens_index, src | |
def get_span_len(max_span_len, p): | |
geo_prob_cum = [0.0] | |
geo_prob = 1.0 | |
for i in range(max_span_len + 1): | |
if i == 0: | |
continue | |
if i == 1: | |
geo_prob *= p | |
geo_prob_cum.append(geo_prob_cum[-1] + geo_prob) | |
else: | |
geo_prob *= (1 - p) | |
geo_prob_cum.append(geo_prob_cum[-1] + geo_prob) | |
prob = geo_prob_cum[-1] * random.random() | |
for i in range(len(geo_prob_cum) - 1): | |
if prob >= geo_prob_cum[i] and prob < geo_prob_cum[i + 1]: | |
current_span_len = i + 1 | |
return current_span_len | |