|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
from distutils.version import LooseVersion |
|
import io |
|
import operator |
|
import tempfile |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
import torchtext |
|
from torchtext.data.utils import get_tokenizer |
|
from torchtext.utils import download_from_url, extract_archive |
|
|
|
if operator.ge(torchtext.__version__, LooseVersion("0.10.0")): |
|
from torchtext.legacy.vocab import build_vocab_from_iterator |
|
else: |
|
from torchtext.vocab import build_vocab_from_iterator |
|
|
|
|
|
def _batchify(data, batch_size): |
|
data = torch.tensor(data) |
|
|
|
nbatch = data.size(0) // batch_size |
|
|
|
data = data.narrow(0, 0, nbatch * batch_size) |
|
|
|
data = data.view(batch_size, -1).t().contiguous() |
|
return data |
|
|
|
|
|
def _get_total_batch_size(benchmark_config, model_specs): |
|
return model_specs["seq_len"] * benchmark_config["batch_size"] |
|
|
|
|
|
DatasetsInfo = namedtuple("DataSetsInfo", ["ntokens", "train_dataset", "valid_dataset", "test_dataset"]) |
|
|
|
|
|
def get_real_datasets(): |
|
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" |
|
tmpdir = tempfile.TemporaryDirectory() |
|
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name)) |
|
tokenizer = get_tokenizer("basic_english") |
|
|
|
def data_process(raw_text_iter): |
|
data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter] |
|
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) |
|
|
|
vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8")))) |
|
|
|
train_dataset = data_process(iter(io.open(train_filepath, encoding="utf8"))) |
|
valid_dataset = data_process(iter(io.open(valid_filepath, encoding="utf8"))) |
|
test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8"))) |
|
return DatasetsInfo(len(vocab.stoi), train_dataset, valid_dataset, test_dataset) |
|
|
|
|
|
def get_dataloaders(datasets_info, benchmark_config, model_specs, num_replicas=1, rank=0): |
|
ntokens, train_dataset, valid_dataset, test_dataset = datasets_info |
|
|
|
def batchify(data): |
|
batch_size = benchmark_config["batch_size"] |
|
return _batchify(data, batch_size) |
|
|
|
total_batch_size = _get_total_batch_size(benchmark_config, model_specs) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
sampler=DistributedSampler(train_dataset, num_replicas=num_replicas, rank=rank), |
|
batch_size=total_batch_size, |
|
collate_fn=batchify, |
|
) |
|
valid_dataloader = DataLoader( |
|
valid_dataset, |
|
sampler=DistributedSampler(valid_dataset, num_replicas=num_replicas, rank=rank), |
|
batch_size=total_batch_size, |
|
collate_fn=batchify, |
|
) |
|
test_dataloader = DataLoader( |
|
test_dataset, |
|
sampler=DistributedSampler(test_dataset, num_replicas=num_replicas, rank=rank), |
|
batch_size=total_batch_size, |
|
collate_fn=batchify, |
|
) |
|
return train_dataloader, valid_dataloader, test_dataloader |
|
|
|
|
|
def get_real_dataloaders(args, benchmark_config, model_specs, num_replicas=1, rank=0): |
|
"""Return real dataloaders for training, testing and validation.""" |
|
dataset_info = get_real_datasets() |
|
train_dataloader, valid_dataloader, test_dataloader = get_dataloaders( |
|
dataset_info, benchmark_config, model_specs, num_replicas, rank |
|
) |
|
return dataset_info.ntokens, train_dataloader, valid_dataloder, test_dataloader |
|
|
|
|
|
def get_synthetic_datasets(): |
|
|
|
lm_dataset = torch.randint(1, 10000, (2049990,)) |
|
return DatasetsInfo(10000, lm_dataset, lm_dataset, lm_dataset) |
|
|
|
|
|
def get_synthetic_dataloaders(args, benchmark_config, model_specs, num_replicas=1, rank=0): |
|
"""Return synthetic dataloaders for training, testing and validation.""" |
|
return get_dataloaders(get_synthetic_datasets(), benchmark_config, model_specs, num_replicas, rank) |
|
|