File size: 1,750 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
from pathlib import Path
from typing import List

import torchaudio
from joblib import Parallel, delayed
from tqdm import tqdm

_default_cache_dir = Path.home() / ".cache" / "s3prl" / "audio_info"


__all__ = [
    "get_cache_dir",
    "set_cache_dir",
    "get_audio_info",
]


def get_cache_dir():
    _default_cache_dir.mkdir(exist_ok=True, parents=True)
    return _default_cache_dir


def set_cache_dir(cache_dir: str):
    global _default_cache_dir
    _default_cache_dir = Path(cache_dir)


def get_audio_info(
    audio_paths: List[str],
    audio_ids: List[str],
    cache_dir: str = None,
    num_workers: int = 6,
) -> List[dict]:
    """
    Use :code:`torchaudio.info` to retrieve the metadata from audio paths.
    The retrieved metadata is cached in :code:`cache_dir`
    """

    cache_dir = cache_dir or get_cache_dir()
    cache_dir: Path = Path(cache_dir)

    def _get_info(audio_path: str, audio_id: str):
        cache_file = cache_dir / f"{audio_id}.json"
        if cache_file.is_file():
            with cache_file.open() as f:
                info = json.load(f)
                return info

        torchaudio.set_audio_backend("sox_io")
        torchaudio_info = torchaudio.info(audio_path)
        info = {
            "sample_rate": torchaudio_info.sample_rate,
            "num_frames": torchaudio_info.num_frames,
            "num_channels": torchaudio_info.num_channels,
            "bits_per_sample": torchaudio_info.bits_per_sample,
            "encoding": torchaudio_info.encoding,
        }
        return info

    infos = Parallel(n_jobs=num_workers)(
        delayed(_get_info)(path, idx)
        for path, idx in tqdm(zip(audio_paths, audio_ids), desc="Get audio metadata")
    )
    return infos