|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT Style dataset.""" |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from megatron import ( |
|
get_args, |
|
get_tokenizer, |
|
mpu, |
|
print_rank_0 |
|
) |
|
from megatron.data.dataset_utils import ( |
|
get_samples_mapping, |
|
get_a_and_b_segments, |
|
truncate_segments, |
|
create_tokens_and_tokentypes, |
|
create_tokens, |
|
create_masked_lm_predictions, |
|
MaskEncoder |
|
) |
|
|
|
class DummyBertDataset(torch.utils.data.Dataset): |
|
def __init__(self, name, num_samples, max_seq_length): |
|
self.name = name |
|
self.num_samples = num_samples |
|
self.max_seq_length = max_seq_length |
|
self.np_rng = np.random.RandomState(seed=0) |
|
|
|
|
|
tokenizer = get_tokenizer() |
|
self.vocab_id_list = list(tokenizer.inv_vocab.keys()) |
|
self.vocab_id_to_token_dict = tokenizer.inv_vocab |
|
self.cls_id = tokenizer.cls |
|
self.sep_id = tokenizer.sep |
|
self.mask_id = tokenizer.mask |
|
self.pad_id = tokenizer.pad |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def __getitem__(self, idx): |
|
tokens = self.np_rng.randint(1000, 2000, self.max_seq_length) |
|
masked_position = np.arange(int(tokens.shape[0] * 0.15)) |
|
tokens = tokens.astype(np.int64) |
|
labels = tokens[masked_position] |
|
label_np = np.full_like(tokens, -1) |
|
label_np[masked_position] = labels |
|
tokens[masked_position] = self.mask_id |
|
loss_mask_np = np.zeros_like(tokens) |
|
loss_mask_np[masked_position] = 1 |
|
train_sample = { |
|
'text': tokens, |
|
'types': np.zeros_like(tokens), |
|
'labels': label_np, |
|
'is_random': 0, |
|
'loss_mask': loss_mask_np, |
|
'padding_mask': np.ones_like(tokens), |
|
'truncated': 0 |
|
} |
|
return train_sample |
|
|
|
class GlmDataset(torch.utils.data.Dataset): |
|
|
|
def __init__(self, name, indexed_dataset, data_prefix, |
|
num_epochs, max_num_samples, masked_lm_prob, |
|
max_seq_length, short_seq_prob, seed, binary_head): |
|
|
|
|
|
self.name = name |
|
self.seed = seed |
|
self.masked_lm_prob = masked_lm_prob |
|
self.max_seq_length = max_seq_length |
|
self.binary_head = binary_head |
|
|
|
|
|
self.indexed_dataset = indexed_dataset |
|
|
|
|
|
self.samples_mapping = get_samples_mapping(self.indexed_dataset, |
|
data_prefix, |
|
num_epochs, |
|
max_num_samples, |
|
self.max_seq_length - 3, |
|
short_seq_prob, |
|
self.seed, |
|
self.name, |
|
self.binary_head) |
|
|
|
|
|
tokenizer = get_tokenizer() |
|
self.vocab_id_list = list(tokenizer.inv_vocab.keys()) |
|
self.vocab_id_to_token_dict = tokenizer.inv_vocab |
|
self.cls_id = tokenizer.cls |
|
self.sep_id = tokenizer.sep |
|
self.mask_id = tokenizer.mask |
|
self.pad_id = tokenizer.pad |
|
|
|
def __len__(self): |
|
return self.samples_mapping.shape[0] |
|
|
|
def __getitem__(self, idx): |
|
start_idx, end_idx, seq_length = self.samples_mapping[idx] |
|
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] |
|
|
|
|
|
|
|
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) |
|
return build_training_sample(sample, seq_length, |
|
self.max_seq_length, |
|
self.vocab_id_list, |
|
self.vocab_id_to_token_dict, |
|
self.cls_id, self.sep_id, |
|
self.mask_id, self.pad_id, |
|
self.masked_lm_prob, np_rng, |
|
self.binary_head) |
|
|
|
def sent_level_task(binary_head, sample, target_seq_length, max_seq_length, np_rng): |
|
if binary_head: |
|
|
|
assert len(sample) > 1 |
|
assert target_seq_length <= max_seq_length |
|
|
|
if binary_head: |
|
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) |
|
else: |
|
tokens_a = [] |
|
for j in range(len(sample)): |
|
tokens_a.extend(sample[j]) |
|
tokens_b = [] |
|
is_next_random = False |
|
|
|
|
|
max_num_tokens = target_seq_length |
|
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), |
|
len(tokens_b), max_num_tokens, np_rng) |
|
return is_next_random, truncated, max_num_tokens, tokens_a, tokens_b |
|
|
|
def generate_decoder_input_and_output(tokens, pad_id, sep_id): |
|
""" |
|
decoder input [SEP] [CSL] A B C D |
|
decoder output [CLS] A B C D E |
|
""" |
|
|
|
decoder_output = tokens[:] |
|
decoder_input = [0] * len(decoder_output) |
|
decoder_input[0] = sep_id |
|
|
|
decoder_input[1:] = decoder_output[:-1] |
|
|
|
""" |
|
decoder input [CSL] A B C D [SEP] |
|
decoder output A B C D [SEP] [PAD] |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
return decoder_input, decoder_output |
|
|
|
|
|
|
|
def build_training_sample(sample, |
|
target_seq_length, max_seq_length, |
|
vocab_id_list, vocab_id_to_token_dict, |
|
cls_id, sep_id, mask_id, pad_id, |
|
masked_lm_prob, np_rng, binary_head): |
|
|
|
""" |
|
sent-level task |
|
""" |
|
is_next_random, truncated, max_num_tokens, tokens_a, tokens_b = sent_level_task( |
|
binary_head, sample, target_seq_length, max_seq_length, np_rng) |
|
tokens_bf_mask = create_tokens(tokens_a, tokens_b, cls_id, sep_id) |
|
if is_next_random: |
|
raw_tokens = create_tokens(tokens_b, tokens_a, cls_id, sep_id) |
|
else: |
|
raw_tokens = tokens_bf_mask[:] |
|
|
|
""" |
|
decoder-input and output |
|
""" |
|
decoder_input, decoder_output = generate_decoder_input_and_output(raw_tokens, pad_id, sep_id) |
|
|
|
|
|
|
|
encoder_loss_flag = 0 |
|
decoder_loss_flag = 0 |
|
sent_loss_flag = 1 |
|
encoder_rng = torch.rand(1).item() |
|
me = MaskEncoder() |
|
if encoder_rng < 1.1: |
|
|
|
|
|
if 0: |
|
max_predictions_per_seq = masked_lm_prob * max_num_tokens |
|
(tokens, _, _, _, _) = create_masked_lm_predictions( |
|
tokens_bf_mask, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, |
|
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, masking_style="t5") |
|
if 1 : |
|
tokens = torch.LongTensor(tokens_bf_mask) |
|
tokens = me.add_whole_word_mask(tokens, 0.15, -1) |
|
tokens = tokens.tolist() |
|
shift_rng = torch.rand(1).item() |
|
if shift_rng < 0.0: |
|
tokens = me.shif_chinese_word(tokens, tokens_bf_mask) |
|
encoder_loss_flag = 1 |
|
decoder_loss_flag = 1 |
|
else: |
|
|
|
tokens = torch.LongTensor(tokens_bf_mask) |
|
decoder_rng = torch.rand(1).item() |
|
if decoder_rng < 0.4: |
|
|
|
tokens = me.add_whole_word_mask(tokens, 0.3, -1) |
|
tokens = tokens.tolist() |
|
if decoder_rng >= 0.4 and decoder_rng < 0.6: |
|
|
|
tokens = me.mass_style_mask(tokens_bf_mask) |
|
if decoder_rng > 0.6: |
|
|
|
tokens = me.add_whole_word_mask(tokens, 0.3, -1) |
|
tokens = tokens.tolist() |
|
tokens = me.delete_chinese_word(tokens, tokens_bf_mask) |
|
tmp_tt = get_tokenizer() |
|
|
|
|
|
|
|
|
|
|
|
decoder_loss_flag = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokentypes = [] |
|
encoder_labels = [] |
|
encoder_labels_mask = [] |
|
padding_mask = [] |
|
apppend_type_id = 0 |
|
|
|
if len(tokens) == len(tokens_bf_mask): |
|
|
|
for index in range(len(tokens)): |
|
padding_mask.append(1) |
|
|
|
if tokens[index] == sep_id: |
|
apppend_type_id = 1 |
|
tokentypes.append(apppend_type_id) |
|
|
|
if tokens[index] == tokens_bf_mask[index]: |
|
encoder_labels.append(-1) |
|
encoder_labels_mask.append(0) |
|
else: |
|
encoder_labels.append(tokens_bf_mask[index]) |
|
encoder_labels_mask.append(1) |
|
else: |
|
|
|
for index in range(len(tokens)): |
|
padding_mask.append(1) |
|
if tokens[index] == sep_id: |
|
apppend_type_id = 1 |
|
tokentypes.append(apppend_type_id) |
|
encoder_labels.append(-1) |
|
encoder_labels_mask.append(0) |
|
|
|
tokens_np = pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id) |
|
tokentypes_np = pad_and_convert_to_numpy_light(tokentypes, max_seq_length, pad_id) |
|
padding_mask_np = pad_and_convert_to_numpy_light(padding_mask, max_seq_length, pad_id) |
|
encoder_labels_np = pad_and_convert_to_numpy_light(encoder_labels, max_seq_length, -1) |
|
encoder_labels_mask_np = pad_and_convert_to_numpy_light(encoder_labels_mask, max_seq_length, pad_id) |
|
decoder_input_np = pad_and_convert_to_numpy_light(decoder_input, max_seq_length, pad_id) |
|
decoder_output_np = pad_and_convert_to_numpy_light(decoder_output, max_seq_length, pad_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_sample = { |
|
'text': tokens_np, |
|
'types': tokentypes_np, |
|
'is_random': int(is_next_random), |
|
'truncated': int(truncated), |
|
'labels': encoder_labels_np, |
|
'loss_mask': encoder_labels_mask_np, |
|
'padding_mask': padding_mask_np, |
|
'decoder_input': decoder_input_np, |
|
'decoder_output': decoder_output_np, |
|
'encoder_loss_flag': int(encoder_loss_flag), |
|
'decoder_loss_flag': int(decoder_loss_flag), |
|
'sent_loss_flag': int(sent_loss_flag), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return train_sample |
|
|
|
def pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id): |
|
padding_length = max_seq_length - len(tokens) |
|
assert padding_length >= 0 |
|
filler = [pad_id] * padding_length |
|
tokens_np = np.array(tokens + filler, dtype=np.int64) |
|
return tokens_np |
|
|
|
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, |
|
masked_labels, pad_id, max_seq_length): |
|
"""Pad sequences and convert them to numpy.""" |
|
|
|
|
|
num_tokens = len(tokens) |
|
padding_length = max_seq_length - num_tokens |
|
assert padding_length >= 0 |
|
assert len(tokentypes) == num_tokens |
|
assert len(masked_positions) == len(masked_labels) |
|
|
|
|
|
filler = [pad_id] * padding_length |
|
tokens_np = np.array(tokens + filler, dtype=np.int64) |
|
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) |
|
|
|
|
|
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, |
|
dtype=np.int64) |
|
|
|
|
|
labels = [-1] * max_seq_length |
|
loss_mask = [0] * max_seq_length |
|
for i in range(len(masked_positions)): |
|
assert masked_positions[i] < num_tokens |
|
labels[masked_positions[i]] = masked_labels[i] |
|
loss_mask[masked_positions[i]] = 1 |
|
labels_np = np.array(labels, dtype=np.int64) |
|
loss_mask_np = np.array(loss_mask, dtype=np.int64) |
|
|
|
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np |
|
|