File size: 1,386 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import logging
import os
from multiprocessing import Process
from pathlib import Path

import torch

from s3prl.util.download import _urls_to_filepaths

logger = logging.getLogger(__name__)
URL = "https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt"


def _download_with_timeout(timeout: float, num_process: int):
    processes = []
    for _ in range(num_process):
        process = Process(
            target=_urls_to_filepaths, args=(URL,), kwargs=dict(refresh=True)
        )
        process.start()
        processes.append(process)

    exitcodes = []
    for process in processes:
        process.join(timeout=timeout)
        exitcodes.append(process.exitcode)
    assert len(set(exitcodes)) == 1
    exitcode = exitcodes[0]

    if exitcode != 0:
        for process in processes:
            process.terminate()


def test_download():
    filepath = Path(_urls_to_filepaths(URL, download=False))
    if filepath.is_file():
        os.remove(filepath)

    logger.info("This should timeout")
    _download_with_timeout(0.1, 2)
    assert not filepath.is_file(), (
        "The download should failed due to the too short timeout second: 0.1 sec, "
        "and hence there should not be any corrupted (incomplete) file"
    )

    logger.info("This should success")
    _download_with_timeout(None, 2)
    torch.load(filepath, map_location="cpu")