Last commit not found
# https://github.com/LAION-AI/CLAP/blob/df65ca0f6c3062dc554132cb40e74f4915084b21/src/training/data.py#L469 | |
from functools import partial | |
import soundfile as sf | |
import io | |
import numpy as np | |
import torch | |
import torchaudio | |
import torchvision | |
import torch.nn.functional as F | |
AUDIO_CFG = { | |
"sample_rate": 48000, | |
"audio_length": 1024, | |
"clip_samples": 480000, | |
"mel_bins": 64, | |
"window_size": 1024, | |
"hop_size": 480, | |
"fmin": 50, | |
"fmax": 14000, | |
"class_num": 527, | |
} | |
class dotdict(dict): | |
"""dot.notation access to dictionary attributes""" | |
__getattr__ = dict.get | |
__setattr__ = dict.__setitem__ | |
__delattr__ = dict.__delitem__ | |
class Map(dict): | |
""" | |
Example: | |
m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer']) | |
""" | |
def __init__(self, *args, **kwargs): | |
super(Map, self).__init__(*args, **kwargs) | |
for arg in args: | |
if isinstance(arg, dict): | |
for k, v in arg.iteritems(): | |
self[k] = v | |
if kwargs: | |
for k, v in kwargs.iteritems(): | |
self[k] = v | |
def __getattr__(self, attr): | |
return self.get(attr) | |
def __setattr__(self, key, value): | |
self.__setitem__(key, value) | |
def __setitem__(self, key, value): | |
super(Map, self).__setitem__(key, value) | |
self.__dict__.update({key: value}) | |
def __delattr__(self, item): | |
self.__delitem__(item) | |
def __delitem__(self, key): | |
super(Map, self).__delitem__(key) | |
del self.__dict__[key] | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def get_mel(audio_data,audio_cfg): | |
# mel shape: (n_mels, T) | |
mel = torchaudio.transforms.MelSpectrogram( | |
sample_rate=audio_cfg['sample_rate'], | |
n_fft=audio_cfg['window_size'], | |
win_length=audio_cfg['window_size'], | |
hop_length=audio_cfg['hop_size'], | |
center=True, | |
pad_mode="reflect", | |
power=2.0, | |
norm=None, | |
onesided=True, | |
n_mels=audio_cfg['mel_bins'], | |
f_min=audio_cfg['fmin'], | |
f_max=audio_cfg['fmax'] | |
)(audio_data) | |
# we use log mel spectrogram as input | |
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) | |
return mel.T # (T, n_mels) | |
def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg): | |
""" | |
Calculate and add audio features to sample. | |
Sample: a dict containing all the data of current sample. | |
audio_data: a tensor of shape (T) containing audio data. | |
max_len: the maximum length of audio data. | |
data_truncating: the method of truncating data. | |
data_filling: the method of filling data. | |
audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. | |
""" | |
with torch.no_grad(): | |
if len(audio_data) > max_len: | |
if data_truncating == "rand_trunc": | |
longer = torch.tensor([True]) | |
elif data_truncating == "fusion": | |
# fusion | |
mel = get_mel(audio_data, audio_cfg) | |
# split to three parts | |
chunk_frames = max_len // audio_cfg['hop_size']+1 # the +1 related to how the spectrogram is computed | |
total_frames = mel.shape[0] | |
if chunk_frames == total_frames: | |
# there is a corner case where the audio length is | |
# larger than max_len but smaller than max_len+hop_size. | |
# In this case, we just use the whole audio. | |
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) | |
sample["mel_fusion"] = mel_fusion | |
longer = torch.tensor([False]) | |
else: | |
ranges = np.array_split(list(range(0, total_frames-chunk_frames+1)), 3) | |
# print('total_frames-chunk_frames:', total_frames-chunk_frames, | |
# 'len(audio_data):', len(audio_data), | |
# 'chunk_frames:', chunk_frames, | |
# 'total_frames:', total_frames) | |
if len(ranges[1]) == 0: | |
# if the audio is too short, we just use the first chunk | |
ranges[1] = [0] | |
if len(ranges[2]) == 0: | |
# if the audio is too short, we just use the first chunk | |
ranges[2] = [0] | |
# randomly choose index for each part | |
idx_front = np.random.choice(ranges[0]) | |
idx_middle = np.random.choice(ranges[1]) | |
idx_back = np.random.choice(ranges[2]) | |
# select mel | |
mel_chunk_front = mel[idx_front:idx_front+chunk_frames, :] | |
mel_chunk_middle = mel[idx_middle:idx_middle+chunk_frames, :] | |
mel_chunk_back = mel[idx_back:idx_back+chunk_frames, :] | |
# shrink the mel | |
mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] | |
# logging.info(f"mel_shrink.shape: {mel_shrink.shape}") | |
# stack | |
mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], dim=0) | |
sample["mel_fusion"] = mel_fusion | |
longer = torch.tensor([True]) | |
else: | |
raise NotImplementedError( | |
f"data_truncating {data_truncating} not implemented" | |
) | |
# random crop to max_len (for compatibility) | |
overflow = len(audio_data) - max_len | |
idx = np.random.randint(0, overflow + 1) | |
audio_data = audio_data[idx: idx + max_len] | |
else: # padding if too short | |
if len(audio_data) < max_len: # do nothing if equal | |
if data_filling == "repeatpad": | |
n_repeat = int(max_len/len(audio_data)) | |
audio_data = audio_data.repeat(n_repeat) | |
# audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
# audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] | |
audio_data = F.pad( | |
audio_data, | |
(0, max_len - len(audio_data)), | |
mode="constant", | |
value=0, | |
) | |
elif data_filling == "pad": | |
audio_data = F.pad( | |
audio_data, | |
(0, max_len - len(audio_data)), | |
mode="constant", | |
value=0, | |
) | |
elif data_filling == "repeat": | |
n_repeat = int(max_len/len(audio_data)) | |
audio_data = audio_data.repeat(n_repeat+1)[:max_len] | |
else: | |
raise NotImplementedError( | |
f"data_filling {data_filling} not implemented" | |
) | |
if data_truncating == 'fusion': | |
mel = get_mel(audio_data, audio_cfg) | |
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) | |
sample["mel_fusion"] = mel_fusion | |
longer = torch.tensor([False]) | |
sample["longer"] = longer | |
sample["waveform"] = audio_data | |
return sample |