asgeorges's picture
Upload folder using huggingface_hub
2fdce3c
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
from distutils.version import LooseVersion
import io
import operator
import tempfile
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
if operator.ge(torchtext.__version__, LooseVersion("0.10.0")):
from torchtext.legacy.vocab import build_vocab_from_iterator
else:
from torchtext.vocab import build_vocab_from_iterator
def _batchify(data, batch_size):
data = torch.tensor(data)
# Divide the dataset into bsz parts.
nbatch = data.size(0) // batch_size
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * batch_size)
# Evenly divide the data across the bsz batches.
data = data.view(batch_size, -1).t().contiguous()
return data
def _get_total_batch_size(benchmark_config, model_specs):
return model_specs["seq_len"] * benchmark_config["batch_size"]
DatasetsInfo = namedtuple("DataSetsInfo", ["ntokens", "train_dataset", "valid_dataset", "test_dataset"])
def get_real_datasets():
url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
tmpdir = tempfile.TemporaryDirectory()
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name))
tokenizer = get_tokenizer("basic_english")
def data_process(raw_text_iter):
data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8"))))
train_dataset = data_process(iter(io.open(train_filepath, encoding="utf8")))
valid_dataset = data_process(iter(io.open(valid_filepath, encoding="utf8")))
test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8")))
return DatasetsInfo(len(vocab.stoi), train_dataset, valid_dataset, test_dataset)
def get_dataloaders(datasets_info, benchmark_config, model_specs, num_replicas=1, rank=0):
ntokens, train_dataset, valid_dataset, test_dataset = datasets_info
def batchify(data):
batch_size = benchmark_config["batch_size"]
return _batchify(data, batch_size)
total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
train_dataloader = DataLoader(
train_dataset,
sampler=DistributedSampler(train_dataset, num_replicas=num_replicas, rank=rank),
batch_size=total_batch_size,
collate_fn=batchify,
)
valid_dataloader = DataLoader(
valid_dataset,
sampler=DistributedSampler(valid_dataset, num_replicas=num_replicas, rank=rank),
batch_size=total_batch_size,
collate_fn=batchify,
)
test_dataloader = DataLoader(
test_dataset,
sampler=DistributedSampler(test_dataset, num_replicas=num_replicas, rank=rank),
batch_size=total_batch_size,
collate_fn=batchify,
)
return train_dataloader, valid_dataloader, test_dataloader
def get_real_dataloaders(args, benchmark_config, model_specs, num_replicas=1, rank=0):
"""Return real dataloaders for training, testing and validation."""
dataset_info = get_real_datasets()
train_dataloader, valid_dataloader, test_dataloader = get_dataloaders(
dataset_info, benchmark_config, model_specs, num_replicas, rank
)
return dataset_info.ntokens, train_dataloader, valid_dataloder, test_dataloader
def get_synthetic_datasets():
# vocab_size is 10000 and length of the real data is 2049990.
lm_dataset = torch.randint(1, 10000, (2049990,))
return DatasetsInfo(10000, lm_dataset, lm_dataset, lm_dataset)
def get_synthetic_dataloaders(args, benchmark_config, model_specs, num_replicas=1, rank=0):
"""Return synthetic dataloaders for training, testing and validation."""
return get_dataloaders(get_synthetic_datasets(), benchmark_config, model_specs, num_replicas, rank)