unpairedelectron07 commited on
Commit
e418eef
·
verified ·
1 Parent(s): f78362b

Upload 7 files

Browse files
audiocraft/data/audio.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Audio IO methods are defined in this module (info, read, write),
9
+ We rely on av library for faster read when possible, otherwise on torchaudio.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ import logging
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import soundfile
19
+ import torch
20
+ from torch.nn import functional as F
21
+
22
+ import av
23
+ import subprocess as sp
24
+
25
+ from .audio_utils import f32_pcm, normalize_audio
26
+
27
+
28
+ _av_initialized = False
29
+
30
+
31
+ def _init_av():
32
+ global _av_initialized
33
+ if _av_initialized:
34
+ return
35
+ logger = logging.getLogger('libav.mp3')
36
+ logger.setLevel(logging.ERROR)
37
+ _av_initialized = True
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class AudioFileInfo:
42
+ sample_rate: int
43
+ duration: float
44
+ channels: int
45
+
46
+
47
+ def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48
+ _init_av()
49
+ with av.open(str(filepath)) as af:
50
+ stream = af.streams.audio[0]
51
+ sample_rate = stream.codec_context.sample_rate
52
+ duration = float(stream.duration * stream.time_base)
53
+ channels = stream.channels
54
+ return AudioFileInfo(sample_rate, duration, channels)
55
+
56
+
57
+ def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
58
+ info = soundfile.info(filepath)
59
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
60
+
61
+
62
+ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
64
+ filepath = Path(filepath)
65
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
66
+ # ffmpeg has some weird issue with flac.
67
+ return _soundfile_info(filepath)
68
+ else:
69
+ return _av_info(filepath)
70
+
71
+
72
+ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
+ """FFMPEG-based audio file reading using PyAV bindings.
74
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75
+
76
+ Args:
77
+ filepath (str or Path): Path to audio file to read.
78
+ seek_time (float): Time at which to start reading in the file.
79
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
+ Returns:
81
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
82
+ """
83
+ _init_av()
84
+ with av.open(str(filepath)) as af:
85
+ stream = af.streams.audio[0]
86
+ sr = stream.codec_context.sample_rate
87
+ num_frames = int(sr * duration) if duration >= 0 else -1
88
+ frame_offset = int(sr * seek_time)
89
+ # we need a small negative offset otherwise we get some edge artifact
90
+ # from the mp3 decoder.
91
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
92
+ frames = []
93
+ length = 0
94
+ for frame in af.decode(streams=stream.index):
95
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
96
+ strip = max(0, frame_offset - current_offset)
97
+ buf = torch.from_numpy(frame.to_ndarray())
98
+ if buf.shape[0] != stream.channels:
99
+ buf = buf.view(-1, stream.channels).t()
100
+ buf = buf[:, strip:]
101
+ frames.append(buf)
102
+ length += buf.shape[1]
103
+ if num_frames > 0 and length >= num_frames:
104
+ break
105
+ assert frames
106
+ # If the above assert fails, it is likely because we seeked past the end of file point,
107
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
108
+ # This will need proper debugging, in due time.
109
+ wav = torch.cat(frames, dim=1)
110
+ assert wav.shape[0] == stream.channels
111
+ if num_frames > 0:
112
+ wav = wav[:, :num_frames]
113
+ return f32_pcm(wav), sr
114
+
115
+
116
+ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
+ """Read audio by picking the most appropriate backend tool based on the audio format.
119
+
120
+ Args:
121
+ filepath (str or Path): Path to audio file to read.
122
+ seek_time (float): Time at which to start reading in the file.
123
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
+ pad (bool): Pad output audio if not reaching expected duration.
125
+ Returns:
126
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
127
+ """
128
+ fp = Path(filepath)
129
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
130
+ # There is some bug with ffmpeg and reading flac
131
+ info = _soundfile_info(filepath)
132
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
133
+ frame_offset = int(seek_time * info.sample_rate)
134
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
135
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
136
+ wav = torch.from_numpy(wav).t().contiguous()
137
+ if len(wav.shape) == 1:
138
+ wav = torch.unsqueeze(wav, 0)
139
+ else:
140
+ wav, sr = _av_read(filepath, seek_time, duration)
141
+ if pad and duration > 0:
142
+ expected_frames = int(duration * sr)
143
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
144
+ return wav, sr
145
+
146
+
147
+ def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
148
+ # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
149
+ assert wav.dim() == 2, wav.shape
150
+ command = [
151
+ 'ffmpeg',
152
+ '-loglevel', 'error',
153
+ '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
154
+ '-i', '-'] + flags + [str(out_path)]
155
+ input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
156
+ sp.run(command, input=input_, check=True)
157
+
158
+
159
+ def audio_write(stem_name: tp.Union[str, Path],
160
+ wav: torch.Tensor, sample_rate: int,
161
+ format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
162
+ normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
163
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
164
+ loudness_compressor: bool = False,
165
+ log_clipping: bool = True, make_parent_dir: bool = True,
166
+ add_suffix: bool = True) -> Path:
167
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
168
+
169
+ Args:
170
+ stem_name (str or Path): Filename without extension which will be added automatically.
171
+ wav (torch.Tensor): Audio data to save.
172
+ sample_rate (int): Sample rate of audio data.
173
+ format (str): Either "wav", "mp3", "ogg", or "flac".
174
+ mp3_rate (int): kbps when using mp3s.
175
+ ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
176
+ normalize (bool): if `True` (default), normalizes according to the prescribed
177
+ strategy (see after). If `False`, the strategy is only used in case clipping
178
+ would happen.
179
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
180
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
181
+ with extra headroom to avoid clipping. 'clip' just clips.
182
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
183
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
184
+ than the `peak_clip` one to avoid further clipping.
185
+ loudness_headroom_db (float): Target loudness for loudness normalization.
186
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
187
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
188
+ occurs despite strategy (only for 'rms').
189
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
190
+ Returns:
191
+ Path: Path of the saved audio.
192
+ """
193
+ assert wav.dtype.is_floating_point, "wav is not floating point"
194
+ if wav.dim() == 1:
195
+ wav = wav[None]
196
+ elif wav.dim() > 2:
197
+ raise ValueError("Input wav should be at most 2 dimension.")
198
+ assert wav.isfinite().all()
199
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
200
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
201
+ log_clipping=log_clipping, sample_rate=sample_rate,
202
+ stem_name=str(stem_name))
203
+ if format == 'mp3':
204
+ suffix = '.mp3'
205
+ flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
206
+ elif format == 'wav':
207
+ suffix = '.wav'
208
+ flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
209
+ elif format == 'ogg':
210
+ suffix = '.ogg'
211
+ flags = ['-f', 'ogg', '-c:a', 'libvorbis']
212
+ if ogg_rate is not None:
213
+ flags += ['-b:a', f'{ogg_rate}k']
214
+ elif format == 'flac':
215
+ suffix = '.flac'
216
+ flags = ['-f', 'flac']
217
+ else:
218
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
219
+ if not add_suffix:
220
+ suffix = ''
221
+ path = Path(str(stem_name) + suffix)
222
+ if make_parent_dir:
223
+ path.parent.mkdir(exist_ok=True, parents=True)
224
+ try:
225
+ _piping_to_ffmpeg(path, wav, sample_rate, flags)
226
+ except Exception:
227
+ if path.exists():
228
+ # we do not want to leave half written files around.
229
+ path.unlink()
230
+ raise
231
+ return path
audiocraft/data/audio_dataset.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """AudioDataset support. In order to handle a larger number of files
7
+ without having to scan again the folders, we precompute some metadata
8
+ (filename, sample rate, duration), and use that to efficiently sample audio segments.
9
+ """
10
+ import argparse
11
+ import copy
12
+ from concurrent.futures import ThreadPoolExecutor, Future
13
+ from dataclasses import dataclass, fields
14
+ from contextlib import ExitStack
15
+ from functools import lru_cache
16
+ import gzip
17
+ import json
18
+ import logging
19
+ import os
20
+ from pathlib import Path
21
+ import random
22
+ import sys
23
+ import typing as tp
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ from .audio import audio_read, audio_info
29
+ from .audio_utils import convert_audio
30
+ from .zip import PathInZip
31
+
32
+ try:
33
+ import dora
34
+ except ImportError:
35
+ dora = None # type: ignore
36
+
37
+
38
+ @dataclass(order=True)
39
+ class BaseInfo:
40
+
41
+ @classmethod
42
+ def _dict2fields(cls, dictionary: dict):
43
+ return {
44
+ field.name: dictionary[field.name]
45
+ for field in fields(cls) if field.name in dictionary
46
+ }
47
+
48
+ @classmethod
49
+ def from_dict(cls, dictionary: dict):
50
+ _dictionary = cls._dict2fields(dictionary)
51
+ return cls(**_dictionary)
52
+
53
+ def to_dict(self):
54
+ return {
55
+ field.name: self.__getattribute__(field.name)
56
+ for field in fields(self)
57
+ }
58
+
59
+
60
+ @dataclass(order=True)
61
+ class AudioMeta(BaseInfo):
62
+ path: str
63
+ duration: float
64
+ sample_rate: int
65
+ amplitude: tp.Optional[float] = None
66
+ weight: tp.Optional[float] = None
67
+ # info_path is used to load additional information about the audio file that is stored in zip files.
68
+ info_path: tp.Optional[PathInZip] = None
69
+
70
+ @classmethod
71
+ def from_dict(cls, dictionary: dict):
72
+ base = cls._dict2fields(dictionary)
73
+ if 'info_path' in base and base['info_path'] is not None:
74
+ base['info_path'] = PathInZip(base['info_path'])
75
+ return cls(**base)
76
+
77
+ def to_dict(self):
78
+ d = super().to_dict()
79
+ if d['info_path'] is not None:
80
+ d['info_path'] = str(d['info_path'])
81
+ return d
82
+
83
+
84
+ @dataclass(order=True)
85
+ class SegmentInfo(BaseInfo):
86
+ meta: AudioMeta
87
+ seek_time: float
88
+ # The following values are given once the audio is processed, e.g.
89
+ # at the target sample rate and target number of channels.
90
+ n_frames: int # actual number of frames without padding
91
+ total_frames: int # total number of frames, padding included
92
+ sample_rate: int # actual sample rate
93
+ channels: int # number of audio channels.
94
+
95
+
96
+ DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
97
+
98
+ logger = logging.getLogger(__name__)
99
+
100
+
101
+ def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
102
+ """AudioMeta from a path to an audio file.
103
+
104
+ Args:
105
+ file_path (str): Resolved path of valid audio file.
106
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
107
+ Returns:
108
+ AudioMeta: Audio file path and its metadata.
109
+ """
110
+ info = audio_info(file_path)
111
+ amplitude: tp.Optional[float] = None
112
+ if not minimal:
113
+ wav, sr = audio_read(file_path)
114
+ amplitude = wav.abs().max().item()
115
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
116
+
117
+
118
+ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
119
+ """If Dora is available as a dependency, try to resolve potential relative paths
120
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
121
+
122
+ Args:
123
+ m (AudioMeta): Audio meta to resolve.
124
+ fast (bool): If True, uses a really fast check for determining if a file
125
+ is already absolute or not. Only valid on Linux/Mac.
126
+ Returns:
127
+ AudioMeta: Audio meta with resolved path.
128
+ """
129
+ def is_abs(m):
130
+ if fast:
131
+ return str(m)[0] == '/'
132
+ else:
133
+ os.path.isabs(str(m))
134
+
135
+ if not dora:
136
+ return m
137
+
138
+ if not is_abs(m.path):
139
+ m.path = dora.git_save.to_absolute_path(m.path)
140
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
141
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
142
+ return m
143
+
144
+
145
+ def find_audio_files(path: tp.Union[Path, str],
146
+ exts: tp.List[str] = DEFAULT_EXTS,
147
+ resolve: bool = True,
148
+ minimal: bool = True,
149
+ progress: bool = False,
150
+ workers: int = 0) -> tp.List[AudioMeta]:
151
+ """Build a list of AudioMeta from a given path,
152
+ collecting relevant audio files and fetching meta info.
153
+
154
+ Args:
155
+ path (str or Path): Path to folder containing audio files.
156
+ exts (list of str): List of file extensions to consider for audio files.
157
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
158
+ progress (bool): Whether to log progress on audio files collection.
159
+ workers (int): number of parallel workers, if 0, use only the current thread.
160
+ Returns:
161
+ list of AudioMeta: List of audio file path and its metadata.
162
+ """
163
+ audio_files = []
164
+ futures: tp.List[Future] = []
165
+ pool: tp.Optional[ThreadPoolExecutor] = None
166
+ with ExitStack() as stack:
167
+ if workers > 0:
168
+ pool = ThreadPoolExecutor(workers)
169
+ stack.enter_context(pool)
170
+
171
+ if progress:
172
+ print("Finding audio files...")
173
+ for root, folders, files in os.walk(path, followlinks=True):
174
+ for file in files:
175
+ full_path = Path(root) / file
176
+ if full_path.suffix.lower() in exts:
177
+ audio_files.append(full_path)
178
+ if pool is not None:
179
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
180
+ if progress:
181
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
182
+
183
+ if progress:
184
+ print("Getting audio metadata...")
185
+ meta: tp.List[AudioMeta] = []
186
+ for idx, file_path in enumerate(audio_files):
187
+ try:
188
+ if pool is None:
189
+ m = _get_audio_meta(str(file_path), minimal)
190
+ else:
191
+ m = futures[idx].result()
192
+ if resolve:
193
+ m = _resolve_audio_meta(m)
194
+ except Exception as err:
195
+ print("Error with", str(file_path), err, file=sys.stderr)
196
+ continue
197
+ meta.append(m)
198
+ if progress:
199
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
200
+ meta.sort()
201
+ return meta
202
+
203
+
204
+ def load_audio_meta(path: tp.Union[str, Path],
205
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
206
+ """Load list of AudioMeta from an optionally compressed json file.
207
+
208
+ Args:
209
+ path (str or Path): Path to JSON file.
210
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
211
+ fast (bool): activates some tricks to make things faster.
212
+ Returns:
213
+ list of AudioMeta: List of audio file path and its total duration.
214
+ """
215
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
216
+ with open_fn(path, 'rb') as fp: # type: ignore
217
+ lines = fp.readlines()
218
+ meta = []
219
+ for line in lines:
220
+ d = json.loads(line)
221
+ m = AudioMeta.from_dict(d)
222
+ if resolve:
223
+ m = _resolve_audio_meta(m, fast=fast)
224
+ meta.append(m)
225
+ return meta
226
+
227
+
228
+ def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
229
+ """Save the audio metadata to the file pointer as json.
230
+
231
+ Args:
232
+ path (str or Path): Path to JSON file.
233
+ metadata (list of BaseAudioMeta): List of audio meta to save.
234
+ """
235
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
236
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
237
+ with open_fn(path, 'wb') as fp: # type: ignore
238
+ for m in meta:
239
+ json_str = json.dumps(m.to_dict()) + '\n'
240
+ json_bytes = json_str.encode('utf-8')
241
+ fp.write(json_bytes)
242
+
243
+
244
+ class AudioDataset:
245
+ """Base audio dataset.
246
+
247
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
248
+ and potentially additional information, by creating random segments from the list of audio
249
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
250
+ mixing of channels, padding, etc.
251
+
252
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
253
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
254
+ duration, applying padding if required.
255
+
256
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
257
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
258
+ original audio meta.
259
+
260
+ Note that you can call `start_epoch(epoch)` in order to get
261
+ a deterministic "randomization" for `shuffle=True`.
262
+ For a given epoch and dataset index, this will always return the same extract.
263
+ You can get back some diversity by setting the `shuffle_seed` param.
264
+
265
+ Args:
266
+ meta (list of AudioMeta): List of audio files metadata.
267
+ segment_duration (float, optional): Optional segment duration of audio to load.
268
+ If not specified, the dataset will load the full audio segment from the file.
269
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
270
+ sample_rate (int): Target sample rate of the loaded audio samples.
271
+ channels (int): Target number of channels of the loaded audio samples.
272
+ sample_on_duration (bool): Set to `True` to sample segments with probability
273
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
274
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
275
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
276
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
277
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
278
+ is shorter than the desired segment.
279
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
280
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
281
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
282
+ audio shorter than this will be filtered out.
283
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
284
+ audio longer than this will be filtered out.
285
+ shuffle_seed (int): can be used to further randomize
286
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
287
+ with the expected segment_duration (which must be provided if load_wav is False).
288
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
289
+ are False. Will ensure a permutation on files when going through the dataset.
290
+ In that case the epoch number must be provided in order for the model
291
+ to continue the permutation across epochs. In that case, it is assumed
292
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
293
+ `total_batch_size` the overall batch size accounting for all gpus.
294
+ """
295
+ def __init__(self,
296
+ meta: tp.List[AudioMeta],
297
+ segment_duration: tp.Optional[float] = None,
298
+ shuffle: bool = True,
299
+ num_samples: int = 10_000,
300
+ sample_rate: int = 48_000,
301
+ channels: int = 2,
302
+ pad: bool = True,
303
+ sample_on_duration: bool = True,
304
+ sample_on_weight: bool = True,
305
+ min_segment_ratio: float = 0.5,
306
+ max_read_retry: int = 10,
307
+ return_info: bool = False,
308
+ min_audio_duration: tp.Optional[float] = None,
309
+ max_audio_duration: tp.Optional[float] = None,
310
+ shuffle_seed: int = 0,
311
+ load_wav: bool = True,
312
+ permutation_on_files: bool = False,
313
+ ):
314
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
315
+ assert segment_duration is None or segment_duration > 0
316
+ assert segment_duration is None or min_segment_ratio >= 0
317
+ self.segment_duration = segment_duration
318
+ self.min_segment_ratio = min_segment_ratio
319
+ self.max_audio_duration = max_audio_duration
320
+ self.min_audio_duration = min_audio_duration
321
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
322
+ assert self.min_audio_duration <= self.max_audio_duration
323
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
324
+ assert len(self.meta) # Fail fast if all data has been filtered.
325
+ self.total_duration = sum(d.duration for d in self.meta)
326
+
327
+ if segment_duration is None:
328
+ num_samples = len(self.meta)
329
+ self.num_samples = num_samples
330
+ self.shuffle = shuffle
331
+ self.sample_rate = sample_rate
332
+ self.channels = channels
333
+ self.pad = pad
334
+ self.sample_on_weight = sample_on_weight
335
+ self.sample_on_duration = sample_on_duration
336
+ self.sampling_probabilities = self._get_sampling_probabilities()
337
+ self.max_read_retry = max_read_retry
338
+ self.return_info = return_info
339
+ self.shuffle_seed = shuffle_seed
340
+ self.current_epoch: tp.Optional[int] = None
341
+ self.load_wav = load_wav
342
+ if not load_wav:
343
+ assert segment_duration is not None
344
+ self.permutation_on_files = permutation_on_files
345
+ if permutation_on_files:
346
+ assert not self.sample_on_duration
347
+ assert not self.sample_on_weight
348
+ assert self.shuffle
349
+
350
+ def start_epoch(self, epoch: int):
351
+ self.current_epoch = epoch
352
+
353
+ def __len__(self):
354
+ return self.num_samples
355
+
356
+ def _get_sampling_probabilities(self, normalized: bool = True):
357
+ """Return the sampling probabilities for each file inside `self.meta`."""
358
+ scores: tp.List[float] = []
359
+ for file_meta in self.meta:
360
+ score = 1.
361
+ if self.sample_on_weight and file_meta.weight is not None:
362
+ score *= file_meta.weight
363
+ if self.sample_on_duration:
364
+ score *= file_meta.duration
365
+ scores.append(score)
366
+ probabilities = torch.tensor(scores)
367
+ if normalized:
368
+ probabilities /= probabilities.sum()
369
+ return probabilities
370
+
371
+ @staticmethod
372
+ @lru_cache(16)
373
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
374
+ # Used to keep the most recent files permutation in memory implicitely.
375
+ # will work unless someone is using a lot of Datasets in parallel.
376
+ rng = torch.Generator()
377
+ rng.manual_seed(base_seed + permutation_index)
378
+ return torch.randperm(num_files, generator=rng)
379
+
380
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
381
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
382
+ This is only called if `segment_duration` is not None.
383
+
384
+ You must use the provided random number generator `rng` for reproducibility.
385
+ You can further make use of the index accessed.
386
+ """
387
+ if self.permutation_on_files:
388
+ assert self.current_epoch is not None
389
+ total_index = self.current_epoch * len(self) + index
390
+ permutation_index = total_index // len(self.meta)
391
+ relative_index = total_index % len(self.meta)
392
+ permutation = AudioDataset._get_file_permutation(
393
+ len(self.meta), permutation_index, self.shuffle_seed)
394
+ file_index = permutation[relative_index]
395
+ return self.meta[file_index]
396
+
397
+ if not self.sample_on_weight and not self.sample_on_duration:
398
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
399
+ else:
400
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
401
+
402
+ return self.meta[file_index]
403
+
404
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
405
+ # Override this method in subclass if needed.
406
+ if self.load_wav:
407
+ return audio_read(path, seek_time, duration, pad=False)
408
+ else:
409
+ assert self.segment_duration is not None
410
+ n_frames = int(self.sample_rate * self.segment_duration)
411
+ return torch.zeros(self.channels, n_frames), self.sample_rate
412
+
413
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
414
+ if self.segment_duration is None:
415
+ file_meta = self.meta[index]
416
+ out, sr = audio_read(file_meta.path)
417
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
418
+ n_frames = out.shape[-1]
419
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
420
+ sample_rate=self.sample_rate, channels=out.shape[0])
421
+ else:
422
+ rng = torch.Generator()
423
+ if self.shuffle:
424
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
425
+ # otherwise we make use of the epoch number and optional shuffle_seed.
426
+ if self.current_epoch is None:
427
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
428
+ else:
429
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
430
+ else:
431
+ # We only use index
432
+ rng.manual_seed(index)
433
+
434
+ for retry in range(self.max_read_retry):
435
+ file_meta = self.sample_file(index, rng)
436
+ # We add some variance in the file position even if audio file is smaller than segment
437
+ # without ending up with empty segments
438
+ max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
439
+ seek_time = torch.rand(1, generator=rng).item() * max_seek
440
+ try:
441
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
442
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
443
+ n_frames = out.shape[-1]
444
+ target_frames = int(self.segment_duration * self.sample_rate)
445
+ if self.pad:
446
+ out = F.pad(out, (0, target_frames - n_frames))
447
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
448
+ sample_rate=self.sample_rate, channels=out.shape[0])
449
+ except Exception as exc:
450
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
451
+ if retry == self.max_read_retry - 1:
452
+ raise
453
+ else:
454
+ break
455
+
456
+ if self.return_info:
457
+ # Returns the wav and additional information on the wave segment
458
+ return out, segment_info
459
+ else:
460
+ return out
461
+
462
+ def collater(self, samples):
463
+ """The collater function has to be provided to the dataloader
464
+ if AudioDataset has return_info=True in order to properly collate
465
+ the samples of a batch.
466
+ """
467
+ if self.segment_duration is None and len(samples) > 1:
468
+ assert self.pad, "Must allow padding when batching examples of different durations."
469
+
470
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
471
+ to_pad = self.segment_duration is None and self.pad
472
+ if to_pad:
473
+ max_len = max([wav.shape[-1] for wav, _ in samples])
474
+
475
+ def _pad_wav(wav):
476
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
477
+
478
+ if self.return_info:
479
+ if len(samples) > 0:
480
+ assert len(samples[0]) == 2
481
+ assert isinstance(samples[0][0], torch.Tensor)
482
+ assert isinstance(samples[0][1], SegmentInfo)
483
+
484
+ wavs = [wav for wav, _ in samples]
485
+ segment_infos = [copy.deepcopy(info) for _, info in samples]
486
+
487
+ if to_pad:
488
+ # Each wav could be of a different duration as they are not segmented.
489
+ for i in range(len(samples)):
490
+ # Determines the total length of the signal with padding, so we update here as we pad.
491
+ segment_infos[i].total_frames = max_len
492
+ wavs[i] = _pad_wav(wavs[i])
493
+
494
+ wav = torch.stack(wavs)
495
+ return wav, segment_infos
496
+ else:
497
+ assert isinstance(samples[0], torch.Tensor)
498
+ if to_pad:
499
+ samples = [_pad_wav(s) for s in samples]
500
+ return torch.stack(samples)
501
+
502
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
503
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
504
+ orig_len = len(meta)
505
+
506
+ # Filter data that is too short.
507
+ if self.min_audio_duration is not None:
508
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
509
+
510
+ # Filter data that is too long.
511
+ if self.max_audio_duration is not None:
512
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
513
+
514
+ filtered_len = len(meta)
515
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
516
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
517
+ if removed_percentage < 10:
518
+ logging.debug(msg)
519
+ else:
520
+ logging.warning(msg)
521
+ return meta
522
+
523
+ @classmethod
524
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
525
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
526
+
527
+ Args:
528
+ root (str or Path): Path to root folder containing audio files.
529
+ kwargs: Additional keyword arguments for the AudioDataset.
530
+ """
531
+ root = Path(root)
532
+ if root.is_dir():
533
+ if (root / 'data.jsonl').exists():
534
+ root = root / 'data.jsonl'
535
+ elif (root / 'data.jsonl.gz').exists():
536
+ root = root / 'data.jsonl.gz'
537
+ else:
538
+ raise ValueError("Don't know where to read metadata from in the dir. "
539
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
540
+ meta = load_audio_meta(root)
541
+ return cls(meta, **kwargs)
542
+
543
+ @classmethod
544
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
545
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
546
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
547
+
548
+ Args:
549
+ root (str or Path): Path to root folder containing audio files.
550
+ minimal_meta (bool): Whether to only load minimal metadata or not.
551
+ exts (list of str): Extensions for audio files.
552
+ kwargs: Additional keyword arguments for the AudioDataset.
553
+ """
554
+ root = Path(root)
555
+ if root.is_file():
556
+ meta = load_audio_meta(root, resolve=True)
557
+ else:
558
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
559
+ return cls(meta, **kwargs)
560
+
561
+
562
+ def main():
563
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
564
+ parser = argparse.ArgumentParser(
565
+ prog='audio_dataset',
566
+ description='Generate .jsonl files by scanning a folder.')
567
+ parser.add_argument('root', help='Root folder with all the audio files')
568
+ parser.add_argument('output_meta_file',
569
+ help='Output file to store the metadata, ')
570
+ parser.add_argument('--complete',
571
+ action='store_false', dest='minimal', default=True,
572
+ help='Retrieve all metadata, even the one that are expansive '
573
+ 'to compute (e.g. normalization).')
574
+ parser.add_argument('--resolve',
575
+ action='store_true', default=False,
576
+ help='Resolve the paths to be absolute and with no symlinks.')
577
+ parser.add_argument('--workers',
578
+ default=10, type=int,
579
+ help='Number of workers.')
580
+ args = parser.parse_args()
581
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
582
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
583
+ save_audio_meta(args.output_meta_file, meta)
584
+
585
+
586
+ if __name__ == '__main__':
587
+ main()
audiocraft/data/audio_utils.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
7
+ and volume normalization."""
8
+ import sys
9
+ import typing as tp
10
+
11
+ import julius
12
+ import torch
13
+ import torchaudio
14
+
15
+
16
+ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
17
+ """Convert audio to the given number of channels.
18
+
19
+ Args:
20
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
21
+ channels (int): Expected number of channels as output.
22
+ Returns:
23
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
24
+ """
25
+ *shape, src_channels, length = wav.shape
26
+ if src_channels == channels:
27
+ pass
28
+ elif channels == 1:
29
+ # Case 1:
30
+ # The caller asked 1-channel audio, and the stream has multiple
31
+ # channels, downmix all channels.
32
+ wav = wav.mean(dim=-2, keepdim=True)
33
+ elif src_channels == 1:
34
+ # Case 2:
35
+ # The caller asked for multiple channels, but the input file has
36
+ # a single channel, replicate the audio over all channels.
37
+ wav = wav.expand(*shape, channels, length)
38
+ elif src_channels >= channels:
39
+ # Case 3:
40
+ # The caller asked for multiple channels, and the input file has
41
+ # more channels than requested. In that case return the first channels.
42
+ wav = wav[..., :channels, :]
43
+ else:
44
+ # Case 4: What is a reasonable choice here?
45
+ raise ValueError('The audio file has less channels than requested but is not mono.')
46
+ return wav
47
+
48
+
49
+ def convert_audio(wav: torch.Tensor, from_rate: float,
50
+ to_rate: float, to_channels: int) -> torch.Tensor:
51
+ """Convert audio to new sample rate and number of audio channels."""
52
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
+ wav = convert_audio_channels(wav, to_channels)
54
+ return wav
55
+
56
+
57
+ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
58
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
59
+ """Normalize an input signal to a user loudness in dB LKFS.
60
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61
+
62
+ Args:
63
+ wav (torch.Tensor): Input multichannel audio data.
64
+ sample_rate (int): Sample rate.
65
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66
+ loudness_compressor (bool): Uses tanh for soft clipping.
67
+ energy_floor (float): anything below that RMS level will not be rescaled.
68
+ Returns:
69
+ torch.Tensor: Loudness normalized output data.
70
+ """
71
+ energy = wav.pow(2).mean().sqrt().item()
72
+ if energy < energy_floor:
73
+ return wav
74
+ transform = torchaudio.transforms.Loudness(sample_rate)
75
+ input_loudness_db = transform(wav).item()
76
+ # calculate the gain needed to scale to the desired loudness level
77
+ delta_loudness = -loudness_headroom_db - input_loudness_db
78
+ gain = 10.0 ** (delta_loudness / 20.0)
79
+ output = gain * wav
80
+ if loudness_compressor:
81
+ output = torch.tanh(output)
82
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
83
+ return output
84
+
85
+
86
+ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
87
+ """Utility function to clip the audio with logging if specified."""
88
+ max_scale = wav.abs().max()
89
+ if log_clipping and max_scale > 1:
90
+ clamp_prob = (wav.abs() > 1).float().mean().item()
91
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
92
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
93
+ wav.clamp_(-1, 1)
94
+
95
+
96
+ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
97
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
98
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
99
+ loudness_compressor: bool = False, log_clipping: bool = False,
100
+ sample_rate: tp.Optional[int] = None,
101
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
102
+ """Normalize the audio according to the prescribed strategy (see after).
103
+
104
+ Args:
105
+ wav (torch.Tensor): Audio data.
106
+ normalize (bool): if `True` (default), normalizes according to the prescribed
107
+ strategy (see after). If `False`, the strategy is only used in case clipping
108
+ would happen.
109
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
110
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
111
+ with extra headroom to avoid clipping. 'clip' just clips.
112
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
113
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
114
+ than the `peak_clip` one to avoid further clipping.
115
+ loudness_headroom_db (float): Target loudness for loudness normalization.
116
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
117
+ log_clipping (bool): If True, basic logging on stderr when clipping still
118
+ occurs despite strategy (only for 'rms').
119
+ sample_rate (int): Sample rate for the audio data (required for loudness).
120
+ stem_name (str, optional): Stem name for clipping logging.
121
+ Returns:
122
+ torch.Tensor: Normalized audio.
123
+ """
124
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
125
+ scale_rms = 10 ** (-rms_headroom_db / 20)
126
+ if strategy == 'peak':
127
+ rescaling = (scale_peak / wav.abs().max())
128
+ if normalize or rescaling < 1:
129
+ wav = wav * rescaling
130
+ elif strategy == 'clip':
131
+ wav = wav.clamp(-scale_peak, scale_peak)
132
+ elif strategy == 'rms':
133
+ mono = wav.mean(dim=0)
134
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
135
+ if normalize or rescaling < 1:
136
+ wav = wav * rescaling
137
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
138
+ elif strategy == 'loudness':
139
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
140
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
141
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
142
+ else:
143
+ assert wav.abs().max() < 1
144
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
145
+ return wav
146
+
147
+
148
+ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
149
+ """Convert audio to float 32 bits PCM format.
150
+ """
151
+ if wav.dtype.is_floating_point:
152
+ return wav
153
+ elif wav.dtype == torch.int16:
154
+ return wav.float() / 2**15
155
+ elif wav.dtype == torch.int32:
156
+ return wav.float() / 2**31
157
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
158
+
159
+
160
+ def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
161
+ """Convert audio to int 16 bits PCM format.
162
+
163
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
164
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
165
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
166
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
167
+ """
168
+ if wav.dtype.is_floating_point:
169
+ assert wav.abs().max() <= 1
170
+ candidate = (wav * 2 ** 15).round()
171
+ if candidate.max() >= 2 ** 15: # clipping would occur
172
+ candidate = (wav * (2 ** 15 - 1)).round()
173
+ return candidate.short()
174
+ else:
175
+ assert wav.dtype == torch.int16
176
+ return wav
audiocraft/data/info_audio_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Base classes for the datasets that also provide non-audio metadata,
7
+ e.g. description, text transcription etc.
8
+ """
9
+ from dataclasses import dataclass
10
+ import logging
11
+ import math
12
+ import re
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .audio_dataset import AudioDataset, AudioMeta
18
+ from ..environment import AudioCraftEnvironment
19
+ from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
+ """Monkey-patch meta to match cluster specificities."""
27
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
+ if meta.info_path is not None:
29
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
+ return meta
31
+
32
+
33
+ def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
+ """Monkey-patch all meta to match cluster specificities."""
35
+ return [_clusterify_meta(m) for m in meta]
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo(SegmentWithAttributes):
40
+ """Dummy SegmentInfo with empty attributes.
41
+
42
+ The InfoAudioDataset is expected to return metadata that inherits
43
+ from SegmentWithAttributes class and can return conditioning attributes.
44
+
45
+ This basically guarantees all datasets will be compatible with current
46
+ solver that contain conditioners requiring this.
47
+ """
48
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
+
50
+ def to_condition_attributes(self) -> ConditioningAttributes:
51
+ return ConditioningAttributes()
52
+
53
+
54
+ class InfoAudioDataset(AudioDataset):
55
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
+
57
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
+ """
59
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
+ super().__init__(clusterify_all_meta(meta), **kwargs)
61
+
62
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
+ if not self.return_info:
64
+ wav = super().__getitem__(index)
65
+ assert isinstance(wav, torch.Tensor)
66
+ return wav
67
+ wav, meta = super().__getitem__(index)
68
+ return wav, AudioInfo(**meta.to_dict())
69
+
70
+
71
+ def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72
+ """Preprocess a single keyword or possible a list of keywords."""
73
+ if isinstance(value, list):
74
+ return get_keyword_list(value)
75
+ else:
76
+ return get_keyword(value)
77
+
78
+
79
+ def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80
+ """Preprocess a single keyword."""
81
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82
+ return None
83
+ else:
84
+ return value.strip()
85
+
86
+
87
+ def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88
+ """Preprocess a single keyword."""
89
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90
+ return None
91
+ else:
92
+ return value.strip().lower()
93
+
94
+
95
+ def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96
+ """Preprocess a list of keywords."""
97
+ if isinstance(values, str):
98
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
99
+ elif isinstance(values, float) and math.isnan(values):
100
+ values = []
101
+ if not isinstance(values, list):
102
+ logger.debug(f"Unexpected keyword list {values}")
103
+ values = [str(values)]
104
+
105
+ kws = [get_keyword(v) for v in values]
106
+ kw_list = [k for k in kws if k is not None]
107
+ if len(kw_list) == 0:
108
+ return None
109
+ else:
110
+ return kw_list
audiocraft/data/music_dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of music tracks with rich metadata.
7
+ """
8
+ from dataclasses import dataclass, field, fields, replace
9
+ import gzip
10
+ import json
11
+ import logging
12
+ from pathlib import Path
13
+ import random
14
+ import typing as tp
15
+
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ AudioInfo,
21
+ get_keyword_list,
22
+ get_keyword,
23
+ get_string
24
+ )
25
+ from ..modules.conditioners import (
26
+ ConditioningAttributes,
27
+ JointEmbedCondition,
28
+ WavCondition,
29
+ )
30
+ from ..utils.utils import warn_once
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class MusicInfo(AudioInfo):
38
+ """Segment info augmented with music metadata.
39
+ """
40
+ # music-specific metadata
41
+ title: tp.Optional[str] = None
42
+ artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
43
+ key: tp.Optional[str] = None
44
+ bpm: tp.Optional[float] = None
45
+ genre: tp.Optional[str] = None
46
+ moods: tp.Optional[list] = None
47
+ keywords: tp.Optional[list] = None
48
+ description: tp.Optional[str] = None
49
+ name: tp.Optional[str] = None
50
+ instrument: tp.Optional[str] = None
51
+ # original wav accompanying the metadata
52
+ self_wav: tp.Optional[WavCondition] = None
53
+ # dict mapping attributes names to tuple of wav, text and metadata
54
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
55
+
56
+ @property
57
+ def has_music_meta(self) -> bool:
58
+ return self.name is not None
59
+
60
+ def to_condition_attributes(self) -> ConditioningAttributes:
61
+ out = ConditioningAttributes()
62
+ for _field in fields(self):
63
+ key, value = _field.name, getattr(self, _field.name)
64
+ if key == 'self_wav':
65
+ out.wav[key] = value
66
+ elif key == 'joint_embed':
67
+ for embed_attribute, embed_cond in value.items():
68
+ out.joint_embed[embed_attribute] = embed_cond
69
+ else:
70
+ if isinstance(value, list):
71
+ value = ' '.join(value)
72
+ out.text[key] = value
73
+ return out
74
+
75
+ @staticmethod
76
+ def attribute_getter(attribute):
77
+ if attribute == 'bpm':
78
+ preprocess_func = get_bpm
79
+ elif attribute == 'key':
80
+ preprocess_func = get_musical_key
81
+ elif attribute in ['moods', 'keywords']:
82
+ preprocess_func = get_keyword_list
83
+ elif attribute in ['genre', 'name', 'instrument']:
84
+ preprocess_func = get_keyword
85
+ elif attribute in ['title', 'artist', 'description']:
86
+ preprocess_func = get_string
87
+ else:
88
+ preprocess_func = None
89
+ return preprocess_func
90
+
91
+ @classmethod
92
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
93
+ _dictionary: tp.Dict[str, tp.Any] = {}
94
+
95
+ # allow a subset of attributes to not be loaded from the dictionary
96
+ # these attributes may be populated later
97
+ post_init_attributes = ['self_wav', 'joint_embed']
98
+ optional_fields = ['keywords']
99
+
100
+ for _field in fields(cls):
101
+ if _field.name in post_init_attributes:
102
+ continue
103
+ elif _field.name not in dictionary:
104
+ if fields_required and _field.name not in optional_fields:
105
+ raise KeyError(f"Unexpected missing key: {_field.name}")
106
+ else:
107
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
108
+ value = dictionary[_field.name]
109
+ if preprocess_func:
110
+ value = preprocess_func(value)
111
+ _dictionary[_field.name] = value
112
+ return cls(**_dictionary)
113
+
114
+
115
+ def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
116
+ drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
117
+ """Augment MusicInfo description with additional metadata fields and potential dropout.
118
+ Additional textual attributes are added given probability 'merge_text_conditions_p' and
119
+ the original textual description is dropped from the augmented description given probability drop_desc_p.
120
+
121
+ Args:
122
+ music_info (MusicInfo): The music metadata to augment.
123
+ merge_text_p (float): Probability of merging additional metadata to the description.
124
+ If provided value is 0, then no merging is performed.
125
+ drop_desc_p (float): Probability of dropping the original description on text merge.
126
+ if provided value is 0, then no drop out is performed.
127
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
128
+ Returns:
129
+ MusicInfo: The MusicInfo with augmented textual description.
130
+ """
131
+ def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
132
+ valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
133
+ valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
134
+ keep_field = random.uniform(0, 1) < drop_other_p
135
+ return valid_field_name and valid_field_value and keep_field
136
+
137
+ def process_value(v: tp.Any) -> str:
138
+ if isinstance(v, (int, float, str)):
139
+ return str(v)
140
+ if isinstance(v, list):
141
+ return ", ".join(v)
142
+ else:
143
+ raise ValueError(f"Unknown type for text value! ({type(v), v})")
144
+
145
+ description = music_info.description
146
+
147
+ metadata_text = ""
148
+ if random.uniform(0, 1) < merge_text_p:
149
+ meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
150
+ for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
151
+ random.shuffle(meta_pairs)
152
+ metadata_text = ". ".join(meta_pairs)
153
+ description = description if not random.uniform(0, 1) < drop_desc_p else None
154
+ logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
155
+
156
+ if description is None:
157
+ description = metadata_text if len(metadata_text) > 1 else None
158
+ else:
159
+ description = ". ".join([description.rstrip('.'), metadata_text])
160
+ description = description.strip() if description else None
161
+
162
+ music_info = replace(music_info)
163
+ music_info.description = description
164
+ return music_info
165
+
166
+
167
+ class Paraphraser:
168
+ def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
169
+ self.paraphrase_p = paraphrase_p
170
+ open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
171
+ with open_fn(paraphrase_source, 'rb') as f: # type: ignore
172
+ self.paraphrase_source = json.loads(f.read())
173
+ logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
174
+
175
+ def sample_paraphrase(self, audio_path: str, description: str):
176
+ if random.random() >= self.paraphrase_p:
177
+ return description
178
+ info_path = Path(audio_path).with_suffix('.json')
179
+ if info_path not in self.paraphrase_source:
180
+ warn_once(logger, f"{info_path} not in paraphrase source!")
181
+ return description
182
+ new_desc = random.choice(self.paraphrase_source[info_path])
183
+ logger.debug(f"{description} -> {new_desc}")
184
+ return new_desc
185
+
186
+
187
+ class MusicDataset(InfoAudioDataset):
188
+ """Music dataset is an AudioDataset with music-related metadata.
189
+
190
+ Args:
191
+ info_fields_required (bool): Whether to enforce having required fields.
192
+ merge_text_p (float): Probability of merging additional metadata to the description.
193
+ drop_desc_p (float): Probability of dropping the original description on text merge.
194
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
195
+ joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
196
+ paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
197
+ paraphrases for the description. The json should be a dict with keys are the
198
+ original info path (e.g. track_path.json) and each value is a list of possible
199
+ paraphrased.
200
+ paraphrase_p (float): probability of taking a paraphrase.
201
+
202
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
203
+ """
204
+ def __init__(self, *args, info_fields_required: bool = True,
205
+ merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
206
+ joint_embed_attributes: tp.List[str] = [],
207
+ paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
208
+ **kwargs):
209
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
210
+ super().__init__(*args, **kwargs)
211
+ self.info_fields_required = info_fields_required
212
+ self.merge_text_p = merge_text_p
213
+ self.drop_desc_p = drop_desc_p
214
+ self.drop_other_p = drop_other_p
215
+ self.joint_embed_attributes = joint_embed_attributes
216
+ self.paraphraser = None
217
+ if paraphrase_source is not None:
218
+ self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
219
+
220
+ def __getitem__(self, index):
221
+ wav, info = super().__getitem__(index)
222
+ info_data = info.to_dict()
223
+ music_info_path = Path(info.meta.path).with_suffix('.json')
224
+
225
+ if Path(music_info_path).exists():
226
+ with open(music_info_path, 'r') as json_file:
227
+ music_data = json.load(json_file)
228
+ music_data.update(info_data)
229
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
230
+ if self.paraphraser is not None:
231
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
232
+ if self.merge_text_p:
233
+ music_info = augment_music_info_description(
234
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
235
+ else:
236
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
237
+
238
+ music_info.self_wav = WavCondition(
239
+ wav=wav[None], length=torch.tensor([info.n_frames]),
240
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
241
+
242
+ for att in self.joint_embed_attributes:
243
+ att_value = getattr(music_info, att)
244
+ joint_embed_cond = JointEmbedCondition(
245
+ wav[None], [att_value], torch.tensor([info.n_frames]),
246
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
247
+ music_info.joint_embed[att] = joint_embed_cond
248
+
249
+ return wav, music_info
250
+
251
+
252
+ def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
253
+ """Preprocess key keywords, discarding them if there are multiple key defined."""
254
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
255
+ return None
256
+ elif ',' in value:
257
+ # For now, we discard when multiple keys are defined separated with comas
258
+ return None
259
+ else:
260
+ return value.strip().lower()
261
+
262
+
263
+ def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
264
+ """Preprocess to a float."""
265
+ if value is None:
266
+ return None
267
+ try:
268
+ return float(value)
269
+ except ValueError:
270
+ return None
audiocraft/data/sound_dataset.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of audio with a simple description.
7
+ """
8
+
9
+ from dataclasses import dataclass, fields, replace
10
+ import json
11
+ from pathlib import Path
12
+ import random
13
+ import typing as tp
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ get_keyword_or_keyword_list
21
+ )
22
+ from ..modules.conditioners import (
23
+ ConditioningAttributes,
24
+ SegmentWithAttributes,
25
+ WavCondition,
26
+ )
27
+
28
+
29
+ EPS = torch.finfo(torch.float32).eps
30
+ TARGET_LEVEL_LOWER = -35
31
+ TARGET_LEVEL_UPPER = -15
32
+
33
+
34
+ @dataclass
35
+ class SoundInfo(SegmentWithAttributes):
36
+ """Segment info augmented with Sound metadata.
37
+ """
38
+ description: tp.Optional[str] = None
39
+ self_wav: tp.Optional[torch.Tensor] = None
40
+
41
+ @property
42
+ def has_sound_meta(self) -> bool:
43
+ return self.description is not None
44
+
45
+ def to_condition_attributes(self) -> ConditioningAttributes:
46
+ out = ConditioningAttributes()
47
+
48
+ for _field in fields(self):
49
+ key, value = _field.name, getattr(self, _field.name)
50
+ if key == 'self_wav':
51
+ out.wav[key] = value
52
+ else:
53
+ out.text[key] = value
54
+ return out
55
+
56
+ @staticmethod
57
+ def attribute_getter(attribute):
58
+ if attribute == 'description':
59
+ preprocess_func = get_keyword_or_keyword_list
60
+ else:
61
+ preprocess_func = None
62
+ return preprocess_func
63
+
64
+ @classmethod
65
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
66
+ _dictionary: tp.Dict[str, tp.Any] = {}
67
+
68
+ # allow a subset of attributes to not be loaded from the dictionary
69
+ # these attributes may be populated later
70
+ post_init_attributes = ['self_wav']
71
+
72
+ for _field in fields(cls):
73
+ if _field.name in post_init_attributes:
74
+ continue
75
+ elif _field.name not in dictionary:
76
+ if fields_required:
77
+ raise KeyError(f"Unexpected missing key: {_field.name}")
78
+ else:
79
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
80
+ value = dictionary[_field.name]
81
+ if preprocess_func:
82
+ value = preprocess_func(value)
83
+ _dictionary[_field.name] = value
84
+ return cls(**_dictionary)
85
+
86
+
87
+ class SoundDataset(InfoAudioDataset):
88
+ """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
89
+
90
+ Args:
91
+ info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
92
+ external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
93
+ The metadata files contained in this folder are expected to match the stem of the audio file with
94
+ a json extension.
95
+ aug_p (float): Probability of performing audio mixing augmentation on the batch.
96
+ mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
97
+ mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
98
+ mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
99
+ mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
100
+ kwargs: Additional arguments for AudioDataset.
101
+
102
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
103
+ """
104
+ def __init__(
105
+ self,
106
+ *args,
107
+ info_fields_required: bool = True,
108
+ external_metadata_source: tp.Optional[str] = None,
109
+ aug_p: float = 0.,
110
+ mix_p: float = 0.,
111
+ mix_snr_low: int = -5,
112
+ mix_snr_high: int = 5,
113
+ mix_min_overlap: float = 0.5,
114
+ **kwargs
115
+ ):
116
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
117
+ super().__init__(*args, **kwargs)
118
+ self.info_fields_required = info_fields_required
119
+ self.external_metadata_source = external_metadata_source
120
+ self.aug_p = aug_p
121
+ self.mix_p = mix_p
122
+ if self.aug_p > 0:
123
+ assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
124
+ assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
125
+ self.mix_snr_low = mix_snr_low
126
+ self.mix_snr_high = mix_snr_high
127
+ self.mix_min_overlap = mix_min_overlap
128
+
129
+ def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
130
+ """Get path of JSON with metadata (description, etc.).
131
+ If there exists a JSON with the same name as 'path.name', then it will be used.
132
+ Else, such JSON will be searched for in an external json source folder if it exists.
133
+ """
134
+ info_path = Path(path).with_suffix('.json')
135
+ if Path(info_path).exists():
136
+ return info_path
137
+ elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
138
+ return Path(self.external_metadata_source) / info_path.name
139
+ else:
140
+ raise Exception(f"Unable to find a metadata JSON for path: {path}")
141
+
142
+ def __getitem__(self, index):
143
+ wav, info = super().__getitem__(index)
144
+ info_data = info.to_dict()
145
+ info_path = self._get_info_path(info.meta.path)
146
+ if Path(info_path).exists():
147
+ with open(info_path, 'r') as json_file:
148
+ sound_data = json.load(json_file)
149
+ sound_data.update(info_data)
150
+ sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
151
+ # if there are multiple descriptions, sample one randomly
152
+ if isinstance(sound_info.description, list):
153
+ sound_info.description = random.choice(sound_info.description)
154
+ else:
155
+ sound_info = SoundInfo.from_dict(info_data, fields_required=False)
156
+
157
+ sound_info.self_wav = WavCondition(
158
+ wav=wav[None], length=torch.tensor([info.n_frames]),
159
+ sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
160
+
161
+ return wav, sound_info
162
+
163
+ def collater(self, samples):
164
+ # when training, audio mixing is performed in the collate function
165
+ wav, sound_info = super().collater(samples) # SoundDataset always returns infos
166
+ if self.aug_p > 0:
167
+ wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
168
+ snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
169
+ min_overlap=self.mix_min_overlap)
170
+ return wav, sound_info
171
+
172
+
173
+ def rms_f(x: torch.Tensor) -> torch.Tensor:
174
+ return (x ** 2).mean(1).pow(0.5)
175
+
176
+
177
+ def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
178
+ """Normalize the signal to the target level."""
179
+ rms = rms_f(audio)
180
+ scalar = 10 ** (target_level / 20) / (rms + EPS)
181
+ audio = audio * scalar.unsqueeze(1)
182
+ return audio
183
+
184
+
185
+ def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
186
+ return (abs(audio) > clipping_threshold).any(1)
187
+
188
+
189
+ def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
190
+ start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
191
+ remainder = src.shape[1] - start
192
+ if dst.shape[1] > remainder:
193
+ src[:, start:] = src[:, start:] + dst[:, :remainder]
194
+ else:
195
+ src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
196
+ return src
197
+
198
+
199
+ def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
200
+ target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
201
+ """Function to mix clean speech and noise at various SNR levels.
202
+
203
+ Args:
204
+ clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
205
+ noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
206
+ snr (int): SNR level when mixing.
207
+ min_overlap (float): Minimum overlap between the two mixed sources.
208
+ target_level (int): Gain level in dB.
209
+ clipping_threshold (float): Threshold for clipping the audio.
210
+ Returns:
211
+ torch.Tensor: The mixed audio, of shape [B, T].
212
+ """
213
+ if clean.shape[1] > noise.shape[1]:
214
+ noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
215
+ else:
216
+ noise = noise[:, :clean.shape[1]]
217
+
218
+ # normalizing to -25 dB FS
219
+ clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
220
+ clean = normalize(clean, target_level)
221
+ rmsclean = rms_f(clean)
222
+
223
+ noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
224
+ noise = normalize(noise, target_level)
225
+ rmsnoise = rms_f(noise)
226
+
227
+ # set the noise level for a given SNR
228
+ noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
229
+ noisenewlevel = noise * noisescalar
230
+
231
+ # mix noise and clean speech
232
+ noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
233
+
234
+ # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
235
+ # there is a chance of clipping that might happen with very less probability, which is not a major issue.
236
+ noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
237
+ rmsnoisy = rms_f(noisyspeech)
238
+ scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
239
+ noisyspeech = noisyspeech * scalarnoisy
240
+ clean = clean * scalarnoisy
241
+ noisenewlevel = noisenewlevel * scalarnoisy
242
+
243
+ # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
244
+ clipped = is_clipped(noisyspeech)
245
+ if clipped.any():
246
+ noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
247
+ noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
248
+
249
+ return noisyspeech
250
+
251
+
252
+ def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
253
+ if snr_low == snr_high:
254
+ snr = snr_low
255
+ else:
256
+ snr = np.random.randint(snr_low, snr_high)
257
+ mix = snr_mixer(src, dst, snr, min_overlap)
258
+ return mix
259
+
260
+
261
+ def mix_text(src_text: str, dst_text: str):
262
+ """Mix text from different sources by concatenating them."""
263
+ if src_text == dst_text:
264
+ return src_text
265
+ return src_text + " " + dst_text
266
+
267
+
268
+ def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
269
+ snr_low: int, snr_high: int, min_overlap: float):
270
+ """Mix samples within a batch, summing the waveforms and concatenating the text infos.
271
+
272
+ Args:
273
+ wavs (torch.Tensor): Audio tensors of shape [B, C, T].
274
+ infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
275
+ aug_p (float): Augmentation probability.
276
+ mix_p (float): Proportion of items in the batch to mix (and merge) together.
277
+ snr_low (int): Lowerbound for sampling SNR.
278
+ snr_high (int): Upperbound for sampling SNR.
279
+ min_overlap (float): Minimum overlap between mixed samples.
280
+ Returns:
281
+ tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
282
+ and mixed SoundInfo for the given batch.
283
+ """
284
+ # no mixing to perform within the batch
285
+ if mix_p == 0:
286
+ return wavs, infos
287
+
288
+ if random.uniform(0, 1) < aug_p:
289
+ # perform all augmentations on waveforms as [B, T]
290
+ # randomly picking pairs of audio to mix
291
+ assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
292
+ wavs = wavs.mean(dim=1, keepdim=False)
293
+ B, T = wavs.shape
294
+ k = int(mix_p * B)
295
+ mixed_sources_idx = torch.randperm(B)[:k]
296
+ mixed_targets_idx = torch.randperm(B)[:k]
297
+ aug_wavs = snr_mix(
298
+ wavs[mixed_sources_idx],
299
+ wavs[mixed_targets_idx],
300
+ snr_low,
301
+ snr_high,
302
+ min_overlap,
303
+ )
304
+ # mixing textual descriptions in metadata
305
+ descriptions = [info.description for info in infos]
306
+ aug_infos = []
307
+ for i, j in zip(mixed_sources_idx, mixed_targets_idx):
308
+ text = mix_text(descriptions[i], descriptions[j])
309
+ m = replace(infos[i])
310
+ m.description = text
311
+ aug_infos.append(m)
312
+
313
+ # back to [B, C, T]
314
+ aug_wavs = aug_wavs.unsqueeze(1)
315
+ assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
316
+ assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
317
+ assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
318
+
319
+ return aug_wavs, aug_infos # [B, C, T]
320
+ else:
321
+ # randomly pick samples in the batch to match
322
+ # the batch size when performing audio mixing
323
+ B, C, T = wavs.shape
324
+ k = int(mix_p * B)
325
+ wav_idx = torch.randperm(B)[:k]
326
+ wavs = wavs[wav_idx]
327
+ infos = [infos[i] for i in wav_idx]
328
+ assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
329
+
330
+ return wavs, infos # [B, C, T]
audiocraft/data/zip.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Utility for reading some info from inside a zip file.
7
+ """
8
+
9
+ import typing
10
+ import zipfile
11
+
12
+ from dataclasses import dataclass
13
+ from functools import lru_cache
14
+ from typing_extensions import Literal
15
+
16
+
17
+ DEFAULT_SIZE = 32
18
+ MODE = Literal['r', 'w', 'x', 'a']
19
+
20
+
21
+ @dataclass(order=True)
22
+ class PathInZip:
23
+ """Hold a path of file within a zip file.
24
+
25
+ Args:
26
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
+ Let's assume there is a zip file /some/location/foo.zip
28
+ and inside of it is a json file located at /data/file1.json,
29
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
+ """
31
+
32
+ INFO_PATH_SEP = ':'
33
+ zip_path: str
34
+ file_path: str
35
+
36
+ def __init__(self, path: str) -> None:
37
+ split_path = path.split(self.INFO_PATH_SEP)
38
+ assert len(split_path) == 2
39
+ self.zip_path, self.file_path = split_path
40
+
41
+ @classmethod
42
+ def from_paths(cls, zip_path: str, file_path: str):
43
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
44
+
45
+ def __str__(self) -> str:
46
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
47
+
48
+
49
+ def _open_zip(path: str, mode: MODE = 'r'):
50
+ return zipfile.ZipFile(path, mode)
51
+
52
+
53
+ _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
54
+
55
+
56
+ def set_zip_cache_size(max_size: int):
57
+ """Sets the maximal LRU caching for zip file opening.
58
+
59
+ Args:
60
+ max_size (int): the maximal LRU cache.
61
+ """
62
+ global _cached_open_zip
63
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
64
+
65
+
66
+ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
67
+ """Opens a file stored inside a zip and returns a file-like object.
68
+
69
+ Args:
70
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
+ mode (str): The mode in which to open the file with.
72
+ Returns:
73
+ A file-like object for PathInZip.
74
+ """
75
+ zf = _cached_open_zip(path_in_zip.zip_path)
76
+ return zf.open(path_in_zip.file_path)