asgeorges's picture
Upload folder using huggingface_hub
2fdce3c
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
import shutil
import tempfile
from torchvision.datasets import MNIST
TEMPDIR = tempfile.gettempdir()
def setup_cached_mnist():
done, tentatives = False, 0
while not done and tentatives < 5:
# Monkey patch the resource URLs to work around a possible blacklist
MNIST.mirrors = ["https://github.com/blefaudeux/mnist_dataset/raw/main/"] + MNIST.mirrors
# This will automatically skip the download if the dataset is already there, and check the checksum
try:
_ = MNIST(transform=None, download=True, root=TEMPDIR)
done = True
except RuntimeError as e:
logging.warning(e)
mnist_root = Path(TEMPDIR + "/MNIST")
# Corrupted data, erase and restart
shutil.rmtree(str(mnist_root))
tentatives += 1
if done is False:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")