Spaces:
Running
on
T4
Running
on
T4
File size: 6,141 Bytes
33d9042 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.
# %% auto 0
__all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
'readlines']
# %% ../nbs/D. Common dataset utilities.ipynb 1
import os
import torch
import torchaudio
from pathlib import Path
import webdataset as wds
from contextlib import contextmanager
import torch.nn.functional as F
# %% ../nbs/D. Common dataset utilities.ipynb 2
def shard_glob(input):
if '{' in input:
return wds.shardlists.expand_urls(input)
if isinstance(input, (Path, str)):
path = Path(input)
if path.is_dir():
glob = '*.tar.gz'
else:
glob = path.name
path = path.parent
input = Path(path).glob(glob)
else:
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
return [str(x) for x in input]
# %% ../nbs/D. Common dataset utilities.ipynb 3
class join_datasets(torch.utils.data.IterableDataset):
def __init__(self, datasets):
self.datasets = datasets
def __iter__(self):
probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
its = [iter(ds) for ds in self.datasets]
while True:
try:
yield next(its[torch.multinomial(probs, 1)])
except StopIteration:
return
def __len__(self):
return sum([ds.total_samples for ds in self.datasets])
# %% ../nbs/D. Common dataset utilities.ipynb 5
def resampler(newsr = 24000, key = 'samples_24k'):
_last_sr = None
tform = None
def _resample(samples):
for s in samples:
sr = s['sample_rate']
if sr != newsr:
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
s[key] = tform(s['samples'])
else:
s[key] = s['samples']
yield s
return _resample
# %% ../nbs/D. Common dataset utilities.ipynb 6
def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
dir = Path(dir) if dir else Path(input).parent
return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))
# %% ../nbs/D. Common dataset utilities.ipynb 7
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
def deriver(url):
url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
return wds.WebDataset(
wds.SimpleShardList([url])
).decode(*decoders)
return deriver
# %% ../nbs/D. Common dataset utilities.ipynb 8
def merge_in(dataset_fun):
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
It requires (and validates) that both datasets have the same ordering of keys so you have
to use it before any sample shuffling. Shard shuffling is ok.
"""
def merge_loop(main_samples):
#print("new merge loop:", dataset_fun)
merged_samples = None
cur_url = None
i = None
for s in main_samples:
url = s['__url__']
if url != cur_url:
# this will open a new file when we get the first sample with a new __url__
merged_samples = iter(dataset_fun(url))
cur_url = url
try:
merge_s = next(merged_samples)
except StopIteration:
# if the original shard got repeated we won't observe a __url__ change
# in this case restart the dataset from the beginning
merged_samples = iter(dataset_fun(url))
merge_s = next(merged_samples)
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
news = {}
news.update(merge_s)
news.update(s)
yield news
return merge_loop
# %% ../nbs/D. Common dataset utilities.ipynb 9
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
for s in stream:
audio, sr = s['audio']
imax = len(s[ikey]) - 1
for i,(ts,te) in enumerate(s[ikey]):
samples = audio[0,int(ts*sr):int(te*sr)]
if pad_to_seconds is not None:
padding = pad_to_seconds*sr-samples.shape[-1]
lpad = random.randint(0, padding) if random_shift else 0
samples = F.pad(samples, (lpad, padding-lpad))
subs = {"__key__": s['__key__'] + f"_{i:03d}",
"src_key": s['__key__'],
"__url__": s['__url__'],
"i": i, "imax": imax,
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
"lpad": lpad, "rpad": padding-lpad,
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
"samples": samples, "sample_rate": sr}
for k in metakeys:
subs[k] = s[k][i]
yield subs
# %% ../nbs/D. Common dataset utilities.ipynb 10
def vad_dataset(shards, ikey='vad.npy', kind='vad'):
return wds.WebDataset(shards).compose(
wds.decode(wds.torch_audio),
merge_in(derived_dataset(kind)),
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
wds.rename(audio="flac;mp3;wav;ogg"),
lambda x: split_to_chunks(x, ikey=ikey),
)
# %% ../nbs/D. Common dataset utilities.ipynb 11
@contextmanager
def AtomicTarWriter(name, throwaway=False):
tmp = name+".tmp"
with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
yield sink
if not throwaway:
os.rename(tmp, name)
# %% ../nbs/D. Common dataset utilities.ipynb 12
def readlines(fname):
with open(fname) as file:
return [line.rstrip() for line in file]
|