import random from tencentpretrain.utils.constants import * def mask_seq(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length): vocab = tokenizer.vocab PAD_ID = vocab.get(PAD_TOKEN) for i in range(len(src) - 1, -1, -1): if src[i] != PAD_ID: break src_no_pad = src[:i + 1] tokens_index, src_no_pad = create_index(src_no_pad, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length) if len(src_no_pad) < len(src): src = src_no_pad + (len(src) - len(src_no_pad)) * [PAD_ID] else: src = src_no_pad random.shuffle(tokens_index) num_to_predict = max(1, int(round(len(src_no_pad) * 0.15))) tgt_mlm = [] for index_set in tokens_index: if len(tgt_mlm) >= num_to_predict: break if whole_word_masking: i = index_set[0] mask_len = index_set[1] if len(tgt_mlm) + mask_len > num_to_predict: continue for j in range(mask_len): token = src[i + j] tgt_mlm.append((i + j, token)) prob = random.random() if prob < 0.8: src[i + j] = vocab.get(MASK_TOKEN) elif prob < 0.9: while True: rdi = random.randint(1, len(vocab) - 1) if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]: break src[i + j] = rdi elif span_masking: i = index_set[0] span_len = index_set[1] if len(tgt_mlm) + span_len > num_to_predict: continue for j in range(span_len): token = src[i + j] tgt_mlm.append((i + j, token)) prob = random.random() if prob < 0.8: for j in range(span_len): src[i + j] = vocab.get(MASK_TOKEN) elif prob < 0.9: for j in range(span_len): while True: rdi = random.randint(1, len(vocab) - 1) if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]: break src[i + j] = rdi else: i = index_set[0] token = src[i] tgt_mlm.append((i, token)) prob = random.random() if prob < 0.8: src[i] = vocab.get(MASK_TOKEN) elif prob < 0.9: while True: rdi = random.randint(1, len(vocab) - 1) if rdi not in [vocab.get(CLS_TOKEN), vocab.get(SEP_TOKEN), vocab.get(MASK_TOKEN), PAD_ID]: break src[i] = rdi tgt_mlm = sorted(tgt_mlm, key=lambda x: x[0]) return src, tgt_mlm def create_index(src, tokenizer, whole_word_masking, span_masking, span_geo_prob, span_max_length): tokens_index = [] span_end_position = -1 vocab = tokenizer.vocab PAD_ID = vocab.get(PAD_TOKEN) if whole_word_masking: src_wwm = [] src_length = len(src) has_cls, has_sep = False, False if src[0] == vocab.get(CLS_TOKEN): src = src[1:] has_cls = True if src[-1] == vocab.get(SEP_TOKEN): src = src[:-1] has_sep = True sentence = "".join(tokenizer.convert_ids_to_tokens(src)).replace('[UNK]', '').replace('##', '') import jieba wordlist = jieba.cut(sentence) if has_cls: src_wwm += [vocab.get(CLS_TOKEN)] for word in wordlist: position = len(src_wwm) src_wwm += tokenizer.convert_tokens_to_ids(tokenizer.tokenize(word)) if len(src_wwm) < src_length: tokens_index.append([position, len(src_wwm)-position]) if has_sep: src_wwm += [vocab.get(SEP_TOKEN)] if len(src_wwm) > src_length: src = src_wwm[:src_length] else: src = src_wwm else: for (i, token) in enumerate(src): if token == vocab.get(CLS_TOKEN) or token == vocab.get(SEP_TOKEN) or token == PAD_ID: continue if not span_masking: tokens_index.append([i]) else: if i < span_end_position: continue span_len = get_span_len(span_max_length, span_geo_prob) span_end_position = i + span_len if span_end_position > len(src): span_len = len(src) - i tokens_index.append([i, span_len]) return tokens_index, src def get_span_len(max_span_len, p): geo_prob_cum = [0.0] geo_prob = 1.0 for i in range(max_span_len + 1): if i == 0: continue if i == 1: geo_prob *= p geo_prob_cum.append(geo_prob_cum[-1] + geo_prob) else: geo_prob *= (1 - p) geo_prob_cum.append(geo_prob_cum[-1] + geo_prob) prob = geo_prob_cum[-1] * random.random() for i in range(len(geo_prob_cum) - 1): if prob >= geo_prob_cum[i] and prob < geo_prob_cum[i + 1]: current_span_len = i + 1 return current_span_len