yuyan-10b / megatron /data /glm_dataset.py
Shawn001's picture
Upload 131 files
23bd7af
raw
history blame
14.4 kB
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Style dataset."""
import numpy as np
import torch
from megatron import (
get_args,
get_tokenizer,
mpu,
print_rank_0
)
from megatron.data.dataset_utils import (
get_samples_mapping,
get_a_and_b_segments,
truncate_segments,
create_tokens_and_tokentypes,
create_tokens,
create_masked_lm_predictions,
MaskEncoder
)
class DummyBertDataset(torch.utils.data.Dataset):
def __init__(self, name, num_samples, max_seq_length):
self.name = name
self.num_samples = num_samples
self.max_seq_length = max_seq_length
self.np_rng = np.random.RandomState(seed=0)
# self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512))
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
tokens = self.np_rng.randint(1000, 2000, self.max_seq_length)
masked_position = np.arange(int(tokens.shape[0] * 0.15))
tokens = tokens.astype(np.int64)
labels = tokens[masked_position]
label_np = np.full_like(tokens, -1)
label_np[masked_position] = labels
tokens[masked_position] = self.mask_id
loss_mask_np = np.zeros_like(tokens)
loss_mask_np[masked_position] = 1
train_sample = {
'text': tokens,
'types': np.zeros_like(tokens),
'labels': label_np,
'is_random': 0,
'loss_mask': loss_mask_np,
'padding_mask': np.ones_like(tokens),
'truncated': 0
}
return train_sample
class GlmDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, binary_head):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
short_seq_prob,
self.seed,
self.name,
self.binary_head)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng,
self.binary_head)
def sent_level_task(binary_head, sample, target_seq_length, max_seq_length, np_rng):
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, np_rng)
return is_next_random, truncated, max_num_tokens, tokens_a, tokens_b
def generate_decoder_input_and_output(tokens, pad_id, sep_id):
"""
decoder input [SEP] [CSL] A B C D
decoder output [CLS] A B C D E
"""
decoder_output = tokens[:]
decoder_input = [0] * len(decoder_output)
decoder_input[0] = sep_id # match the preprocessing in fairseq
# decoder_input[0] = sep_id # match the preprocessing in fairseq
decoder_input[1:] = decoder_output[:-1]
"""
decoder input [CSL] A B C D [SEP]
decoder output A B C D [SEP] [PAD]
"""
# decoder_input = tokens[:]
# decoder_output = [0] * len(decoder_input)
# decoder_output[:-1] = decoder_input[1:]
# decoder_output[-1] = pad_id
return decoder_input, decoder_output
def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng, binary_head):
"""
sent-level task
"""
is_next_random, truncated, max_num_tokens, tokens_a, tokens_b = sent_level_task(
binary_head, sample, target_seq_length, max_seq_length, np_rng)
tokens_bf_mask = create_tokens(tokens_a, tokens_b, cls_id, sep_id)
if is_next_random:
raw_tokens = create_tokens(tokens_b, tokens_a, cls_id, sep_id)
else:
raw_tokens = tokens_bf_mask[:]
"""
decoder-input and output
"""
decoder_input, decoder_output = generate_decoder_input_and_output(raw_tokens, pad_id, sep_id)
# importance part
encoder_loss_flag = 0
decoder_loss_flag = 0
sent_loss_flag = 1
encoder_rng = torch.rand(1).item()
me = MaskEncoder()
if encoder_rng < 1.1:
# only train with encoder and decoder
# Masking.
if 0:
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, _, _, _, _) = create_masked_lm_predictions(
tokens_bf_mask, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, masking_style="t5")
if 1 :
tokens = torch.LongTensor(tokens_bf_mask)
tokens = me.add_whole_word_mask(tokens, 0.15, -1)
tokens = tokens.tolist()
shift_rng = torch.rand(1).item()
if shift_rng < 0.0:
tokens = me.shif_chinese_word(tokens, tokens_bf_mask)
encoder_loss_flag = 1
decoder_loss_flag = 1
else:
# train only with decoder
tokens = torch.LongTensor(tokens_bf_mask)
decoder_rng = torch.rand(1).item()
if decoder_rng < 0.4:
# WWM mask 30% word
tokens = me.add_whole_word_mask(tokens, 0.3, -1)
tokens = tokens.tolist()
if decoder_rng >= 0.4 and decoder_rng < 0.6:
# MASS mask style
tokens = me.mass_style_mask(tokens_bf_mask)
if decoder_rng > 0.6:
# delete tokens
tokens = me.add_whole_word_mask(tokens, 0.3, -1)
tokens = tokens.tolist()
tokens = me.delete_chinese_word(tokens, tokens_bf_mask)
tmp_tt = get_tokenizer()
# print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask)))
# print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens)))
# print("------\n\n")
decoder_loss_flag = 1
# tmp_tt = get_tokenizer()
# print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask)))
# print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens)))
# print("decoder input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_input)))
# print("decoder output", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_output)))
tokentypes = []
encoder_labels = []
encoder_labels_mask = []
padding_mask = []
apppend_type_id = 0
if len(tokens) == len(tokens_bf_mask):
# encoder and decoder can train togather
for index in range(len(tokens)):
padding_mask.append(1)
# generate tokens type
if tokens[index] == sep_id:
apppend_type_id = 1
tokentypes.append(apppend_type_id)
if tokens[index] == tokens_bf_mask[index]:
encoder_labels.append(-1)
encoder_labels_mask.append(0)
else:
encoder_labels.append(tokens_bf_mask[index])
encoder_labels_mask.append(1)
else:
# only train decoder
for index in range(len(tokens)):
padding_mask.append(1)
if tokens[index] == sep_id:
apppend_type_id = 1
tokentypes.append(apppend_type_id)
encoder_labels.append(-1)
encoder_labels_mask.append(0)
tokens_np = pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id)
tokentypes_np = pad_and_convert_to_numpy_light(tokentypes, max_seq_length, pad_id)
padding_mask_np = pad_and_convert_to_numpy_light(padding_mask, max_seq_length, pad_id)
encoder_labels_np = pad_and_convert_to_numpy_light(encoder_labels, max_seq_length, -1)
encoder_labels_mask_np = pad_and_convert_to_numpy_light(encoder_labels_mask, max_seq_length, pad_id)
decoder_input_np = pad_and_convert_to_numpy_light(decoder_input, max_seq_length, pad_id)
decoder_output_np = pad_and_convert_to_numpy_light(decoder_output, max_seq_length, pad_id)
# print(tokens_np)
# print(encoder_labels_np)
# print(padding_mask_np)
# print(encoder_labels_mask_np)
# generate tokentypes
train_sample = {
'text': tokens_np, # encoder_input
'types': tokentypes_np, # token_type
'is_random': int(is_next_random), #sop_labels
'truncated': int(truncated), # if truncated
'labels': encoder_labels_np, #encoder_labels
'loss_mask': encoder_labels_mask_np, # mlm_mask
'padding_mask': padding_mask_np, # padding_mask
'decoder_input': decoder_input_np, # decoder_input
'decoder_output': decoder_output_np, #decoder_output
'encoder_loss_flag': int(encoder_loss_flag),
'decoder_loss_flag': int(decoder_loss_flag),
'sent_loss_flag': int(sent_loss_flag),
}
# print(tokens_np.shape)
# print(tokens_np)
# print(tokentypes_np.shape)
# print(tokentypes_np)
# print(encoder_labels_np.shape)
# print(encoder_labels_np)
# print(encoder_labels_mask_np.shape)
# print(encoder_labels_mask_np)
# print(padding_mask_np.shape)
# print(padding_mask_np)
# print(decoder_input_np.shape)
# print(decoder_input_np)
# print(decoder_output_np.shape)
# print(decoder_output_np)
# print("=====\n\n\n")
# import sys;sys.exit(0)
return train_sample
def pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id):
padding_length = max_seq_length - len(tokens)
assert padding_length >= 0
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
return tokens_np
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np