File size: 7,591 Bytes
26fd00c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# https://github.com/LAION-AI/CLAP/blob/df65ca0f6c3062dc554132cb40e74f4915084b21/src/training/data.py#L469
from functools import partial
import soundfile as sf
import io
import numpy as np
import torch
import torchaudio
import torchvision
import torch.nn.functional as F
AUDIO_CFG = {
"sample_rate": 48000,
"audio_length": 1024,
"clip_samples": 480000,
"mel_bins": 64,
"window_size": 1024,
"hop_size": 480,
"fmin": 50,
"fmax": 14000,
"class_num": 527,
}
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class Map(dict):
"""
Example:
m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
"""
def __init__(self, *args, **kwargs):
super(Map, self).__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for k, v in arg.iteritems():
self[k] = v
if kwargs:
for k, v in kwargs.iteritems():
self[k] = v
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(Map, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(Map, self).__delitem__(key)
del self.__dict__[key]
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def get_mel(audio_data,audio_cfg):
# mel shape: (n_mels, T)
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=audio_cfg['sample_rate'],
n_fft=audio_cfg['window_size'],
win_length=audio_cfg['window_size'],
hop_length=audio_cfg['hop_size'],
center=True,
pad_mode="reflect",
power=2.0,
norm=None,
onesided=True,
n_mels=audio_cfg['mel_bins'],
f_min=audio_cfg['fmin'],
f_max=audio_cfg['fmax']
)(audio_data)
# we use log mel spectrogram as input
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
return mel.T # (T, n_mels)
def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg):
"""
Calculate and add audio features to sample.
Sample: a dict containing all the data of current sample.
audio_data: a tensor of shape (T) containing audio data.
max_len: the maximum length of audio data.
data_truncating: the method of truncating data.
data_filling: the method of filling data.
audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
"""
with torch.no_grad():
if len(audio_data) > max_len:
if data_truncating == "rand_trunc":
longer = torch.tensor([True])
elif data_truncating == "fusion":
# fusion
mel = get_mel(audio_data, audio_cfg)
# split to three parts
chunk_frames = max_len // audio_cfg['hop_size']+1 # the +1 related to how the spectrogram is computed
total_frames = mel.shape[0]
if chunk_frames == total_frames:
# there is a corner case where the audio length is
# larger than max_len but smaller than max_len+hop_size.
# In this case, we just use the whole audio.
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
sample["mel_fusion"] = mel_fusion
longer = torch.tensor([False])
else:
ranges = np.array_split(list(range(0, total_frames-chunk_frames+1)), 3)
# print('total_frames-chunk_frames:', total_frames-chunk_frames,
# 'len(audio_data):', len(audio_data),
# 'chunk_frames:', chunk_frames,
# 'total_frames:', total_frames)
if len(ranges[1]) == 0:
# if the audio is too short, we just use the first chunk
ranges[1] = [0]
if len(ranges[2]) == 0:
# if the audio is too short, we just use the first chunk
ranges[2] = [0]
# randomly choose index for each part
idx_front = np.random.choice(ranges[0])
idx_middle = np.random.choice(ranges[1])
idx_back = np.random.choice(ranges[2])
# select mel
mel_chunk_front = mel[idx_front:idx_front+chunk_frames, :]
mel_chunk_middle = mel[idx_middle:idx_middle+chunk_frames, :]
mel_chunk_back = mel[idx_back:idx_back+chunk_frames, :]
# shrink the mel
mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0]
# logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
# stack
mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], dim=0)
sample["mel_fusion"] = mel_fusion
longer = torch.tensor([True])
else:
raise NotImplementedError(
f"data_truncating {data_truncating} not implemented"
)
# random crop to max_len (for compatibility)
overflow = len(audio_data) - max_len
idx = np.random.randint(0, overflow + 1)
audio_data = audio_data[idx: idx + max_len]
else: # padding if too short
if len(audio_data) < max_len: # do nothing if equal
if data_filling == "repeatpad":
n_repeat = int(max_len/len(audio_data))
audio_data = audio_data.repeat(n_repeat)
# audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
# audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
audio_data = F.pad(
audio_data,
(0, max_len - len(audio_data)),
mode="constant",
value=0,
)
elif data_filling == "pad":
audio_data = F.pad(
audio_data,
(0, max_len - len(audio_data)),
mode="constant",
value=0,
)
elif data_filling == "repeat":
n_repeat = int(max_len/len(audio_data))
audio_data = audio_data.repeat(n_repeat+1)[:max_len]
else:
raise NotImplementedError(
f"data_filling {data_filling} not implemented"
)
if data_truncating == 'fusion':
mel = get_mel(audio_data, audio_cfg)
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
sample["mel_fusion"] = mel_fusion
longer = torch.tensor([False])
sample["longer"] = longer
sample["waveform"] = audio_data
return sample |