yuyan-10b / megatron /data /dataset_utils.py
Shawn001's picture
Upload 131 files
23bd7af
raw
history blame
36.5 kB
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import math
import os
import time
import collections
import numpy as np
import torch
import random
from megatron import (
get_tokenizer,
get_args,
mpu,
print_rank_0
)
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPE_GLM = 'glm'
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_GLM]
def get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0]*num_datasets
prefixes = [0]*num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2*i])
prefixes[i] = (data_prefix[2*i+1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
# First part:
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
# Second part:
tokens_b = []
for j in range(a_end, n_sentences):
tokens_b.extend(sample[j])
# Random next:
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
if len_a > len_b:
len_a -= 1
tokens = tokens_a
else:
len_b -= 1
tokens = tokens_b
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
tokentypes = []
# [CLS].
tokens.append(cls_id)
tokentypes.append(0)
# Segment A.
for token in tokens_a:
tokens.append(token)
tokentypes.append(0)
# [SEP].
tokens.append(sep_id)
tokentypes.append(0)
# Segment B.
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
def create_tokens(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
# [CLS].
tokens.append(cls_id)
# Segment A.
for token in tokens_a:
tokens.append(token)
# [SEP].
tokens.append(sep_id)
# Segment B.
for token in tokens_b:
tokens.append(token)
if tokens_b:
# [SEP].
tokens.append(sep_id)
return tokens
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
def create_masked_lm_predictions(tokens,
vocab_id_list, vocab_id_to_token_dict,
masked_lm_prob,
cls_id, sep_id, mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False,
geometric_dist=False,
masking_style="bert"):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
if not geometric_dist:
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
(masked_lms, masked_spans) = ([], [])
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
if not geometric_dist:
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
if masking_style == "bert":
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
elif masking_style == "t5":
masked_token = mask_id
else:
raise ValueError("invalid value of masking style")
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_spans.append(MaskedLmInstance(
index=index_set,
label=[tokens[index] for index in index_set]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
# Sort the spans by the index of the first span
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob, short_seq_prob, seed,
skip_warmup, binary_head=False,
max_seq_length_dec=None,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
max_seq_length_dec,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, max_seq_length_dec, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob, short_seq_prob, seed,
skip_warmup, binary_head,
max_seq_length_dec,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset
from megatron.data.ict_dataset import ICTDataset
from megatron.data.t5_dataset import T5Dataset
from megatron.data.glm_dataset import GlmDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.use_one_sent_docs,
binary_head=binary_head,
**kwargs
)
elif dataset_type == DSET_TYPE_T5:
dataset = T5Dataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
max_seq_length_dec=max_seq_length_dec,
short_seq_prob=short_seq_prob,
**kwargs
)
elif dataset_type == DSET_TYPE_BERT:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs
)
elif dataset_type == DSET_TYPE_GLM:
dataset = GlmDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs
)
else:
raise NotImplementedError("Dataset type not fully implemented.")
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name,
binary_head):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
verbose,
2 if binary_head else 1)
print_rank_0(' > done building samples index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
class MaskEncoder(object):
def __init__(self):
tokenizer = get_tokenizer()
self.vocab_size = tokenizer.vocab_size
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
import jieba_fast
self.zh_tokenizer = jieba_fast.lcut
self.random_ratio = 0
def word_starts(self, source):
raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()]
words = [raw_tokens[0]] + self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]]
def _is_chinese_char(c):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if len(c) > 1:
return all([_is_chinese_char(c_i) for c_i in c])
cp = ord(c)
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def align_linear(atokens, btokens):
a2c = []
c2b = []
a2b = []
length = 0
for tok in atokens:
a2c.append([length + i for i in range(len(tok))])
length += len(tok)
for i, tok in enumerate(btokens):
c2b.extend([i for _ in range(len(tok))])
for i, amap in enumerate(a2c):
bmap = [c2b[ci] for ci in amap]
a2b.append(list(set(bmap)))
return a2b
raw_to_word_align = align_linear(raw_tokens, words)
is_word_start = torch.zeros(source.size())
word_starts = []
skip_cur_word = True
for i in range(1, len(raw_to_word_align)):
if raw_to_word_align[i-1] == raw_to_word_align[i]:
# not a word start, as they align to the same word
if not skip_cur_word and not _is_chinese_char(raw_tokens[i]):
word_starts.pop(-1)
skip_cur_word = True
continue
else:
is_word_start[i] = 1
if _is_chinese_char(raw_tokens[i]):
word_starts.append(i)
skip_cur_word = False
is_word_start[0] = 0
is_word_start[-1] = 0
word_starts = torch.tensor(word_starts).long().view(-1, 1)
return is_word_start, word_starts
def add_whole_word_mask(self, source, p, replace_length=1):
is_word_start, word_starts = self.word_starts(source)
num_to_mask_word = int(math.ceil(word_starts.size(0) * p))
num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1))
num_to_mask = num_to_mask_word + num_to_mask_char
if num_to_mask > word_starts.size(0):
word_starts = is_word_start.nonzero(as_tuple=False)
num_inserts = 0
if num_to_mask == 0:
return source
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
if len(indices) < num_to_mask:
num_to_mask = len(indices)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
# print(source.size(), word_starts.size(), indices.size(), mask_random.size())
# try:
source[indices] = self.mask_id
source[indices[mask_random]] = torch.randint(
1, self.vocab_size, size=(mask_random.sum(),)
)
# except:
# print(source)
# print(indices)
# print(mask_random)
# print()
# sorted_indices = torch.sort(indices)[0]
# continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:])
# continue_mask_indices = sorted_indices[1:][continue_mask_pos]
# to_keep[continue_mask_indices] = 0
# for char indices, we already masked, the following loop handles word mask
indices = indices[:num_to_mask_word]
mask_random = mask_random[:num_to_mask_word]
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_id
source[indices[mask_random]] = torch.randint(
1, self.vocab_size, size=(mask_random.sum(),)
)
assert source_length - 1 not in indices
source = source[to_keep]
return source
def shif_chinese_word(self, tokens, tokens_bf_mask):
assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function'
buff_list = []
buff_list_index = []
for i in range(len(tokens)):
if tokens[i] == tokens_bf_mask[i]:
if len(buff_list) == 0:
continue
else:
if len(buff_list) != 1:
random.shuffle(buff_list)
tokens[buff_list_index[0] : buff_list_index[-1]+1] = buff_list
buff_list = []
buff_list_index = []
else:
buff_list.append(tokens_bf_mask[i])
buff_list_index.append(i)
return tokens
def mass_style_mask(self, tokens):
tokens = tokens[:]
p = random.uniform(0.3, 0.5)
num_to_mask = int(len(tokens) * p)
start_index = int((1 - p) / 2 * len(tokens))
tokens[start_index : start_index + num_to_mask] = [self.mask_id] * num_to_mask
return tokens
def delete_chinese_word(self, tokens, tokens_bf_mask):
return_tokens = []
assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function'
for i in range(len(tokens)):
if tokens[i] == tokens_bf_mask[i]:
return_tokens.append(tokens[i])
return return_tokens