File size: 1,882 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Create pseudo data

Authors
  * Leo 2022
"""

import random
import shutil
import tempfile
from pathlib import Path
from typing import Any, List

import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence

SAMPLE_RATE = 16000

__all__ = [
    "pseudo_audio",
    "get_pseudo_wavs",
]


class pseudo_audio:
    """
    This context manager returns filepaths (List[str]) and num_samples (List[int]) on entering
    """

    def __init__(self, secs: List[float], sample_rate: int = SAMPLE_RATE):
        self.tempdir = Path(tempfile.TemporaryDirectory().name)
        self.tempdir.mkdir(parents=True, exist_ok=True)
        self.num_samples = []
        for n, sec in enumerate(secs):
            wav = torch.randn(1, round(sample_rate * sec))
            torchaudio.save(
                str(self.tempdir / f"audio_{n}.wav"), wav, sample_rate=sample_rate
            )
            self.num_samples.append(wav.size(-1))
        self.filepaths = [
            str(self.tempdir / f"audio_{i}.wav") for i in range(len(secs))
        ]

    def __enter__(self):
        return self.filepaths, self.num_samples

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        shutil.rmtree(self.tempdir)


def get_pseudo_wavs(
    seed: int = 0,
    n: int = 2,
    min_secs: int = 1,
    max_secs: int = 3,
    sample_rate: int = SAMPLE_RATE,
    device: str = "cpu",
    padded: bool = False,
):
    random.seed(seed)
    torch.manual_seed(seed)

    wavs = []
    wavs_len = []
    for _ in range(n):
        wav_length = random.randint(min_secs * sample_rate, max_secs * sample_rate)
        wav = torch.randn(wav_length, requires_grad=True).to(device)
        wavs_len.append(wav_length)
        wavs.append(wav)

    if not padded:
        return wavs
    else:
        return pad_sequence(wavs, batch_first=True), torch.LongTensor(wavs_len)