|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
import numpy as np |
|
from typing import List, Tuple |
|
from itertools import zip_longest, cycle |
|
from functools import partial |
|
|
|
from megatron import mpu, print_rank_0 |
|
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset |
|
from megatron.data.blendable_dataset import BlendableDataset |
|
from megatron.data.gpt2_dataset import GPT2Dataset |
|
from megatron.data.pairwise_dataset import PairwiseDataset |
|
from megatron.data.samplers import DistributedBatchSampler |
|
|
|
|
|
def make_data_loader(dataset, neox_args): |
|
"""Build dataloader given an input dataset.""" |
|
if dataset is None: |
|
return None |
|
|
|
world_size = mpu.get_data_parallel_world_size() |
|
rank = mpu.get_data_parallel_rank() |
|
global_batch_size = neox_args.batch_size * world_size |
|
num_workers = neox_args.num_workers |
|
|
|
|
|
sampler = torch.utils.data.SequentialSampler(dataset) |
|
batch_sampler = DistributedBatchSampler( |
|
sampler=sampler, |
|
batch_size=global_batch_size, |
|
drop_last=True, |
|
rank=rank, |
|
world_size=world_size, |
|
) |
|
|
|
return torch.utils.data.DataLoader( |
|
dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True |
|
) |
|
|
|
|
|
def build_the_dataset( |
|
data_prefix, |
|
pos_data_prefix, |
|
neg_data_prefix, |
|
name, |
|
data_impl, |
|
pack_impl, |
|
dataset_impl, |
|
allow_chopped, |
|
num_samples, |
|
num_epochs, |
|
seq_length, |
|
seed, |
|
skip_warmup, |
|
build_index_mappings=True, |
|
label_prefix=None, |
|
pos_label_prefix=None, |
|
neg_label_prefix=None, |
|
precompute_model_name=None, |
|
reward_prefix=None, |
|
): |
|
"""Build train/valid/test datasets.""" |
|
if dataset_impl == "gpt2": |
|
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) |
|
if label_prefix is None: |
|
label_dataset = None |
|
else: |
|
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) |
|
if precompute_model_name is not None: |
|
|
|
precompute_indexed_dataset = make_indexed_dataset( |
|
data_prefix + "_" + precompute_model_name, data_impl, skip_warmup |
|
) |
|
precompute_indexed_dataset = precompute_indexed_dataset |
|
else: |
|
precompute_indexed_dataset = None |
|
if reward_prefix is not None: |
|
reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup) |
|
else: |
|
reward_dataset = None |
|
elif dataset_impl == "pairwise": |
|
pos_indexed_dataset = make_indexed_dataset( |
|
pos_data_prefix, data_impl, skip_warmup |
|
) |
|
neg_indexed_dataset = make_indexed_dataset( |
|
neg_data_prefix, data_impl, skip_warmup |
|
) |
|
if pos_label_prefix is None: |
|
pos_label_dataset = None |
|
|
|
assert neg_label_prefix is None |
|
neg_label_dataset = None |
|
else: |
|
pos_label_dataset = make_indexed_dataset( |
|
pos_label_prefix, data_impl, skip_warmup |
|
) |
|
|
|
assert neg_label_prefix is not None |
|
neg_label_dataset = make_indexed_dataset( |
|
neg_label_prefix, data_impl, skip_warmup |
|
) |
|
if precompute_model_name is None: |
|
pos_ref_dataset = None |
|
neg_ref_dataset = None |
|
else: |
|
pos_ref_dataset = make_indexed_dataset( |
|
pos_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup |
|
) |
|
neg_ref_dataset = make_indexed_dataset( |
|
neg_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup |
|
) |
|
else: |
|
raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented") |
|
|
|
total_num_of_documents = ( |
|
indexed_dataset.sizes.shape[0] |
|
if dataset_impl == "gpt2" |
|
else pos_indexed_dataset.sizes.shape[0] |
|
) |
|
print_rank_0(" {}:".format(name)) |
|
print_rank_0(" no. of documents:{}".format(total_num_of_documents)) |
|
dataset = None |
|
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) |
|
if dataset_impl == "gpt2": |
|
dataset = GPT2Dataset( |
|
name, |
|
data_prefix, |
|
documents, |
|
indexed_dataset, |
|
num_samples, |
|
num_epochs, |
|
seq_length, |
|
seed, |
|
pack_impl=pack_impl, |
|
allow_chopped=allow_chopped, |
|
build_index_mappings=build_index_mappings, |
|
label_dataset=label_dataset, |
|
reward_dataset=reward_dataset, |
|
ref_dataset=precompute_indexed_dataset, |
|
) |
|
elif dataset_impl == "pairwise": |
|
dataset = PairwiseDataset( |
|
name, |
|
pos_data_prefix, |
|
documents, |
|
pos_indexed_dataset, |
|
neg_indexed_dataset, |
|
num_samples, |
|
seq_length, |
|
seed, |
|
pack_impl=pack_impl, |
|
allow_chopped=allow_chopped, |
|
build_index_mappings=build_index_mappings, |
|
pos_label_dataset=pos_label_dataset, |
|
neg_label_dataset=neg_label_dataset, |
|
pos_ref_dataset=pos_ref_dataset, |
|
neg_ref_dataset=neg_ref_dataset, |
|
) |
|
return dataset |
|
|
|
|
|
def build_train_valid_test_datasets( |
|
data_prefix, |
|
use_shared_fs, |
|
data_impl, |
|
pack_impl, |
|
allow_chopped, |
|
splits_string, |
|
train_valid_test_num_samples, |
|
train_valid_test_epochs, |
|
seq_length, |
|
seed, |
|
skip_warmup, |
|
): |
|
"""Build train, valid, and test datasets.""" |
|
|
|
|
|
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) |
|
|
|
total_num_of_documents = indexed_dataset.sizes.shape[0] |
|
splits = get_train_valid_test_split_(splits_string, total_num_of_documents) |
|
|
|
|
|
print_rank_0(" > dataset split:") |
|
|
|
def print_split_stats(name, index): |
|
print_rank_0(" {}:".format(name)) |
|
print_rank_0( |
|
" document indices in [{}, {}) total of {} " |
|
"documents".format( |
|
splits[index], splits[index + 1], splits[index + 1] - splits[index] |
|
) |
|
) |
|
|
|
print_split_stats("train", 0) |
|
print_split_stats("validation", 1) |
|
print_split_stats("test", 2) |
|
|
|
def build_dataset(index, name): |
|
dataset = None |
|
if splits[index + 1] > splits[index]: |
|
documents = np.arange( |
|
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 |
|
) |
|
dataset = GPT2Dataset( |
|
name, |
|
data_prefix, |
|
documents, |
|
indexed_dataset, |
|
train_valid_test_num_samples[index], |
|
train_valid_test_epochs[index], |
|
seq_length, |
|
seed, |
|
pack_impl=pack_impl, |
|
allow_chopped=allow_chopped, |
|
use_shared_fs=use_shared_fs, |
|
) |
|
return dataset |
|
|
|
train_dataset = build_dataset(0, "train") |
|
valid_dataset = build_dataset(1, "valid") |
|
test_dataset = build_dataset(2, "test") |
|
|
|
return train_dataset, valid_dataset, test_dataset |
|
|
|
|
|
def get_train_valid_test_split_(splits_string, size): |
|
"""Get dataset splits from comma or '/' separated string list.""" |
|
|
|
splits = [] |
|
if splits_string.find(",") != -1: |
|
splits = [float(s) for s in splits_string.split(",")] |
|
elif splits_string.find("/") != -1: |
|
splits = [float(s) for s in splits_string.split("/")] |
|
else: |
|
splits = [float(splits_string)] |
|
while len(splits) < 3: |
|
splits.append(0.0) |
|
splits = splits[:3] |
|
splits_sum = sum(splits) |
|
assert splits_sum > 0.0 |
|
splits = [split / splits_sum for split in splits] |
|
splits_index = [0] |
|
for index, split in enumerate(splits): |
|
splits_index.append(splits_index[index] + int(round(split * float(size)))) |
|
diff = splits_index[-1] - size |
|
for index in range(1, len(splits_index)): |
|
splits_index[index] -= diff |
|
assert len(splits_index) == 4 |
|
assert splits_index[-1] == size |
|
return splits_index |
|
|
|
|
|
def get_normalized_weights_and_num_samples( |
|
weights: List[float], num_samples: int |
|
) -> Tuple[List[float], List[int]]: |
|
|
|
weight_sum = sum(weights) |
|
assert weight_sum > 0.0 |
|
weights = [weight / weight_sum for weight in weights] |
|
if num_samples is not None: |
|
|
|
|
|
|
|
weighted_num_samples = [] |
|
for weight in weights: |
|
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) |
|
else: |
|
weighted_num_samples = [None for _ in weights] |
|
return weights, weighted_num_samples |
|
|
|
|
|
def build_weighted_datasets( |
|
neox_args, |
|
train_num_samples, |
|
valid_num_samples, |
|
test_num_samples, |
|
train_epochs, |
|
valid_epochs, |
|
test_epochs, |
|
build_index_mappings=True, |
|
): |
|
|
|
train_datasets, valid_datasets, test_datasets = [], [], [] |
|
for i, ( |
|
train_path, |
|
train_label_path, |
|
train_reward_path, |
|
valid_path, |
|
valid_label_path, |
|
valid_reward_path, |
|
test_path, |
|
test_label_path, |
|
test_reward_path, |
|
pos_train_path, |
|
neg_train_path, |
|
pos_train_label_path, |
|
neg_train_label_path, |
|
pos_valid_path, |
|
neg_valid_path, |
|
pos_valid_label_path, |
|
neg_valid_label_path, |
|
pos_test_path, |
|
neg_test_path, |
|
pos_test_label_path, |
|
neg_test_label_path, |
|
) in enumerate( |
|
zip_longest( |
|
neox_args.train_data_paths if neox_args.train_data_paths else [], |
|
neox_args.train_label_data_paths |
|
if neox_args.train_label_data_paths |
|
else [], |
|
neox_args.train_reward_data_paths |
|
if neox_args.train_reward_data_paths |
|
else [], |
|
neox_args.valid_data_paths if neox_args.valid_data_paths else [], |
|
neox_args.valid_label_data_paths |
|
if neox_args.valid_label_data_paths |
|
else [], |
|
neox_args.valid_reward_data_paths |
|
if neox_args.valid_reward_data_paths |
|
else [], |
|
neox_args.test_data_paths if neox_args.test_data_paths else [], |
|
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], |
|
neox_args.test_reward_data_paths |
|
if neox_args.test_reward_data_paths |
|
else [], |
|
neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], |
|
neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], |
|
neox_args.pos_train_label_data_paths |
|
if neox_args.pos_train_label_data_paths |
|
else [], |
|
neox_args.neg_train_label_data_paths |
|
if neox_args.neg_train_label_data_paths |
|
else [], |
|
neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [], |
|
neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [], |
|
neox_args.pos_valid_label_data_paths |
|
if neox_args.pos_valid_label_data_paths |
|
else [], |
|
neox_args.neg_valid_label_data_paths |
|
if neox_args.neg_valid_label_data_paths |
|
else [], |
|
neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [], |
|
neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [], |
|
neox_args.pos_test_label_data_paths |
|
if neox_args.pos_test_label_data_paths |
|
else [], |
|
neox_args.neg_test_label_data_paths |
|
if neox_args.neg_test_label_data_paths |
|
else [], |
|
) |
|
): |
|
if train_path or pos_train_path: |
|
train_datasets.append( |
|
build_the_dataset( |
|
data_prefix=train_path, |
|
name=f"train_{i}", |
|
data_impl=neox_args.data_impl, |
|
pack_impl=neox_args.pack_impl, |
|
allow_chopped=neox_args.allow_chopped, |
|
num_samples=train_num_samples[i], |
|
num_epochs=train_epochs, |
|
seq_length=neox_args.seq_length, |
|
seed=neox_args.seed, |
|
skip_warmup=(not neox_args.mmap_warmup), |
|
build_index_mappings=build_index_mappings, |
|
label_prefix=train_label_path, |
|
dataset_impl=neox_args.dataset_impl, |
|
pos_data_prefix=pos_train_path, |
|
neg_data_prefix=neg_train_path, |
|
pos_label_prefix=pos_train_label_path, |
|
neg_label_prefix=neg_train_label_path, |
|
precompute_model_name=neox_args.precompute_model_name, |
|
reward_prefix=train_reward_path, |
|
) |
|
) |
|
|
|
if valid_path or pos_valid_path: |
|
valid_datasets.append( |
|
build_the_dataset( |
|
data_prefix=valid_path, |
|
name=f"valid_{i}", |
|
data_impl=neox_args.data_impl, |
|
pack_impl=neox_args.pack_impl, |
|
allow_chopped=neox_args.allow_chopped, |
|
num_samples=valid_num_samples[i], |
|
num_epochs=valid_epochs, |
|
seq_length=neox_args.seq_length, |
|
seed=neox_args.seed, |
|
skip_warmup=(not neox_args.mmap_warmup), |
|
build_index_mappings=build_index_mappings, |
|
label_prefix=valid_label_path, |
|
dataset_impl=neox_args.dataset_impl, |
|
pos_data_prefix=pos_valid_path, |
|
neg_data_prefix=neg_valid_path, |
|
pos_label_prefix=pos_valid_label_path, |
|
neg_label_prefix=neg_valid_label_path, |
|
precompute_model_name=neox_args.precompute_model_name, |
|
reward_prefix=valid_reward_path, |
|
) |
|
) |
|
|
|
if test_path or pos_test_path: |
|
test_datasets.append( |
|
build_the_dataset( |
|
data_prefix=test_path, |
|
name=f"test_{i}", |
|
data_impl=neox_args.data_impl, |
|
pack_impl=neox_args.pack_impl, |
|
allow_chopped=neox_args.allow_chopped, |
|
num_samples=test_num_samples[i], |
|
num_epochs=test_epochs, |
|
seq_length=neox_args.seq_length, |
|
seed=neox_args.seed, |
|
skip_warmup=(not neox_args.mmap_warmup), |
|
build_index_mappings=build_index_mappings, |
|
label_prefix=test_label_path, |
|
dataset_impl=neox_args.dataset_impl, |
|
pos_data_prefix=pos_test_path, |
|
neg_data_prefix=neg_test_path, |
|
pos_label_prefix=pos_test_label_path, |
|
neg_label_prefix=neg_test_label_path, |
|
precompute_model_name=neox_args.precompute_model_name, |
|
reward_prefix=test_reward_path, |
|
) |
|
) |
|
return train_datasets, valid_datasets, test_datasets |
|
|
|
|
|
def weights_by_num_docs(l: list, alpha=0.3): |
|
""" |
|
Builds weights from a multinomial distribution over groups of data according to the number of |
|
samples in each group. |
|
|
|
We sample from a group according to the probability p(L) ∝ |L| ** α, |
|
where p(L) is the probability of sampling from a given group, |
|
|L| is the number of examples in that datapoint, |
|
and α is a coefficient that acts to upsample data from underrepresented groups |
|
|
|
Hence α (`alpha`) allows us to control how much to 'boost' the probability of training on low-resource groups. |
|
|
|
See https://arxiv.org/abs/1911.02116 for more details |
|
""" |
|
if len(l) == 1: |
|
return [1.0] |
|
|
|
total_n_docs = sum(l) |
|
unbiased_sample_probs = [i / total_n_docs for i in l] |
|
|
|
probs = [i**alpha for i in unbiased_sample_probs] |
|
|
|
|
|
total = sum(probs) |
|
probs = [i / total for i in probs] |
|
|
|
|
|
unbiased_sample_probs_inverse = [1 - p for p in unbiased_sample_probs] |
|
weights = [p * p2 for p, p2 in zip(probs, unbiased_sample_probs_inverse)] |
|
|
|
|
|
total = sum(weights) |
|
weights = [i / total for i in weights] |
|
|
|
return weights |
|
|
|
|
|
def validate_train_epochs(neox_args): |
|
"""Check for unsupported neox_args when using train_epochs instead of train_iters""" |
|
if neox_args.train_epochs is None: |
|
return |
|
|
|
if neox_args.train_epochs and neox_args.train_iters: |
|
raise ValueError( |
|
"Cannot specify both train epochs and train iters simultaneously" |
|
) |
|
|
|
if neox_args.pack_impl != "packed": |
|
raise ValueError( |
|
"Packing implementations other than 'packed' are currently unsupported with train_epochs" |
|
) |
|
|
|
if neox_args.weight_by_num_documents: |
|
raise ValueError( |
|
"Weighting by number of documents is currently unsupported with train_epochs" |
|
) |
|
|
|
if neox_args.train_data_weights and ( |
|
not all(weight == 1.0 for weight in neox_args.train_data_weights) |
|
): |
|
raise ValueError( |
|
"train_data_weights != None is currently unsupported with train_epochs" |
|
) |
|
|
|
if neox_args.dataset_impl != "gpt2": |
|
raise ValueError( |
|
"non gpt2 datasets are not currently unsupported with train_epochs" |
|
) |
|
|
|
|
|
def build_train_valid_test_data_loaders(neox_args): |
|
"""XXX""" |
|
|
|
validate_train_epochs(neox_args) |
|
|
|
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) |
|
|
|
print_rank_0("> building train, validation, and test datasets ...") |
|
|
|
|
|
if neox_args.is_pipe_parallel: |
|
is_first_stage = mpu.get_pipe_parallel_rank() == 0 |
|
is_last_stage = ( |
|
mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1 |
|
) |
|
pipe_load = is_first_stage or is_last_stage |
|
else: |
|
pipe_load = True |
|
|
|
|
|
if mpu.get_model_parallel_rank() == 0 and pipe_load: |
|
|
|
if neox_args.train_iters is not None: |
|
train_iters = neox_args.train_iters |
|
eval_iters = ( |
|
train_iters // neox_args.eval_interval + 1 |
|
) * neox_args.eval_iters |
|
test_iters = neox_args.eval_iters |
|
train_val_test_num_samples = [ |
|
train_iters * neox_args.train_batch_size, |
|
eval_iters * neox_args.train_batch_size, |
|
test_iters * neox_args.train_batch_size, |
|
] |
|
train_val_test_epochs = [None, None, None] |
|
elif neox_args.train_epochs is not None: |
|
train_val_test_num_samples = [None, None, None] |
|
train_val_test_epochs = [1, 1, 1] |
|
|
|
if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths): |
|
|
|
|
|
train_weights, train_num_samples = get_normalized_weights_and_num_samples( |
|
neox_args.train_data_weights, train_val_test_num_samples[0] |
|
) |
|
valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( |
|
neox_args.valid_data_weights, train_val_test_num_samples[1] |
|
) |
|
test_weights, test_num_samples = get_normalized_weights_and_num_samples( |
|
neox_args.test_data_weights, train_val_test_num_samples[2] |
|
) |
|
|
|
|
|
train_datasets, valid_datasets, test_datasets = build_weighted_datasets( |
|
neox_args, |
|
train_num_samples, |
|
valid_num_samples, |
|
test_num_samples, |
|
train_val_test_epochs[0], |
|
train_val_test_epochs[1], |
|
train_val_test_epochs[2], |
|
build_index_mappings=not neox_args.weight_by_num_documents, |
|
) |
|
|
|
if neox_args.weight_by_num_documents: |
|
|
|
get_num_docs_list = lambda datasets: [ |
|
dataset.indexed_dataset.sizes.shape[0] for dataset in datasets |
|
] |
|
train_num_docs, valid_num_docs, test_num_docs = ( |
|
get_num_docs_list(train_datasets), |
|
get_num_docs_list(valid_datasets), |
|
get_num_docs_list(test_datasets), |
|
) |
|
|
|
|
|
fn = partial( |
|
weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha |
|
) |
|
train_weights, valid_weights, test_weights = ( |
|
fn(train_num_docs), |
|
fn(valid_num_docs), |
|
fn(test_num_docs), |
|
) |
|
( |
|
train_weights, |
|
train_num_samples, |
|
) = get_normalized_weights_and_num_samples( |
|
train_weights, train_val_test_num_samples[0] |
|
) |
|
( |
|
valid_weights, |
|
valid_num_samples, |
|
) = get_normalized_weights_and_num_samples( |
|
valid_weights, train_val_test_num_samples[1] |
|
) |
|
test_weights, test_num_samples = get_normalized_weights_and_num_samples( |
|
test_weights, train_val_test_num_samples[2] |
|
) |
|
|
|
|
|
train_datasets, valid_datasets, test_datasets = build_weighted_datasets( |
|
neox_args, |
|
train_num_samples, |
|
valid_num_samples, |
|
test_num_samples, |
|
train_val_test_epochs[0], |
|
train_val_test_epochs[1], |
|
train_val_test_epochs[2], |
|
) |
|
|
|
if train_datasets: |
|
train_ds = BlendableDataset(train_datasets, train_weights) |
|
if valid_datasets: |
|
valid_ds = BlendableDataset(valid_datasets, valid_weights) |
|
if test_datasets: |
|
test_ds = BlendableDataset(test_datasets, test_weights) |
|
else: |
|
|
|
|
|
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( |
|
data_prefix=neox_args.data_path, |
|
use_shared_fs=neox_args.use_shared_fs, |
|
data_impl=neox_args.data_impl, |
|
splits_string=neox_args.split, |
|
train_valid_test_num_samples=train_val_test_num_samples, |
|
train_valid_test_epochs=train_val_test_epochs, |
|
seq_length=neox_args.seq_length, |
|
seed=neox_args.seed, |
|
skip_warmup=(not neox_args.mmap_warmup), |
|
pack_impl=neox_args.pack_impl, |
|
allow_chopped=neox_args.allow_chopped, |
|
) |
|
|
|
|
|
train_dataloader = make_data_loader(train_ds, neox_args=neox_args) |
|
valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args) |
|
test_dataloader = make_data_loader(test_ds, neox_args=neox_args) |
|
|
|
|
|
if neox_args.train_epochs: |
|
do_train = train_dataloader is not None |
|
do_valid = valid_dataloader is not None |
|
do_test = test_dataloader is not None |
|
else: |
|
do_train = train_dataloader is not None and neox_args.train_iters > 0 |
|
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 |
|
do_test = test_dataloader is not None and neox_args.eval_iters > 0 |
|
|
|
|
|
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) |
|
else: |
|
flags = torch.cuda.LongTensor([0, 0, 0]) |
|
|
|
|
|
if neox_args.is_pipe_parallel: |
|
|
|
|
|
torch.distributed.broadcast(flags, src=0) |
|
else: |
|
torch.distributed.broadcast( |
|
flags, |
|
mpu.get_model_parallel_src_rank(), |
|
group=mpu.get_model_parallel_group(), |
|
) |
|
neox_args.do_train = flags[0].item() |
|
neox_args.do_valid = flags[1].item() |
|
neox_args.do_test = flags[2].item() |
|
data_loaders = { |
|
"train": train_dataloader, |
|
"valid": valid_dataloader, |
|
"test": test_dataloader, |
|
} |
|
return data_loaders |
|
|
|
|
|
def shift_and_wrap_data_loaders(neox_args, data_loaders, loop=True): |
|
"""Shift start iteration and wrap data_loaders in iterators""" |
|
train_dataloader = data_loaders["train"] |
|
valid_dataloader = data_loaders["valid"] |
|
test_dataloader = data_loaders["test"] |
|
|
|
|
|
if train_dataloader is not None: |
|
train_dataloader.batch_sampler.start_iter = ( |
|
neox_args.iteration * neox_args.gradient_accumulation_steps |
|
) % len(train_dataloader) |
|
print_rank_0( |
|
"setting training data start iteration to {}".format( |
|
train_dataloader.batch_sampler.start_iter |
|
) |
|
) |
|
if valid_dataloader is not None: |
|
start_iter_val = ( |
|
(neox_args.iteration * neox_args.gradient_accumulation_steps) |
|
// neox_args.eval_interval |
|
) * neox_args.eval_iters |
|
valid_dataloader.batch_sampler.start_iter = start_iter_val % len( |
|
valid_dataloader |
|
) |
|
print_rank_0( |
|
"setting validation data start iteration to {}".format( |
|
valid_dataloader.batch_sampler.start_iter |
|
) |
|
) |
|
|
|
def loop_iterator(data_loader): |
|
while True: |
|
for x in data_loader: |
|
yield x |
|
data_loader.start_iter = 0 |
|
|
|
|
|
if train_dataloader is not None: |
|
if loop: |
|
train_data_iterator = cycle(train_dataloader) |
|
else: |
|
train_data_iterator = iter(train_dataloader) |
|
else: |
|
train_data_iterator = None |
|
|
|
if valid_dataloader is not None: |
|
if loop: |
|
valid_data_iterator = cycle(valid_dataloader) |
|
else: |
|
valid_data_iterator = iter(valid_dataloader) |
|
else: |
|
valid_data_iterator = None |
|
|
|
if test_dataloader is not None: |
|
if loop: |
|
test_data_iterator = cycle(test_dataloader) |
|
else: |
|
test_data_iterator = iter(test_dataloader) |
|
else: |
|
test_data_iterator = None |
|
|
|
return train_data_iterator, valid_data_iterator, test_data_iterator |
|
|
|
|
|
def compile_helper(): |
|
"""Compile helper function at runtime. Make sure this |
|
is invoked on a single process.""" |
|
import os |
|
import subprocess |
|
|
|
path = os.path.abspath(os.path.dirname(__file__)) |
|
ret = subprocess.run(["make", "-C", path]) |
|
if ret.returncode != 0: |
|
print("Making C++ dataset helpers module failed, exiting.") |
|
import sys |
|
|
|
sys.exit(1) |
|
|