File size: 1,386 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import logging
import os
from multiprocessing import Process
from pathlib import Path
import torch
from s3prl.util.download import _urls_to_filepaths
logger = logging.getLogger(__name__)
URL = "https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt"
def _download_with_timeout(timeout: float, num_process: int):
processes = []
for _ in range(num_process):
process = Process(
target=_urls_to_filepaths, args=(URL,), kwargs=dict(refresh=True)
)
process.start()
processes.append(process)
exitcodes = []
for process in processes:
process.join(timeout=timeout)
exitcodes.append(process.exitcode)
assert len(set(exitcodes)) == 1
exitcode = exitcodes[0]
if exitcode != 0:
for process in processes:
process.terminate()
def test_download():
filepath = Path(_urls_to_filepaths(URL, download=False))
if filepath.is_file():
os.remove(filepath)
logger.info("This should timeout")
_download_with_timeout(0.1, 2)
assert not filepath.is_file(), (
"The download should failed due to the too short timeout second: 0.1 sec, "
"and hence there should not be any corrupted (incomplete) file"
)
logger.info("This should success")
_download_with_timeout(None, 2)
torch.load(filepath, map_location="cpu")
|