# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from pathlib import Path | |
import shutil | |
import tempfile | |
from torchvision.datasets import MNIST | |
TEMPDIR = tempfile.gettempdir() | |
def setup_cached_mnist(): | |
done, tentatives = False, 0 | |
while not done and tentatives < 5: | |
# Monkey patch the resource URLs to work around a possible blacklist | |
MNIST.mirrors = ["https://github.com/blefaudeux/mnist_dataset/raw/main/"] + MNIST.mirrors | |
# This will automatically skip the download if the dataset is already there, and check the checksum | |
try: | |
_ = MNIST(transform=None, download=True, root=TEMPDIR) | |
done = True | |
except RuntimeError as e: | |
logging.warning(e) | |
mnist_root = Path(TEMPDIR + "/MNIST") | |
# Corrupted data, erase and restart | |
shutil.rmtree(str(mnist_root)) | |
tentatives += 1 | |
if done is False: | |
logging.error("Could not download MNIST dataset") | |
exit(-1) | |
else: | |
logging.info("Dataset downloaded") | |